pythonenumsfastapipydantic

How to make case insensitive choices using Python's enum and FastAPI?


I have this application:

import enum
from typing import Annotated, Literal

import uvicorn
from fastapi import FastAPI, Query, Depends
from pydantic import BaseModel

app = FastAPI()


class MyEnum(enum.Enum):
    ab = "ab"
    cd = "cd"


class MyInput(BaseModel):
    q: Annotated[MyEnum, Query(...)]


@app.get("/")
def test(inp: MyInput = Depends()):
    return "Hello world"


def main():
    uvicorn.run("run:app", host="0.0.0.0", reload=True, port=8001)


if __name__ == "__main__":
    main()

curl http://127.0.0.1:8001/?q=ab or curl http://127.0.0.1:8001/?q=cd returns "Hello World"

But any of these

returns 422Unprocessable Entity which makes sense.

How can I make this validation case insensitive?


Solution

  • You could make case insensitive enum values by overriding the Enum's _missing_ method . As per the documentation, this classmethod—which by default does nothing—can be used to look up for values not found in cls; thus, allowing one to try and find the enum member by value.

    Note that one could extend from the str class when declaring the enumeration class (e.g., class MyEnum(str, Enum)), which would indicate that all members in the enum must have values of the specified type (e.g., str). This would also allow comparing a string to an enum member (using the equality operator ==), without having to use the .value attribute on the enum member (e.g., if member.lower() == value). Otherwise, if the enumeration class was declared as class MyEnum(Enum) (without str subclass), one would need to use the .value attribute on the enum member (e.g., if member.value.lower() == value) to safely compare the enum member to a string.

    Also, note that calling the lower() function on the enum member (i.e., member.lower()) would not be necessary, unless the enum member values of your class include uppercase (or a combination of uppercase and lowercase) letters as well (e.g., ab = 'aB', cd = 'Cd', etc.). Hence, for the example below, where only lowercase letters are used, you could avoid using it, and instead simply use if member == value to compare the enum member to a value; thus, saving you from calling the lower() funciton on every member in the class.

    Example 1

    from enum import Enum
    
    class MyEnum(str, Enum):
        ab = 'ab'
        cd = 'cd'
        
        @classmethod
        def _missing_(cls, value):
            value = value.lower()
            for member in cls:
                if member.lower() == value:
                    return member
            return None
    

    Generic Version (with FastAPI example)

    from fastapi import FastAPI
    from enum import Enum
    
    
    app = FastAPI()
    
    
    class CaseInsensitiveEnum(str, Enum):
        @classmethod
        def _missing_(cls, value):
            value = value.lower()
            for member in cls:
                if member.lower() == value:
                    return member
            return None
            
    
    class MyEnum(CaseInsensitiveEnum):
        ab = 'aB'
        cd = 'Cd'
    
    
    @app.get("/")
    def main(q: MyEnum):
        return q
    

    In case you needed the Enum query parameter to be defined using Pydantic's BaseModel, you could then use the below (see this answer and this answer for more details):

    from fastapi import Query, Depends
    from pydantic import BaseModel
    
    ...
    
    class MyInput(BaseModel):
        q: MyEnum = Query(...)
    
    
    @app.get("/")
    def main(inp: MyInput = Depends()):
        return inp.q
    

    In both cases, the endpoint could be called as follows:

    http://127.0.0.1:8000/?q=ab
    http://127.0.0.1:8000/?q=aB
    http://127.0.0.1:8000/?q=cD
    http://127.0.0.1:8000/?q=CD
    ...
    

    Example 2

    In Python 3.11+, one could instead use the newly introduced StrEnum, which allows using the auto() feature, resulting in the lower-cased version of the member's name as the value.

    from enum import StrEnum, auto
    
    class MyEnum(StrEnum):    
        AB = auto()
        CD = auto()
        
        @classmethod
        def _missing_(cls, value):
            value = value.lower()
            for member in cls:
                if member == value:
                    return member
            return None