I have two numpy arrays containing integers which I'm comparing with numpy.testing.assert_array_equal
. The arrays are "equal enough", i.e. a few elements differ but given the size of my arrays, that's OK (in this specific case). But of course the test fails:
AssertionError:
Arrays are not equal
(mismatch 0.0010541406645359075%)
x: array([[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],...
y: array([[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],
[ 0., 0., 0., ..., 0., 0., 0.],...
----------------------------------------------------------------------
Ran 1 test in 0.658s
FAILED (failures=1)
Of course one might argue that the (long-term) clean solution to this would be to adapt the reference solution or whatnot, but what I'd prefer is to simply allow for some mismatch without the test failing. I would have hoped for assert_array_equal to have an option for this, but this is not the case.
I've written a function which allows me to do exactly what I want, so the problem might be considered solved, but I'm just wondering whether there is a better, more elegant way to do this. Also, the approach of parsing the error string feels pretty hacky, but I haven't found a better way to get the mismatch percentage value.
def assert_array_equal_tolerant(arr1,arr2,threshold):
"""Compare equality of two arrays while allowing a certain mismatch.
Arguments:
- arr1, arr2: Arrays to compare.
- threshold: Mismatch (in percent) above which the test fails.
"""
try:
np.testing.assert_array_equal(arr1,arr2)
except AssertionError as e:
for arg in e.args[0].split("\n"):
match = re.search(r'mismatch ([0-9.]+)%',arg)
if match:
mismatch = float(match.group(1))
break
else:
raise
if mismatch > threshold:
raise
Just to be clear: I'm not talking about assert_array_almost_equal
, and using it is also not feasible, because the errors are not small, they might be huge for a single element, but are confined to a very small number of elements.
You could try (if they are integers) to check for the number of elements that are not equal without regular expressions
unequal_pos = np.where(arr1 != arr2)
len(unequal_pos[0]) # gives you the number of elements that are not equal.
I don't know if you consider this more elegant.
Since the result of np.where
can be used as index you can get the elements that do not match with
arr1[unequal_pos]
So you can do pretty much every test you like with that result. Depends on how you want to define the mismatch either by number of different elements or difference between the elements or something even fancier.
Here's a crude comparison, but it seems to be in the spirit of what numpy.testing.assert_array_equal
does:
In [71]: x=np.arange(100).reshape(10,10)
In [72]: y=np.arange(100).reshape(10,10)
In [73]: y[(5,7),(3,5)]=(3,5)
In [74]: np.sum(np.abs(x-y)>1)
Out[74]: 2
In [80]: np.sum(x!=y)
Out[80]: 2
count_nonzero
is a faster counter (because it is used frequently in other numpy
code to allocate space)
In [90]: np.count_nonzero(x!=y)
Out[90]: 2
The function that you are using does:
assert_array_compare(operator.__eq__, x, y, err_msg=err_msg)
np.testing.utils.assert_array_compare
is a longish function, but most of it has to do with testing shape, and handling nan
and inf
. Otherwise it comes down to doing
x==y
and doing a count
on the number of mismatches, and generating the err_msg
. Note that the err_msg
can be customized, so parsing it could simplified.
If you know the shapes match, and you aren't worried about nan
like values, then just filtering the numeric difference should work just fine.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With