pythonauthenticationsecurityfastapi

How to add user authentication to specific endpoints in FastAPI?


Could you please tell me how to create endpoints that only work after user authentication? I need only the endpoints located below the line files = [], in the example below, to work this way. At the moment, everything works without authentication, even though it shouldn't.

import uvicorn
from fastapi import FastAPI, HTTPException, Response, Depends, UploadFile
from fastapi.responses import StreamingResponse, FileResponse
from pydantic import BaseModel, Field
from typing import Annotated, List
from authx import AuthX, AuthXConfig
from sqlalchemy import select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from passlib.context import CryptContext


# Хэширование паролей
password_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


# Конфигурация FastAPI и создание сессии БД
app = FastAPI() # Объявляем приложение FastAPI
engine = create_async_engine("sqlite+aiosqlite:///users.db", echo=True) # Объявляем асинхронный движок нашей БД
new_session = async_sessionmaker(engine, expire_on_commit=False) # Запускаем асинхронную сессию с сохранением объектов


async def get_session():
    "Обращение к сессии БД"
    async with new_session() as session:
        yield session


SessionDep = Annotated[AsyncSession, Depends(get_session)] # Создаём зависимость сессии БД от FastAPI


class Base(DeclarativeBase):
    pass


class UserModel(Base): # Создаём таблицу в БД, в которой будут находиться данные пользователей
    __tablename__ = "users"

    # Привязываем данные к стоблцам таблицы
    uid: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, nullable=False)
    username: Mapped[str] = mapped_column(unique=True, index=True)
    password: Mapped[str]


class UserSchema(BaseModel):
    username: str = Field(max_length=10)
    password: str = Field(min_length=8, max_length=16)


# Конфигурация AuthX
config = AuthXConfig() # Конфигурация AuthX
config.JWT_SECRET_KEY = "kirieshki" # Секретный ключ для создания JWT-токенов
config.JWT_ACCESS_COOKIE_NAME = "access_token" # Название куки-токена
config.JWT_TOKEN_LOCATION = ['cookies'] # Расположение куки-токенов
security = AuthX(config=config) # Обозначаем конфиг для модуля AuthX


@app.post("/register",
          tags=["Авторизация"],
          summary="Зарегистрировать аккаунт")
async def register(creds: UserSchema,
                    response: Response,
                    session: SessionDep):
    "Регистрация пользователя"
    # Проверка на существование пользователя с таким же username
    check_user = select(UserModel).where(UserModel.username == creds.username)
    existing_user = await session.execute(check_user)
    if existing_user.scalar_one_or_none():
        raise HTTPException(status_code=400, detail="User already exist")
    
    # Хэширование паролей перед сохранением
    hashed_password = password_context.hash(creds.password)

    # Регистрация нового пользователя, если уже существующй в БД username не найден
    new_user = UserModel(username=creds.username, password=hashed_password) # Создание нового пользователя
    session.add(new_user) # Добавление нового пользователя в БД
    await session.commit() # Сохранение данных в БД
    await session.refresh(new_user) # Обновление информации в БД

    # Создание JWT-токена с payload
    access_token = security.create_access_token(uid=str(new_user.uid))
    security.set_access_cookies(access_token, response)
    
    return {"message": "User has been registered"}



@app.post("/login",
          tags=["Авторизация"],
          summary="Войти в аккаунт")
async def login(creds: UserSchema,
                response: Response,
                session: SessionDep):
    "Вход пользователя"
    # Проверка на существование в БД пользователя с таким же именем и паролем
    check_user = select(UserModel).where(UserModel.username == creds.username)
    existing_user = await session.execute(check_user)
    user = existing_user.scalar_one_or_none()

    # Проверяем, существует ли пользователь и совпадает ли хэш пароля с паролем
    if not user or not password_context.verify(creds.password, user.password):
        raise HTTPException(status_code=401, detail="Invalid username or password")

    # Создание куки-токена
    access_token = security.create_access_token(uid=str(user.uid))
    security.set_access_cookies(access_token, response)

    return {"Success": True}


@app.post("/reset",
          tags=["Работа с БД"],
          summary="Сброс БД")
async def reset_database():
    "Сброс БД"
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.drop_all)
        await conn.run_sync(Base.metadata.create_all)
    return {"База данных успешно сброшена": True}


files = [] # Список файлов, загружаемых на сервер


@app.get("/get_filenames",
        tags=["Получение файлов"],
        summary="Получить названия файлов")
async def get_filenames():
    if len(files) == 0: # Проверка на наличие загруженных файлов
        return {"File list is empty"}
    else:
        return files


def iterfile(filename: str): # Обработчик файлов. Нарезает файлы на чанки для постепенной загрузки
    "Обработка файлов / нарезка на чанки"
    with open(filename, "rb") as file:
        while chunk := file.read(1024 * 1024):
            yield chunk


@app.get("/get_files",
         tags=["Получение файлов"],
         summary="Получить файл")
async def get_file(filename: str):
    return FileResponse(filename)


@app.get("/streaming/{filename}",
         tags=["Получение файлов"],
         summary="Получение файла в стриминге")
async def get_streaming_file(filename: str):
    return StreamingResponse(iterfile(filename)) 


@app.post("/upload",
          tags=["Добавление файлов"],
          summary="Загрузка файла")
async def upload_file(uploaded_file: UploadFile):
    "Загрузка файла"
    file = uploaded_file.file
    filename = uploaded_file.filename
    with open(f"1_{filename}", "wb") as f:
        f.write(file.read())
    files.append(f"1_{filename}")
    return {"File was uploaded": True}


@app.post("/upload_multiple",
          tags=["Добавление файлов"],
          summary="Загрузка нескольких файлов")
async def upload_files(uploaded_files: list[UploadFile]):
    for uploaded_file in upload_files:
        file = uploaded_file.file
        filename = uploaded_file.filename
        with open(f"1_{filename}", "wb") as f:
            f.write(file.read())
        files.append(f"1_{filename}")
    return {"Multiple files were uploaded": True}


if __name__ == "__main__":
    uvicorn.run("main:app", host="127.0.0.1", port=8000, reload=True)

I tried to set up authentication using the AuthX module, but it didn't work.


Solution

  • You could use Dependency Injection for each individual endpoint you wish having it protected:

    from fastapi import FastAPI, APIRouter, Depends
    from fastapi.security import HTTPBasic, HTTPBasicCredentials
    
    app = FastAPI()
    security = HTTPBasic()
    
    
    @app .get("/")
    async def main():
        return "OK"
    
    
    @app.get("/protected")
    async def protected(credentials: HTTPBasicCredentials = Depends(security)):
        return {"username": credentials.username, "password": credentials.password}
    

    or, create separate APIRouter instances (see this as well), and then assign the required security dependencies to the one that requires authentication (i.e., auth_router, in the example below). You could then register endpoints on these router instances, in the same way as you would do with the app object.

    The test acoount credentials used in the example below are: admin and password (for the username and password properties, respectively).

    Working Example

    from fastapi import FastAPI, APIRouter, Depends, Request, HTTPException, status
    from fastapi.security import HTTPBasic, HTTPBasicCredentials
    import hashlib
    from secrets import compare_digest
    from pydantic import BaseModel
    
    app = FastAPI()
    security = HTTPBasic()
    
    
    fake_users_db = {
        "admin": {
            "username": "admin",
            "full_name": "Test Admin",
            "email": "admin@example.com",
            "hashed_password": "c255f02d5cb0812495fcf301d3239f80693f349d6b423a4f1868997c6b211eda",  # Plain Password: password
            "salt": "fc96c333c19cfbf9f99b916c0dc82db1e2f8d88fa72e5f534147e25a19341fa3",
            "disabled": False,
            "priviliged": True,
        }
    }
    
    
    # Class to represent user data
    class User(BaseModel):
        username: str
        full_name: str
        email: str
        disabled: bool
        priviliged: bool
    
    
    # Class to represent user data stored in the database
    class UserInDB(User):
        hashed_password: str
        salt: str
    
    
    # Used to hash a new plain password with a randomly generated salt. Returns the salt and hashed password
    def get_password_hash_and_salt(plain_password: str):
        salt = secrets.token_hex(32)
        hashed_password = hashlib.pbkdf2_hmac('sha256', plain_password.encode('utf-8'), salt.encode('utf-8'), 100000)
        return hashed_password.hex(), salt
        
    
    # Used to verify a given plain password with a previously stored salt and hashed password
    def verify_password(plain_password: str, user: UserInDB):
        hashed_password = hashlib.pbkdf2_hmac('sha256', plain_password.encode('utf-8'), user.salt.encode('utf-8'), 100000)
        current_password_bytes = hashed_password.hex().encode('utf-8')
        correct_password_bytes = user.hashed_password.encode('utf-8')
        is_correct_password = compare_digest(current_password_bytes, correct_password_bytes)
        return is_correct_password
        
    
    def get_user(username: str):
        user = fake_users_db.get(username)
        if user:
            return UserInDB(**user)  # or UserInDB.model_construct(**user)
    
    
    def verify_credentials(request: Request, credentials: HTTPBasicCredentials = Depends(security)):
        user = get_user(credentials.username)
        if user and verify_password(credentials.password, user):
            request.state.user = User(**user.model_dump())
        else:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Incorrect username or password",
                headers={"WWW-Authenticate": "Basic"},
            )
        
    
    unauth_router = APIRouter()
    auth_router = APIRouter(dependencies=[Depends(verify_credentials)])
    
       
    @unauth_router.get("/")
    async def main():
        return "OK"
    
    
    @auth_router.get("/protected")
    async def protected(request: Request):
        return request.state.user
    
    
    app.include_router(unauth_router)
    app.include_router(auth_router)
    

    The example above demonstrates the authentication process as well, and makes use of request.state—as demonstrated here, as well as here and here—in order to get the returned object(s)/value(s) from the dependency function inside the relevant endpoint. In this case, this is the user object, once the credentials are verified. Using return User(**user.model_dump()) inside the verify_credentials dependency function would not actually return the object to the endpoint, as in the example above the verify_credentials dependency is assigned to the entire router (i.e., APIRouter(dependencies=[Depends(verify_credentials)])). Thus, one would need to use request.state to store the relevant returned object to the state that should be accessible from the endpoint. As explained in this answer (please have a look at it for more details), the state received on the requests is a shallow copy of the state received on the lifespan handler. Hence, storing new data to request.state (compared to updating existing data structures that might have been added to state in the lifespan handler; again, see the linked answer above for more details) would only be accessible from the current request instance.

    If one would like to get the returned data directly from the dependency function instead, they should then drop dependencies=[Depends(verify_credentials)] from APIRouter() instatiation, and define the dependency directly in the endpoint, as shown below. In this case, one might not use separate APIRouter instances, and choose to keep using the FastAPI app instance to register endpoints instead. The downside of this approach, however, is that one should need to define the dependency for every individual endpoint that needs to be protected. Example (rest of the code is the same as given earlier):

    # ...
    
    app = FastAPI()
    
    # ...
    
    def verify_credentials(credentials: HTTPBasicCredentials = Depends(security)):
        user = get_user(credentials.username)
        if user and verify_password(credentials.password, user):
            return User(**user.model_dump())  
        else:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Incorrect username or password",
                headers={"WWW-Authenticate": "Basic"},
            )
    
    
    @app.get("/protected")
    async def protected(user: User = Depends(verify_credentials)):
        return user
    

    You should also check OAuth2 scheme (using, for instance, OAuth2PasswordBearer instead of HTTPBasic that was used for demo purposes in the example above) and OAuth2 scopes (start by looking at this, then this and its following chapters, as well as this. An example can be seen below, but note that this is not a complete working example, and one should thus take a look at the detailed tutorial in the links provided above on how to implement the token generation and authentication process.

    from fastapi import FastAPI, APIRouter, Depends
    from fastapi.security import OAuth2PasswordBearer
    
    app = FastAPI()
    oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
    unauth_router = APIRouter()
    auth_router = APIRouter(dependencies=[Depends(oauth2_scheme)])
    
    
    @unauth_router .get("/")
    async def main():
        return "OK"
    
    
    @auth_router.get("/protected")
    async def protected(token: str = Depends(oauth2_scheme)):
        return {"token": token}
        
    
    app.include_router(unauth_router)
    app.include_router(auth_router)
    

    Further related posts that might prove helpful can be found here, as well as here and here.