Hello!
In this blog post I will talk about patching Python objects when using unittest.mock.patch
.
I don't know about you, but I'm always confused about what is the right path to patch, and it used to take me some trial and error to find the right path. My hope is that if I come across this again, maybe search engine will take me to my own blog post so that I don't have to waste any more time!
So, let's start with an example. Let's imagine that we need a class to represent 2D points, that are created randomly on the screen. We have this lovely file called point.py
:
from dataclasses import dataclass
from random import randint
MAX_X = 100
MAX_Y = 100
@dataclass
class Point:
x: float
y: float
def create_random_point():
return Point(
x=randint(0, MAX_X),
y=randint(0, MAX_Y)
)
Great! So the next step is to write tests. Let's write this in test_point.py
:
from point import create_random_point, MAX_X, MAX_Y
def test_create_point(_):
point = create_random_point()
assert 0 <= point.x and point.x < MAX_X
assert 0 <= point.y and point.y < MAX_Y
All good so far, if we then run
pytest
in the directory where code is, we will see that the tests will pass.
Now, let's say that instead of checking that x
and y
are within certain range, we want in our test to compare to exact values of x
and y
. In this example it's less obvious why we would want that, but in real life we often want to remove any randomness from our tests.
In order to achieve that, we need to mock the randint
method, so that it always returns the same value. We decided to use the patch
decorator from mock
library to achieve it.
Since we import a method like this:
from random import randint
It can be tempting to try and use patch
like this:
from mock import patch
@patch('random.randint', return_value=59)
def test_create_point(_):
point = create_random_point()
assert point.x == 59
assert point.y == 59
However, if you try this, you will see that it doesn't work:
So the question is, of course: why it doesn't work?
Actually, Python's unittest.mock
library documentation provides us with an answer, but it might be a bit hard to understand at first. Let's take a look:
So what does it mean, where the object is looked up?
In our simple example, we are looking up, or using, function randint
from the file point.py
, or, module point
. And what the text is saying, is that the name that patch
wants is in module point
, and not where it is defined, which is in module random
.
Let's try to change the code so that we take this into account:
@patch('point.randint', return_value=59)
def test_create_point(_):
point = create_random_point()
assert point.x == 59
assert point.y == 59
And voila, if we now run pytest
again, the tests will pass!
The mnemonic rule that I remembered this by is:
The object should be patched using a path it's imported TO, not where it is imported FROM.
Now, let's see what happens if we import randint
a little bit differently.
Say, we now import it using syntax:
from dataclasses import dataclass
import random.randint
MAX_X = 100
MAX_Y = 100
@dataclass
class Point:
x: float
y: float
def create_random_point():
return Point(
x=random.randint(0, MAX_X),
y=random.randint(0, MAX_Y)
)
The tests will stop working again!
In this case, the name randint
is being looked up in module random
again, since we access it as random.randint
. We need to change the test correspondingly:
@patch('random.randint', return_value=59)
def test_create_point(_):
point = create_random_point()
assert point.x == 59
assert point.y == 59
And everything works again!
BONUS TRACK
If you want to test for different values of x
and y
, you can do it by passing an iterable to the side_effect
parameter of patch
:
@patch('point.random.randint', side_effect=[59, 63])
def test_create_point(_):
point = create_random_point()
assert point.x == 59
assert point.y == 63
That's it for today! Thanks for reading and I hope this was useful.
Top comments (0)