pythonpattern-matchingenum-flags

Pattern matching in python to catch enum-flag combination


When using Flags in Python, pattern matching only catch direct equality and not inclusion. It can be circumvented with a condition and an in however, provided you catch the 0 flag before if you consider it a special case:

from enum import Flag

class Test(Flag):
    ONE = 1
    TWO = 2
    THREE = ONE | TWO
    NONE = 0

def match_one(pattern):
    match pattern:
        case Test.THREE:
            print(pattern, 'Three')
        case Test.ONE:
            print(pattern, 'ONE')
        case v if v in Test.THREE:
            print(pattern, 'in THREE')
        case _:
            print(pattern, 'Other')

Test.ONE ONE
Test.TWO in THREE
Test.THREE Three
Test.NONE in THREE

However, using such condition is limiting when using multiple flags. A minimal example would be trying to match OBJ_1 with TYPE_1 (either ATTR_1 or ATTR_2, or TYPE_1) or OBJ_2 with TYPE_2 (either ATTR_3, ATTR_4, or TYPE_2) (the problem is not limited to such a simple case, but it's a good minimal representation)

class Flag_1(Flag):
    OBJ_1 = 1
    OBJ_2 = 2
    NONE = 0

class Flag_2(Flag):
    ATTR_1 = 1
    ATTR_2 = 2
    ATTR_3 = 4
    ATTR_4 = 8
    TYPE_1 = ATTR_1 | ATTR_2
    TYPE_2 = ATTR_3 | ATTR_4
    NONE = 0

which can be done as

match flag1, flag2:
    case Flag_1.NONE, _ | _, Flag_2.NONE:
        print('Useless')
    case (Flag_1.OBJ_1, c) if c in Flag_1.TYPE_1:
        print('Do 1')
    case (Flag_1.OBJ_2, c) if c in Flag_1.TYPE_2:
        print('Do 1')
    case _:
        print('Other stuff')

or

match flag1, flag2:
    case Flag_1.NONE, _ | _, Flag_2.NONE:
        print('Useless')
    case (Flag_1.OBJ_1, Flag_2.ATTR_1 | Flag_2.ATTR_2) \
        | (Flag_1.OBJ_2, Flag_2.ATTR_3 | Flag_2.ATTR_4):
        print('Do 1')
    case _:
        print('Other stuff')

However solution 1 is limiting when there are some more combinations and adds many line which repeats the same operation. Solution 2 on the other hand is limiting if TYPE_1 contains let's say 10 flags, as it would be a very long line.

something like below is not possible with pattern matching (and would be more adapted to a if statement I think)

case (Flag_1.OBJ_1, c) if c in Flag_2.TYPE_1 \
    | (Flag_1.OBJ_2, c) if c in Flag_2.TYPE_2:  # invalid syntax

and the following does not works as explained earlier due to the way pattern matching works

case (Flag_1.OBJ_1, Flag_2.TYPE_1) \
    | (Flag_1.OBJ_2, Flag_2.TYPE_2) \ 

Is there a better way than solution 1 with doing one case for each combinations despite all having the same outcome, or using an if statement, to match combined Flag?


Solution

  • The match-case pattern is very powerful, but I don't think it is the right solution here. You can accomplish this is in a much more understandable way using a few if-statements. The nice part of using flags is that you can check for membership to the combined flags using &.

    def check_two(flag1: Flag_1, flag2: Flag_2):
        if (not flag1) or (not flag2):
            return 'Useless'
        if (flag1 & Flag_1.OBJ_1) and (flag2 & Flag_2.TYPE_1):
            return 'Obj-1, Type-1'
        if (flag1 & Flag_1.OBJ_2) and (flag2 & Flag_2.TYPE_2):
            return 'Obj-2, Type-2'
        return 'Other mix'