pythondjangodjango-rest-frameworkdjango-filter

DRF. Filtering child's class objects linked to parent with foreign key from a parent class api routing


I work on some kind of e-commercial shop (just for learning purpose). I have parent model Category and child model Product which has a field category which references to Category object with Many-To-One relations (ForeignKey).

Idea is to make filtration from a parent class e.g. Category to retrieve all Product objects matches with given parameter(s).

For example: when I send a request to route "localhost/api/categories/" I get all categories and subcategories (just for wider vision: subcategory routing also works. If i send request to localhost/api/categories/smartphones/ it returns a response with subcategories which are children of smartphones and so on).

Now I want to implement filtering in such way : when I send a request to route "localhost/api/categories/smartphones/brand?=apple" it has to return all Product objects with brand field equal to "apple". And this pattern has to work for any category so I don't have to hardcode apple as Category object for every single category that might contain Apple devices.

My code right now (not all lines included but most important) :

products/models.py

class Product(models.Model):
"""Model representing a product."""
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=255)
category = models.ForeignKey(Category, on_delete=models.CASCADE, related_name='products')
brand = models.ForeignKey('Brand', on_delete=models.CASCADE, related_name='products')
slug = models.SlugField(unique=True, blank=True, default=id)

class Brand(models.Model):
    """A model representing the brand of the product."""
    id = models.AutoField(primary_key=True, editable=False)
    name = models.CharField(max_length=255)
    slug = models.SlugField(unique=True, editable=False)

products/serializers.py

class ProductSerializer(serializers.ModelSerializer):
    class Meta:
        model = Product
        fields = '__all__'

products/views.py

class ProductViewSet(viewsets.ModelViewSet):
    queryset = Product.objects.all()
    serializer_class = ProductSerializer
    lookup_field = 'slug'

products/urls.py

router = DefaultRouter()
router.register("", ProductViewSet)

urlpatterns = [
    path("", include(router.urls)),
]

categories/models.py

class Category(MPTTModel):
""" Category model inherits from a side application class MPTT designed for more
convenient work with model tree structure """

objects = CategoryManager()  # Use of a custom manager defined above

id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
name = models.CharField(max_length=120, db_index=True, unique=True)
slug = models.SlugField(unique=True)
parent = TreeForeignKey(   # MPTT model Field represents a parent of subcategory (if exists) in tree structure.
    "self",
    blank=True,
    null=True,
    related_name="child",
    on_delete=models.CASCADE
)

categories/serializers.py

class RecursiveField(serializers.Serializer):
    def to_representation(self, value):
        serializer = self.parent.parent.__class__(value, context=self.context)
        return serializer.data


class CategorySerializer(serializers.ModelSerializer):
    child = RecursiveField(many=True, read_only=True)

categories/views.py class CategoriesAPIViewSet(viewsets.ModelViewSet):

""" Standard ModelViewSet which implements CRUD with minor changes. """

queryset = Category.objects.all()
serializer_class = CategorySerializer
lookup_field = 'slug'

categories/urls.py

router = SimpleRouter()

router.register("", CategoriesAPIViewSet)

urlpatterns += router.urls

Solution

  • As mentioned at https://www.django-rest-framework.org/api-guide/viewsets/#modelviewset , you can implement the get_queryset method to provide a different queryset based on the request.

    You can use self.request.query_params to access the brand that has been sent, https://www.django-rest-framework.org/api-guide/requests/#query_params

    So your method might look like:

    def get_queryset(self):
      return Product.objects.filter(brand=self.request.query_params['brand'])
    

    Or similar :)

    EDIT based on comment:

    If the query param is optional, you could do this:

    def get_queryset(self):
      if 'brand' in self.request.query_params:
        return Product.objects.filter(brand=self.request.query_params['brand'])
      return Product.objects.all()