Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to mock random.choice in python?

I want choice to return the same value 1000 every time in my unittest. The following code doesn't work.

import unittest
from random import choice

from mock import mock

def a():
    return choice([1, 2, 3])

class mockobj(object):
    @classmethod
    def choice(cls, li):
        return 1000

class testMock(unittest.TestCase):

    def test1(self):
        with mock.patch('random.choice', mockobj.choice):
            self.assertEqual(a(), 1000)

The error message is as follows:

Failure
Traceback (most recent call last):
  File "test.py", line 15, in test1
    self.assertEqual(a(), 1000)
AssertionError: 3 != 1000

How should I modify it to make it work? I'm using python2.7

like image 306
Searene Avatar asked Oct 08 '16 08:10

Searene


3 Answers

I would like to improve the @Alex response with a full script to better understand and be adaptable to other cases.

import random
from unittest import TestCase, mock

letters = ['A', 'B', 'C', 'D']

def get_random_words(): # Simple function using choice
  l = []
  for _ in range(3):
      l.append(random.choice(letters))
    
  return "".join(l)

class TestRandom(TestCase):

   @mock.patch('random.choice') # *(1)
   def test_get_random_words(self, mock_choice):
    
      mock_choice.side_effect = ['A','b','D','Z'] # *(2)
      result = get_random_words()
    
      self.assertEqual(result, 'AbD', 'Does not generate correct string')

Considerations

*(1) For this example, the function is inside the same file, but in case it is in another file you must change the path of the patch Ex: @mock.patch('your_package.your_file.your_function.random.choice')

*(2) For this case, the get_random_words function calls random.choice 3 times. This is why you must put equal or more items inside mock_choice.side_effect. This is because if it has fewer items it will throw the StopIteration error.

like image 191
Alex Montoya Avatar answered Oct 16 '22 04:10

Alex Montoya


The problem here is that a() is using an unpatched version of random.choice.

Compare functions a and b:

import random
from random import choice

def a():
    return choice([1, 2, 3])

def b():
    return random.choice([1, 2, 3])

def choice1000(values):
    return 1000

import unittest.mock as mock

with mock.patch('random.choice', choice1000):
    print('a', a())
    print('b', b())

It prints e.g.:

a 3
b 1000

Why?

This line is the problem:

from random import choice

It imported random and than stored random.choice into a new variable named choice.

Later, mock.patch patched the original random.choice, but not the local choice.

Can I patch the local one? Yes:

with mock.patch('__main__.choice', choice1000):
    print('a', a())
    print('b', b())

Now it prints e.g.

a 1000
b 1

(I used '__main__' because I put this code into the top-level file - it may be something else in your case)

So what to do?

Either patch everything, or take a different approach. For example, patch a() instead of choice().

Alternative Solution

In this case, where you want to test behaviour of random functions, it may be better to use a seed

def a():
    return random.choice([1, 2, 3, 1000])

def test1(self):
    random.seed(0)
    self.assertEqual(a(), 1000)

You can't know beforehand what random values will be generated for a certain seed, but you can be sure that they will always be the same. Which is exactly what you need in tests.

In the last example above, I tested a() after random.seed(0) once and it returned 1000, so I can be sure it will do so every time:

>>> import random
>>> random.seed(0)
>>> print (random.choice([1, 2, 3, 1000]))
1000
>>> random.seed(0)
>>> print (random.choice([1, 2, 3, 1000]))
1000
>>> random.seed(0)
>>> print (random.choice([1, 2, 3, 1000]))
1000
>>> random.seed(0)
>>> print (random.choice([1, 2, 3, 1000]))
1000
like image 39
zvone Avatar answered Oct 16 '22 04:10

zvone


I don't know what is mockobj from tests but what you can do is.

    @mock.patch('random.choice')
    def test1(self, choice_mock):
        choice_mock.return_value = 1000
        self.assertEqual(a(), 1000)
like image 24
Alex Avatar answered Oct 16 '22 06:10

Alex