Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to set up and tear down a database between tests in FastAPI?

I have set up my unit tests as per FastAPI documentation, but it only covers a case where database is persisted among tests.

What if I want to build and tear down database per test? (for example, the second test below will fail, because the database will no longer be empty after the first test).

I am currently doing it by calling create_all and drop_all (commented out in code below) on the beginning and end of each test, but this is obviously not ideal (if a test fails, the database will be never torn down, impacting the result of the next test).

How can I do it properly? Should I create some kind of Pytest fixture around override_get_db dependency?

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from main import app, get_db
from database import Base

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Base.metadata.create_all(bind=engine)

def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()

app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)

def test_get_todos():
    # Base.metadata.create_all(bind=engine)

    # create
    response = client.post('/todos/', json={'text': 'some new todo'})
    data1 = response.json()
    response = client.post('/todos/', json={'text': 'some even newer todo'})
    data2 = response.json()

    assert data1['user_id'] == data2['user_id']

    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == [
        {'id': data1['id'], 'user_id': data1['user_id'], 'text': data1['text']},
        {'id': data2['id'], 'user_id': data2['user_id'], 'text': data2['text']}
    ]

    # Base.metadata.drop_all(bind=engine)

def test_get_empty_todos_list():
    # Base.metadata.create_all(bind=engine)

    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == []

    # Base.metadata.drop_all(bind=engine)
like image 780
barciewicz Avatar asked Apr 25 '21 16:04

barciewicz


People also ask

Should I mock the database for tests?

Mocking and stubbing are the cornerstones of having quick and simple unit tests. Mocks are useful if you have a dependency on an external system, file reading takes too long, the database connection is unreliable, or if you don't want to send an email after every test.

What is ORM mode FastAPI?

ORMs. FastAPI works with any database and any style of library to talk to the database. A common pattern is to use an "ORM": an "object-relational mapping" library. An ORM has tools to convert ("map") between objects in code and database tables ("relations").


1 Answers

For cleaning up after tests even when they fail (and setting up before tests), pytest provides pytest.fixture.

In your case you want to create all tables before each test, and drop them again afterwards. This can be achieved with the following fixture:

@pytest.fixture()
def test_db():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

And then use it in your tests like so:

def test_get_empty_todos_list(test_db):
    response = client.get('/todos/')

    assert response.status_code == 200
    assert response.json() == []

For each test that has test_db in its argument list pytest first runs Base.metadata.create_all(bind=engine), then yields to the test code, and afterwards makes sure that Base.metadata.drop_all(bind=engine) gets run, even when the tests fail.

The full code:

import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from main import app, get_db
from database import Base


SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


@pytest.fixture()
def test_db():
    Base.metadata.create_all(bind=engine)
    yield
    Base.metadata.drop_all(bind=engine)

app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_get_todos(test_db):
    response = client.post("/todos/", json={"text": "some new todo"})
    data1 = response.json()
    response = client.post("/todos/", json={"text": "some even newer todo"})
    data2 = response.json()

    assert data1["user_id"] == data2["user_id"]

    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == [
        {"id": data1["id"], "user_id": data1["user_id"], "text": data1["text"]},
        {"id": data2["id"], "user_id": data2["user_id"], "text": data2["text"]},
    ]


def test_get_empty_todos_list(test_db):
    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == []

As your application grows, setting up and tearing down the whole database for each test might get slow.

A solution for that is to only set up the db once and then never actually commit anything to it. This can be achieved using nested transactions and rollbacks:

import pytest
import sqlalchemy as sa
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker

from database import Base
from main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"

engine = sa.create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

# Set up the database once
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)


# These two event listeners are only needed for sqlite for proper
# SAVEPOINT / nested transaction support. Other databases like postgres
# don't need them. 
# From: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
    # disable pysqlite's emitting of the BEGIN statement entirely.
    # also stops it from emitting COMMIT before any DDL.
    dbapi_connection.isolation_level = None


@sa.event.listens_for(engine, "begin")
def do_begin(conn):
    # emit our own BEGIN
    conn.exec_driver_sql("BEGIN")


# This fixture is the main difference to before. It creates a nested
# transaction, recreates it when the application code calls session.commit
# and rolls it back at the end.
# Based on: https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
@pytest.fixture()
def session():
    connection = engine.connect()
    transaction = connection.begin()
    session = TestingSessionLocal(bind=connection)

    # Begin a nested transaction (using SAVEPOINT).
    nested = connection.begin_nested()

    # If the application code calls session.commit, it will end the nested
    # transaction. Need to start a new one when that happens.
    @sa.event.listens_for(session, "after_transaction_end")
    def end_savepoint(session, transaction):
        nonlocal nested
        if not nested.is_active:
            nested = connection.begin_nested()

    yield session

    # Rollback the overall transaction, restoring the state before the test ran.
    session.close()
    transaction.rollback()
    connection.close()


# A fixture for the fastapi test client which depends on the
# previous session fixture. Instead of creating a new session in the
# dependency override as before, it uses the one provided by the
# session fixture.
@pytest.fixture()
def client(session):
    def override_get_db():
        yield session

    app.dependency_overrides[get_db] = override_get_db
    yield TestClient(app)
    del app.dependency_overrides[get_db]


def test_get_empty_todos_list(client):
    response = client.get("/todos/")

    assert response.status_code == 200
    assert response.json() == []

Having two fixtures (session and client) here has an additional advantage:

If a test only talks to the API, then you don't need to remember adding the db fixture explicitly (but it will still be invoked implicitly). And if you want to write a test that directly talks the db, you can do that as well:

def test_something(session):
    session.query(...)

Or both, if you for example want to prepare the db state before an API call:

def test_something_else(client, session):
    session.add(...)
    session.commit()
    client.get(...)

Both the application code and test code will see the same state of the db.

like image 193
mihi Avatar answered Sep 28 '22 04:09

mihi