I have a small class for which I am writing tests.
The result for my function contains nan
which causes my test to fail because nan
is not equal to any other nan
. How do I write a proper test for this?
from math import nan
import unittest
import pandas as pd
class holder:
def __init__(self, source: list) -> None:
self.source = source
self.data = []
def calculate_rolling_average(self):
self.data = pd.Series(self.source).rolling(3).mean().to_list()
class test_value(unittest.TestCase):
def setUp(self):
self.hold = holder(source=[10, 20, 30, 40])
def test_rolling_simple_average_for_list(self):
expected_result = [nan, nan, 20.0, 30.0]
self.hold.calculate_rolling_average()
self.assertSequenceEqual(self.hold.data, expected_result)
unittest.main()
This results in:
>>> AssertionError: Sequences differ: [nan, nan, 20.0, 30.0] != [nan, nan, 20.0, 30.0]
First differing element 0:
nan
nan
[nan, nan, 20.0, 30.0]
I read How to assert if the two possibly NaN values are equal which is testing a singular value for numpy.nan
so the same does not apply to this question.
The specific problem here seems to be caused by pd.Series.to_list
returning new float("nan")
s, not math.nan
; these values all represent NaN but are not identical to one another, hence the list comparison fails. An actual MRE of this question would be:
import math
import unittest
class TestValue(unittest.TestCase):
def test_nan_in_list(self):
self.assertEqual([float("nan")], [math.nan])
if __name__ == "__main__":
unittest.main()
AssertionError: Lists differ: [nan] != [nan]
First differing element 0:
nan
nan
[nan]
Note: you can generally call self.assertEqual
and have it dispatch to the appropriate underlying implementation.
However, it's easy enougn in Python to write a simple matcher:
import math
import typing as tp
class NaN:
def __eq__(self, other: tp.Any) -> bool:
return math.isnan(other)
def __repr__(self) -> str:
return "nan"
Then you can use that in the expected value - both math.nan == NaN()
and any arbitrary float("nan") == NaN()
:
class TestValue(unittest.TestCase):
def test_nan_in_list(self):
self.assertEqual([float("nan")], [NaN()])
As well as passing when the code works, this offers useful diagnostics on failure:
AssertionError: Lists differ: [nan, nan, 20.0, 30.0] != [nan, nan, 20.0, 35.0]
First differing element 3:
30.0
35.0
- [nan, nan, 20.0, 30.0]
? ^
+ [nan, nan, 20.0, 35.0]
?
This is considerably better than flattening to a single boolean, as this answer suggests, which loses all diagnostics entirely:
AssertionError: False is not true