rustserderust-axum

How to Trigger Specific Error for Enum Deserialization Failure in Axum with Serde?


I’m working on an Axum-based application in Rust, where one of my handlers receives a JSON payload that is deserialized into a MyRequest struct. The struct contains a field my_enum of type MyEnum, which is an enum with variants Foo and Bar.

Here’s a simplified version of my code:

use axum::{extract::Json, response::IntoResponse};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum MyEnum {
    Foo,
    Bar,
}

#[derive(Deserialize)]
struct MyRequest {
    my_enum: MyEnum,
    // other fields...
}

async fn my_handler(
    payload: Result<Json<MyRequest>, axum::extract::rejection::JsonRejection>,
) -> impl IntoResponse {
    match payload {
        Ok(Json(request)) => {
            // Handle the request...
        }
        Err(e) => {
            // Handle the error...
            // Right now, I'm just checking the error message.
            if e.to_string().contains("my_enum") {
                // Specific handling for enum deserialization error
            } else {
                // Generic error handling
            }
        }
    }
}

The Problem

Currently, if the incoming JSON contains an invalid value for my_enum, Serde fails to deserialize it, and Axum returns a JsonRejection. To differentiate between errors caused by an invalid my_enum value and other potential errors, I’m inspecting the error message with a string check like e.to_string().contains("my_enum").

This approach feels brittle because it relies on the specific wording of the error message, which could change and isn't guaranteed to be consistent.

My Question

Is there a way to configure Serde to throw a specific error (or to more reliably identify the deserialization failure) when the value assigned to the my_enum field in MyRequest cannot be deserialized? Ideally, I'd like to handle this scenario without resorting to fragile string matching.

What I've Tried

I’m currently using Axum's standard pattern for handling deserialization:

payload: Result<Json<MyRequest>, axum::extract::rejection::JsonRejection>

And then inspecting the error like this:

if e.to_string().contains("my_enum") {
    // Specific handling for enum deserialization error
}

However, as mentioned, this approach isn’t ideal due to its dependency on error message content.

What I'm Looking For

I’m seeking a more robust solution that allows me to reliably detect when the deserialization of my_enum fails, preferably by leveraging Serde’s capabilities or customizing the deserialization process.


Solution

  • You have to do a two-step deserialization. Create intermediate Request, where you deserialize your special fields to serde_json::Value, and then deserialize this value into your struct. For example:

    #[derive(Serialize, Deserialize)]
    #[serde(rename_all = "lowercase")]
    enum Data {
        Foo,
        Bar,
    }
    
    struct Request {
        data: Data,
        field: i32,
    }
    
    #[derive(Deserialize)]
    struct RequestIntermediate {
        data: serde_json::Value,
        field: i32,
    }
    
    enum APIError {
        RequestError(serde_json::Error),
        DataError(serde_json::Error),
    }
    
    fn deserialize_request(input: &[u8]) -> Result<Request, APIError> {
        let RequestIntermediate { data, field }: RequestIntermediate =
            serde_json::from_slice(input).map_err(APIError::RequestError)?;
        let data: Data = serde_json::from_value(data).map_err(APIError::DataError)?;
    
        Ok(Request { data, field })
    }
    

    If you want to automate extracting this request, you can implement FromRequest. Here is a quick example (without proper error handling).

    struct RequestExtractor(pub Result<Request, APIError>);
    
    #[axum::async_trait]
    impl<S: Send + Sync> FromRequest<S> for RequestExtractor {
       // You should probably also use some other rejection. 
       type Rejection = BytesRejection;
    
        async fn from_request(req: axum::extract::Request, state: &S) -> Result<Self, Self::Rejection> {
            // Ignoring Content-type. You might want to add this check.
            // See axum's implementation of FromRequest for axum::Json.
            let bytes = Bytes::from_request(req, state).await?;
            let request = deserialize_request(bytes.as_ref());
            Ok(RequestExtractor(request))
        }
    }
    
    async fn handler(RequestExtractor(payload): RequestExtractor) {
        match payload {
            Ok(pyload) => todo!(),
            Err(APIError::RequestError(_)) => todo!(),
            Err(APIError::DataError(_)) => todo!(),
        }
    }