pythondjangounit-testinggoogle-cloud-firestorefirebase-admin

Django Mock database not working properly


I was coding a unit test method for a post method, but for some reason, it keeps modifying the actual database instead of the mock database. This is my test method:

def test_post_contest(self):
    mock_collection = self.mock_db.collection.return_value
    mock_document = mock_collection.document.return_value

    mock_document.set.return_value = None

    response = self.client.post(
        reverse('contest-list'),
        data={
            'ONI2024F2': {
                'name': 'ONI2024F2', 
                'problems': ['problem1', 'problem2'],
                'administratorId': '0987654321'
            }
        },
        format='json'
    )
    self.assertEqual(response.status_code, 201)
    self.assertEqual(response.json(), {'id': 'ONI2024F2'})

This is the setUp method:

    def setUp(self):
        self.client = APIClient()
        self.mock_db = patch('firebase_config.db').start()
        self.addCleanup(patch.stopall)

And this is the firebase_config.py file

import os
import firebase_admin
from firebase_admin import credentials, firestore

#Path to Firebase credentials
cred_path = os.path.join(os.path.dirname(__file__), 'credentials','firebase_credentials.json')

#Initialize Firebase
cred = credentials.Certificate(cred_path)
firebase_admin.initialize_app(cred)

#Firestore client
db = firestore.client()

This is the whole testing file:

from django.test import TestCase
from django.urls import reverse
from unittest.mock import patch, MagicMock
from rest_framework.test import APIClient

class ContestViewTest(TestCase):
    
    def setUp(self):
        self.client = APIClient()
        self.mock_db = patch('firebase_config.db').start()
        self.addCleanup(patch.stopall)

    @patch('OIECApp.views.ContestView.referenceToJson')
    def test_get_contest_list(self, mock_referenceToJson):
        mock_referenceToJson.return_value = {'name': 'ONI2024F1'}

        mock_collection = self.mock_db.collection.return_value
        mock_doc = MagicMock()
        mock_doc.id = 'ONI2024F1'
        mock_collection.stream.return_value = [mock_doc]

        response = self.client.get(reverse('contest-list'))

        self.assertEqual(response.status_code, 200)
        self.assertEqual(response.json(), {'IOI2023': {'name': 'ONI2024F1'}, 'ONI2024F1': {'name': 'ONI2024F1'}})

    def test_post_contest(self):
        mock_collection = self.mock_db.collection.return_value
        mock_document = mock_collection.document.return_value

        mock_document.set.return_value = None

        response = self.client.post(
            reverse('contest-list'),
            data={
                'ONI2024F2': {
                    'name': 'ONI2024F2', 
                    'problems': ['problem1', 'problem2'],
                    'administratorId': '0987654321'
                }
            },
            format='json'
        )
        self.assertEqual(response.status_code, 201)
        self.assertEqual(response.json(), {'id': 'ONI2024F2'})

    @patch('OIECApp.views.ContestView.referenceToJson')
    def test_get_contest_detail(self, mock_referenceToJson):
        mock_referenceToJson.return_value = {'name': 'ONI2024F1'}

        mock_collection = self.mock_db.collection.return_value
        mock_doc = MagicMock()
        mock_doc.id = 'ONI2024F1'
        mock_collection.stream.return_value = [mock_doc]

        response = self.client.get(reverse('contest-detail', args=['ONI2024F1']))

        self.assertEqual(response.status_code, 200)
        self.assertEqual(response.json(), {'ONI2024F1': {'name': 'ONI2024F1'}})

    def test_put_contest(self):
        mock_collection = self.mock_db.collection.return_value
        mock_document = mock_collection.document.return_value

        # Mock the update operation
        mock_document.update.return_value = None

        response = self.client.put(
            reverse('contest-detail', args=['ONI2025F1']),
            data={
                'ONI2025F1': {
                    'name': 'ONI2025F1',
                    'problems': ['problem1', 'problem3'],
                    'administratorId': '0987654321'
                }
            },
            format='json'
        )
        self.assertEqual(response.status_code, 201)
        self.assertEqual(response.json(), {'id': 'ONI2025F1'})

    def test_delete_contest(self):
        mock_collection = self.mock_db.collection.return_value
        mock_document = mock_collection.document.return_value

        # Mock the delete operation
        mock_document.delete.return_value = None

        response = self.client.delete(reverse('contest-detail', args=['ONI2024F2']))
        self.assertEqual(response.status_code, 204)

And this is the file with the http methods:

from django.http import Http404, JsonResponse
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView
from firebase_config import db
from firebase_admin import firestore

collection_name = "Contest"

def referenceToJson(element):
    doc_dict = element.to_dict()
    for x in doc_dict:
        tmp = doc_dict[x]
        if isinstance(tmp, list):
            counter = 0
            for y in tmp:
                if isinstance(y, firestore.firestore.DocumentReference):
                    doc_ref = y.get().id
                    tmp[counter]=doc_ref
                counter += 1
        else:
            if isinstance(tmp, firestore.firestore.DocumentReference):
                doc_ref = tmp.get().id
                doc_dict[x]=doc_ref
    return doc_dict

class ContestList(APIView):

    def get(self, request):
        contest_ref = db.collection(collection_name)
        contests = dict ()
        
        for doc in contest_ref.stream():
            idContest = doc.id
            doc_dict = referenceToJson(doc)
            contests[idContest]= doc_dict
        return JsonResponse(contests)
    
    def post(self, request):
        try: 
            datos = request.data
            clave = list(datos.keys())[0]
            valores = datos[clave]
            for c in valores.keys():
                if(c == 'Problems') or (c == 'problems'):
                    if isinstance(valores[c], list):
                        contador = 0
                        for element in valores[c]:
                            print(element)
                            ingreso = db.collection('Problems').document(valores[c][contador])
                            valores[c][contador]=ingreso
                            contador += 1
                    else:
                        ingreso = db.collection('Problems').document(valores[c])
                        valores[c] = ingreso
                elif(c == 'administratorId'):
                    ingreso = db.collection('Administrator').document(valores[c])
                    valores[c] = ingreso
            db.collection(collection_name).document(clave).set(valores)
            return JsonResponse({'id':clave}, status=201)
        except Exception as e:  
            return JsonResponse({'error':str(e)}, status=500)

class ContestDetail(APIView):

    def get(self, request, id):
        contest_ref = db.collection(collection_name)
        contest = dict ()
        
        for doc in contest_ref.stream():
            if(doc.id == id):
                idContest = doc.id
                doc_dict = referenceToJson(doc)
                contest[idContest] = doc_dict

        return JsonResponse(contest)

    def put(self,request,id):
        try: 
            datos = request.data
            clave = list(datos.keys())[0]
            valores = datos[clave]
            for c in valores.keys():
                if(c == 'Problems') or (c == 'problems'):
                    if isinstance(valores[c], list):
                        list_ref = []
                        for element in valores[c]:
                            ingreso = db.collection('Problems').document(element)
                            list_ref.append(ingreso)
                        valores[c] = list_ref
                    else:
                        ingreso = db.collection('Problems').document(valores[c])
                        valores[c] = ingreso
                elif(c == 'administratorId'):
                    ingreso = db.collection('Administrator').document(valores[c])
                    valores[c] = ingreso
            db.collection(collection_name).document(clave).update(valores)
            return JsonResponse({'id':clave}, status=201)
        except Exception as e:  
            return JsonResponse({'error':str(e)}, status=500)

    def delete(self, request, id):
        try:
            contest_ref = db.collection(collection_name).document(id)
            contest_ref.delete()
            return JsonResponse({'message': 'Deleted'}, status=204)
        except Exception as e:
            return JsonResponse({'error': str(e)}, status=500)

I thought that it could be the firebase-config file, so I tried changing it but then all the tests raised exceptions.


Solution

  • Instead of from firebase_config import db, use import firebase_config and refer to db using firebase_config.db.

    Once you from firebase_config import db, you need to patch my_module.db and not firebase_config.db.
    But, that isn't scalable; You don't want to patch every usage explicitly.