Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Compare (assert equality of) two complex data structures containing numpy arrays in unittest

I use Python's unittest module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers, strings, Python containers (lists/tuples/dicts) and numpy arrays. The latter are the reason for asking the question, because I cannot just do

self.assertEqual(big_struct1, big_struct2) 

because it produces a

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() 

I imagine that I need to write my own equality test for this. It should work for arbitrary structures. My current idea is a recursive function that:

  • tries direct comparison of the current "node" of arg1 to the corresponding node of arg2;
  • if no exception is raised, moves on ("terminal" nodes/leaves are processed here, too);
  • if ValueError is caught, goes deeper until it finds a numpy.array;
  • compares the arrays (e.g. like this).

What seems a little problematic is keeping track of "corresponding" nodes of two structures, but perhaps zip is all I need here.

The question is: are there good (simpler) alternatives to this approach? Maybe numpy presents some tools for this? If no alternatives are suggested, I will implement this idea (unless I have a better one) and post as an answer.

P.S. I have a vague feeling that I might have seen a question addressing this problem, but I can't find it now.

P.P.S. An alternative approach would be a function that traverses the structure and converts all numpy.arrays to lists, but is this any easier to implement? Seems the same to me.


Edit: Subclassing numpy.ndarray sounds very promising, but obviously I don't have both sides of the comparison hard-coded into a test. One of them, though, is indeed hardcoded, so I can:

  • populate it with custom subclasses of numpy.array;
  • change isinstance(other, SaneEqualityArray) to isinstance(other, np.ndarray) in jterrace's answer;
  • always use it as LHS in comparisons.

My questions in this regard are:

  1. Will it work (I mean, it sounds all right to me, but maybe some tricky edge cases will not be handled correctly)? Will my custom object always end up as LHS in the recursive equality checks, as I expect?
  2. Again, are there better ways (given that I get at least one of the structures with real numpy arrays).

Edit 2: I tried it out, the (seemingly) working implementation is shown in this answer.

like image 440
Lev Levitsky Avatar asked Jan 09 '13 21:01

Lev Levitsky


People also ask

How do you compare two NumPy arrays are equal?

To check if two NumPy arrays A and B are equal: Use a comparison operator (==) to form a comparison array. Check if all the elements in the comparison array are True.

Are two NumPy arrays equal?

Comparing two NumPy arrays determines whether they are equivalent by checking if every element at each corresponding index is the same. Method 1: We generally use the == operator to compare two NumPy arrays to generate a new array object.


2 Answers

Would have commented, but it gets too long...

Fun fact, you cannot use == to test if arrays are the same I would suggest you use np.testing.assert_array_equal instead.

  1. that checks dtype, shape, etc.,
  2. that doesn't fail for the neat little math of (float('nan') == float('nan')) == False (normal python sequence == has an even more fun way of ignoring this sometimes, because it uses PyObject_RichCompareBool which does a (for NaNs incorrect) is quick check (for testing of course that is perfect)...
  3. There is also assert_allclose because floating point equality can get very tricky if you do actual calculations and you usually want almost the same values, since the values can become hardware depended or possibly random depending what you do with them.

I would almost suggest trying serializing it with pickle if you want something this insanely nested, but that is overly strict (and point 3 is of course fully broken then), for example the memory layout of your array does not matter, but matters to its serialization.

like image 168
seberg Avatar answered Sep 26 '22 01:09

seberg


The assertEqual function will invoke the __eq__ method of objects, which should recurse for complex data types. The exception is numpy, which doesn't have a sane __eq__ method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:

import copy import numpy import unittest  class SaneEqualityArray(numpy.ndarray):     def __eq__(self, other):         return (isinstance(other, SaneEqualityArray) and                 self.shape == other.shape and                 numpy.ndarray.__eq__(self, other).all())  class TestAsserts(unittest.TestCase):      def testAssert(self):         tests = [             [1, 2],             {'foo': 2},             [2, 'foo', {'d': 4}],             SaneEqualityArray([1, 2]),             {'foo': {'hey': SaneEqualityArray([2, 3])}},             [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},              SaneEqualityArray([5, 6]), 34]         ]         for t in tests:             self.assertEqual(t, copy.deepcopy(t))  if __name__ == '__main__':     unittest.main() 

This test passes.

like image 26
jterrace Avatar answered Sep 25 '22 01:09

jterrace