I have a Django command that runs a loop until the database becomes available:
import time
from django.db import connections
from django.db.utils import OperationalError
from django.core.management.base import BaseCommand
class Command(BaseCommand):
    """Django command to pause execution until database is available"""
    def handle(self, *args, **options):
        """Handle the command"""
        self.stdout.write('Waiting for database...')
        db_conn = None
        while not db_conn:
            try:
                db_conn = connections['default']
            except OperationalError:
                self.stdout.write('Database unavailable, waiting 1 second...')
                time.sleep(0.1)
        self.stdout.write(self.style.SUCCESS('Database available!'))
I want to create unit tests for this code.
I've managed to test the database being available from the start like follows:
def test_wait_for_db_ready(self):
    """Test waiting for db when db is available"""
    with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
        gi.return_value = True
        call_command('wait_for_db')
        self.assertTrue(True)
Is there a way to test that the command waits for the DB to be available before returning?
So far I've tried the following, however it doesn't work as attempt is not accessible outside of getitem.
def test_wait_for_db(self):
    """Test waiting for db"""
    attempt = 0
    def getitem(alias):
        if attempt < 5:
            attempt += 1
            raise OperationalError()
        else:
            return True
    with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
        gi.side_effect = getitem
        call_command('wait_for_db')
        self.assertGreaterEqual(attempt, 5)
                There are a few ways to achieve this. The simplest approarch may just be to abandon the getitem() nested function and set up the side effect using a sequence of OperationalErrors. You could then verify the number of attempts with the patched gi object's call_count. For example:
def test_wait_for_db(self):
    """Test waiting for db"""
    with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
        gi.side_effect = [OperationalError] * 5 + [True]
        call_command('wait_for_db')
        self.assertGreaterEqual(gi.call_count, 5)  # Verify using the call_count
If you wish to keep the getitem() function, then I think you just need to make the attempt variable nonlocal so it can be seen inside the nested function:
def test_wait_for_db(self):
    """Test waiting for db"""
    attempt = 0
    def getitem(alias):
        nonlocal attempt  # Make the outer attempt variable visible
        if attempt < 5:
            attempt += 1
            raise OperationalError()
        else:
            return True
    with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
        gi.side_effect = getitem
        call_command('wait_for_db')
        self.assertGreaterEqual(attempt, 5)
Thirdly, and as suggested in the comments, you could create a class which has an attempt attribute, and use an instance of the class as the side effect:
def test_wait_for_db(self):
    """Test waiting for db"""
    class Getitem:
        def __init__(self):
            self.attempt = 0
        def __call__(self, item):
            if self.attempt < 5:
                self.attempt += 1
                raise OperationalError()
            else:
                return True
    with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
        getitem = Getitem()
        gi.side_effect = getitem
        call_command('wait_for_db')
        self.assertGreaterEqual(getitem.attempt, 5)  # Access the attempts from the instance
                        You can achieve the same in a more efficient way using the code below.
from unittest.mock import patch
from django.core.management import call_command
from django.db.utils import OperationalError
#gives error when db isn't available
from django.test import TestCase
class CommandTests(TestCase):
    def test_wait_for_db_ready(self):
        """Test waiting for the db when db is`available"""
        with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
            gi.return_value = True
            call_command('wait_for_db')
            self.assertEqual(gi.call_count, 1)
    @patch('time.sleep', return_value=True)
    def test_wait_for_db(self, ts):
        """Test waiting for db"""
        with patch('django.db.utils.ConnectionHandler.__getitem__') as gi:
            gi.side_effect = [OperationalError] * 5 + [True]
            call_command('wait_for_db')
            self.assertEqual(gi.call_count, 6)
                        If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With