I am using Python 3.11 and I would need to detect if an optional class attribute is type of Enum (i.e. type of a subclass of Enum).
With typing.get_type_hints()
I can get the type hints as a dict, but how to check if a field's type is optional Enum (subclass)? Even better if I could get the type of any optional field regardless is it Optional[str]
, Optional[int]
, Optional[Class_X]
, etc.
from typing import Optional, get_type_hints
from enum import IntEnum, Enum
class TestEnum(IntEnum):
foo = 1
bar = 2
class Foo():
opt_enum : TestEnum | None = None
types = get_type_hints(Foo)['opt_enum']
(ipython)
In [4]: Optional[TestEnum] == types
Out[4]: True
(yes, these are desperate attempts)
In [6]: Optional[IntEnum] == types
Out[6]: False
and
In [11]: issubclass(Enum, types)
Out[11]: False
and
In [12]: issubclass(types, Enum)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [12], line 1
----> 1 issubclass(types, Enum)
TypeError: issubclass() arg 1 must be a class
and
In [13]: issubclass(types, Optional[Enum])
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [13], line 1
----> 1 issubclass(types, Optional[Enum])
File /usr/lib/python3.10/typing.py:1264, in _UnionGenericAlias.__subclasscheck__(self, cls)
1262 def __subclasscheck__(self, cls):
1263 for arg in self.__args__:
-> 1264 if issubclass(cls, arg):
1265 return True
TypeError: issubclass() arg 1 must be a class
and
In [7]: IntEnum in types
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In [7], line 1
----> 1 IntEnum in types
TypeError: argument of type 'types.UnionType' is not iterable
I have several cases where I am importing data from CSV files and creating objects of a class from each row. csv.DictReader()
returns a dict[str, str]
and I need to fix the types for the fields before attempting to create the object. However, some of the object fields are Optional[int]
, Optional[bool]
, Optional[EnumX]
or Optional[ClassX]
. I have several of those classes multi-inheriting my CSVImportable()
class/interface. I want to implement the logic once into CSVImportable()
class instead of writing roughly same code in field-aware way in every subclass. This CSVImportable._field_type_updater()
should:
Optional[ClassX]
fieldsNaturally I am thankful for better designs too.
When you are dealing with a parameterized type (generic or special like typing.Optional
), you can inspect it via get_args
/get_origin
.
Doing that you'll see that T | S
is implemented slightly differently than typing.Union[T, S]
. The origin of the former is types.UnionType
, while that of the latter is typing.Union
. Unfortunately this means that to cover both variants, we need two distinct checks.
from types import UnionType
from typing import Union, get_origin
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
Using typing.Optional
just uses typing.Union
under the hood, so the origin is the same. Here is a working demo:
from enum import IntEnum
from types import UnionType
from typing import Optional, get_type_hints, get_args, get_origin, Union
class TestEnum(IntEnum):
foo = 1
bar = 2
class Foo:
opt_enum1: TestEnum | None = None
opt_enum2: Optional[TestEnum] = None
opt_enum3: TestEnum
opt4: str
def is_union(t: object) -> bool:
origin = get_origin(t)
return origin is Union or origin is UnionType
if __name__ == "__main__":
for name, type_ in get_type_hints(Foo).items():
if type_ is TestEnum or is_union(type_) and TestEnum in get_args(type_):
print(name, "accepts TestEnum")
Output:
opt_enum1 accepts TestEnum opt_enum2 accepts TestEnum opt_enum3 accepts TestEnum