Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Testing functions returning iterable in python

I'm having difficulties testing python functions that return an iterable, like functions that are yielding or functions that simply return an iterable, like return imap(f, some_iter) or return permutations([1,2,3]).

So with the permutations example, I expect the output of the function to be [(1, 2, 3), (1, 3, 2), ...]. So, I start testing my code.

def perm3():
  return permutations([1,2,3])

# Lets ignore test framework and such details
def test_perm3():
  assertEqual(perm3(), [(1, 2, 3), (1, 3, 2), ...])

This will not work, since perm3() is an iterable, not a list. So we can fix this particular example.

def test_perm3():
  assertEqual(list(perm3()), [(1, 2, 3), (1, 3, 2), ...])

And this works fine. But what if I have nested iterables? That is iterables yielding iterables? Like say the expressions product(permutations([1, 2]), permutations([3, 4])). Now this is probably not useful but it's clear that it will be (once unrolling the iterators) something like [((1, 2), (3, 4)), ((1, 2), (4, 3)), ...]. However, we can not just wrap list around our result, as that will only turn iterable<blah> to [iterable<blah>, iterable<blah>, ...]. Well of course I can do map(list, product(...)), but this only works for a nesting level of 2.

So, does the python testing community have any solution for the problems when testing iterables? Naturally some iterables can't be tested in this way, like if you want an infinite generator, but still this issue should be common enough for somebody to have thought about this.

like image 829
Tarrasch Avatar asked Sep 28 '12 16:09

Tarrasch


Video Answer


3 Answers

I use KennyTM's assertRecursiveEq:

import unittest
import collections
import itertools

class TestCase(unittest.TestCase):
    def assertRecursiveEq(self, first, second, *args, **kwargs):
        """
        https://stackoverflow.com/a/3124155/190597 (KennyTM)
        """
        if (isinstance(first, collections.Iterable)
            and isinstance(second, collections.Iterable)):
            for first_, second_ in itertools.izip_longest(
                    first, second, fillvalue = object()):
                self.assertRecursiveEq(first_, second_, *args, **kwargs)
        else:
            # If first = np.nan and second = np.nan, I want them to
            # compare equal. np.isnan raises TypeErrors on some inputs,
            # so I use `first != first` as a proxy. I avoid dependency on numpy
            # as a bonus.
            if not (first != first and second != second):
                self.assertAlmostEqual(first, second, *args, **kwargs)                

def perm3():
    return itertools.permutations([1,2,3])

class Test(TestCase):
    def test_perm3(self):
        self.assertRecursiveEq(perm3(),
            [(1, 2, 3), (1, 3, 2), (2, 1, 3), (2, 3, 1), (3, 1, 2), (3, 2, 1)])

if __name__ == '__main__':
    import sys
    sys.argv.insert(1, '--verbose')
    unittest.main(argv = sys.argv)
like image 164
unutbu Avatar answered Sep 23 '22 01:09

unutbu


1. If the order of results doesn't matter

Use unittest.assertItemsEqual(). This tests that the items are present in both self and reference, but ignores the order. This works on your example one nested deep example. It also works on a 2-deep example that I concocted.

2. If the order of results matters

I would suggest not ever casting the results of perm3() to a list. Instead, compare the elements directly as you iterate. Here's a test function that will work for your example. I added it to a subclass of unittest.TestCase:

def assertEqualIterables(self, itable1, itable2):
     for ival1, ival2 in zip(itable1, itable2):
         if "__iter__" in dir(ival1):
             self.assertEqualIterables(ival1, ival2)
         else:
             self.assertEquals(ival1, ival2)

Use it like:

def test_perm3(self):
    reference = [((1, 2), (3, 4)), ((1, 2), (4, 3)), 
                 ((2, 1), (3, 4)), ((2, 1), (4, 3)),]

    self.assertEqualIterables(perm3(), reference)
like image 24
dbn Avatar answered Sep 22 '22 01:09

dbn


You could extend you suggestion to include type (that was allowing you to distinguish between lists, tuples, etc.), like so:

def unroll(item):
  if "__iter__" in dir(item):
    return map(unroll, item), type(item)
  else:
    return item, type(item)

For example:

got = unroll(permutations([1,2]))
([([(1, <type 'int'>), (2, <type 'int'>)], <type 'tuple'>), ([(2, <type 'int'>), (1, <type 'int'>)], <type 'tuple'>)], <type 'itertools.permutations'>)
# note the final: <type 'itertools.permutations'>
expected = [(1, 2), (2, 1)]
assertEqual(x[0], unroll(expected) ) # check underlying
assertEqual(x[1], type(permutations([]) ) # check type

.

One thing to mention, type is coarse in distinguishing between objects e.g. <type 'classobj'>...

like image 40
Andy Hayden Avatar answered Sep 24 '22 01:09

Andy Hayden