pythonpython-3.xdjango-rest-frameworkdrf-yasg

Reusing choice definition in swagger generation with Django Rest Framework


I have a Django (Django Rest Framework) webservice that uses drf-yasg to generate a swagger.json file. In the model, I have a couple of enums/choicefields that are used in more than one place. By default, drf-yasg defines the field inline for each occurrance:

Choices = serializers.ChoiceField(choices=['a', 'b', 'c'])

class SomeObject(serializers.Serializer):
    field_1 = Choices
    field_2 = Choices

Produces the following definitions in the swagger file:

{
  "definitions": {
    "SomeObject": {
      "required": [ "field_1", "field_2" ],
      "type": "object",
      "properties": {
        "field_1": {
          "title": "Field 1",
          "type": "string",
          "enum": [ "a", "b", "c" ]
        },
        "field_2": {
          "title": "Field 1",
          "type": "string",
          "enum": [ "a", "b", "c" ]
        }
      }
    }
  }
}

This is a slight problem, since it makes client-codegen tools generate each enum as its own type, instead of reusing the definition. So instead I would like to create a swaggerfile like so:

{
  "definitions": {
    "Choices": {
      "title": "Field 1",
      "type": "string",
      "enum": [ "a", "b", "c" ]
    },
    "SomeObject": {
      "required": [ "field_1", "field_2" ],
      "type": "object",
      "properties": {
        "field_1": {
          "$ref": "#/definitions/Choices"
        },
        "field_2": {
          "$ref": "#/definitions/Choices"
        }
      }
    }
  }
}

Is it possible to enable this behaviour within the Django Rest Framework?


Solution

  • In case anybody sees this and needs some pointers. I ended up implementing this as follows:

    Create a subclass of ChoiceField that is used to signal that the enum should be implemented as a reference, along with some other checks. The check for str is to ensure that the serialization knows how to handle the values:

    from rest_framework import serializers
    
    from enum import Enum
    
    class ReferenceEnumField(serializers.ChoiceField):
      def __init__(self, enum_type, **kwargs):
        if not issubclass(enum_type, str):
          raise TypeError("enum_type should inherit from str in order to be json-serializable.")
        if not issubclass(enum_type, Enum):
          raise TypeError("enum_type should be an Enum")
        self.enum_name = enum_type.__name__
        super().__init__(choices=[enum.name for enum in enum_type], **kwargs)
    

    Then there is the inspector that can be added to the decorator as follows:

    from drf_yasg.inspectors.base import NotHandled
    from drf_yasg.inspectors.field import ReferencingSerializerInspector
    from drf_yasg import openapi
    from drf_yasg.errors import SwaggerGenerationError
    
    from .ReferenceEnumfield import ReferenceEnumField
    
    class EnumAsReferenceInspector(ReferencingSerializerInspector):
      accepting_objects = True
    
      @classmethod
      def set_accepting_objects(cls, value):
        cls.accepting_objects = value
    
      def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
        SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
    
    
        if EnumAsReferenceInspector.accepting_objects and isinstance(field, ReferenceEnumField):
          try:
            # Avoid infinite recursion by setting the class to not accept objects to serialize.
            EnumAsReferenceInspector.set_accepting_objects(False)
            if swagger_object_type != openapi.Schema:
              raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
    
            ref_name = field.enum_name
    
            def make_schema_definition(enum=field):
              return self.probe_field_inspectors(enum, ChildSwaggerType, use_references)
            if not ref_name or not use_references:
              return make_schema_definition()
    
            definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
            actual_schema = definitions.setdefault(ref_name, make_schema_definition)
            actual_schema._remove_read_only()
    
            return openapi.SchemaRef(definitions, ref_name)
          finally:
            EnumAsReferenceInspector.set_accepting_objects(True)
    
        return NotHandled
    

    It's cobbled together from some other code in the library, so I'm not sure if there are lines that could be ommitted or done differently, but it does the trick.