pythonunit-testingpytest

Using logical operators in pytest expected results


I'm trying to develop pytest for a project, and while I'm not the most familiar with pytest I feel like I have a fairly basic understanding.

In this particular case I am testing some code that does route optimization and I wish to implement a bunch of different tests to ensure that the code performs as it should.

To help with this I have defined a dataclass which I want to use with pytest to basically tell what to expect for a given scenario.

@dataclass(slots=True)
class VRPResults:
    """
    A dataclass for storing the results of a VRP problem.

    """
    solver_time: Optional[float] = None
    total_travel_time: Optional[int] = None
    route_lengths: Optional[list[int]] = None
    route_indices: Optional[list[list[int]]] = None
    

    def __post_init__(self):
        if self.route_lengths is not None and self.total_travel_time is None:
            self.total_travel_time = sum(self.route_lengths)

    def __eq__(self, other):
        if not isinstance(other, VRPResults):
            raise TypeError(f"Cannot compare {type(self)} with {type(other)}")
        is_equal = True
        for field in fields(VRPResults):
            val = getattr(self, field.name)
            val_other = getattr(other, field.name)
            if val is None or val_other is None:
                continue
            if val_other != val:
                is_equal = False
                break
        return is_equal

The idea with this dataclass is that I can specify that for scenario 1, I know that the total travel time should be 10 minutes, while for scenario 2 I know that the route_indices should be the following [0,1,2] and the route_lengths should be [5,2,4]. So basically it is an easy framework for me to use to compare different scenarios expected values with what my model produces. So my code for using the above dataclass would look something like this:

    vrp_data, vrp_results_expected = create_data_model()
    vrp_results_predicted = solve_vrp(vrp_data)
    assert vrp_results_expected == vrp_results_predicted, "Solution does not match expected results"

where vrp_results_predicted and vrp_results_expected both are instances of the above dataclass.

The main problem is that this code only checks whether these parameters are equal or not. And instead I would like some way to specify how exactly it should evaluate a parameter.

For instance in scenario 3 I do not know what the actual best travel time is, but I would be happy with anything below 20 minutes. In order to accommodate this I'm thinking of adding additional parameters which specifies the logical operator that should be used to evaluate the expressions, but I am not sure exactly how to add these logical operators in python, and I'm wondering whether there is a better way of doing something like this? maybe pytest have some clever tools available for this?


Solution

  • I suggest going with this approach: instead of overriding the __eq__ method in your dataclass, it’s better to create a separate comparison function where you can pass in custom check rules for each field. This is especially helpful when different test scenarios require different validation logic — like in one case, you might want to compare route lists exactly, but in another, you just want to make sure total travel time is under a certain threshold.

    Here’s how you could set it up:

    from dataclasses import dataclass, fields
    from typing import Optional, Callable
    
    @dataclass(slots=True)
    class VRPResults:
        solver_time: Optional[float] = None
        total_travel_time: Optional[int] = None
        route_lengths: Optional[list[int]] = None
        route_indices: Optional[list[list[int]]] = None
    
        def __post_init__(self):
            if self.route_lengths is not None and self.total_travel_time is None:
                self.total_travel_time = sum(self.route_lengths)
    

    Then, here’s the custom comparison function:

    def compare_vrp_results(predicted: VRPResults, expected: VRPResults, custom_checks: dict[str, Callable[[any], bool]] = None):
        custom_checks = custom_checks or {}
        for field in fields(VRPResults):
            value_predicted = getattr(predicted, field.name)
            value_expected = getattr(expected, field.name)
    
            if field.name in custom_checks:
                if not custom_checks[field.name](value_predicted):
                    return False, f"{field.name} failed check: Got {value_predicted}"
            elif value_expected is not None and value_predicted != value_expected:
                return False, f"{field.name} does not match: Expected {value_expected}, got {value_predicted}"
        return True, ""
    

    And a test example could look like this:

    def test_scenario_3():
        vrp_data, vrp_results_expected = create_data_model()
        vrp_results_predicted = solve_vrp(vrp_data)
    
        custom_checks = {
            "total_travel_time": lambda x: x < 20  # Something like anything under 20 is fine. You can have any checks.
        }
    
        result, message = compare_vrp_results(vrp_results_predicted, vrp_results_expected, custom_checks)
        assert result, message
    

    This approach must give you a clean and flexible way to define how results should be validated per scenario, without bloating your dataclass logic.