I have set up my unit tests as per FastAPI documentation, but it only covers a case where database is persisted among tests.
What if I want to build and tear down database per test? (for example, the second test below will fail, because the database will no longer be empty after the first test).
I am currently doing it by calling create_all
and drop_all
(commented out in code below) on the beginning and end of each test, but this is obviously not ideal (if a test fails, the database will be never torn down, impacting the result of the next test).
How can I do it properly? Should I create some kind of Pytest fixture around override_get_db
dependency?
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from main import app, get_db
from database import Base
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base.metadata.create_all(bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_get_todos():
# Base.metadata.create_all(bind=engine)
# create
response = client.post('/todos/', json={'text': 'some new todo'})
data1 = response.json()
response = client.post('/todos/', json={'text': 'some even newer todo'})
data2 = response.json()
assert data1['user_id'] == data2['user_id']
response = client.get('/todos/')
assert response.status_code == 200
assert response.json() == [
{'id': data1['id'], 'user_id': data1['user_id'], 'text': data1['text']},
{'id': data2['id'], 'user_id': data2['user_id'], 'text': data2['text']}
]
# Base.metadata.drop_all(bind=engine)
def test_get_empty_todos_list():
# Base.metadata.create_all(bind=engine)
response = client.get('/todos/')
assert response.status_code == 200
assert response.json() == []
# Base.metadata.drop_all(bind=engine)
Mocking and stubbing are the cornerstones of having quick and simple unit tests. Mocks are useful if you have a dependency on an external system, file reading takes too long, the database connection is unreliable, or if you don't want to send an email after every test.
ORMs. FastAPI works with any database and any style of library to talk to the database. A common pattern is to use an "ORM": an "object-relational mapping" library. An ORM has tools to convert ("map") between objects in code and database tables ("relations").
For cleaning up after tests even when they fail (and setting up before tests), pytest provides pytest.fixture
.
In your case you want to create all tables before each test, and drop them again afterwards. This can be achieved with the following fixture:
@pytest.fixture()
def test_db():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
And then use it in your tests like so:
def test_get_empty_todos_list(test_db):
response = client.get('/todos/')
assert response.status_code == 200
assert response.json() == []
For each test that has test_db
in its argument list pytest first runs Base.metadata.create_all(bind=engine)
, then yields to the test code, and afterwards makes sure that Base.metadata.drop_all(bind=engine)
gets run, even when the tests fail.
The full code:
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from main import app, get_db
from database import Base
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def override_get_db():
try:
db = TestingSessionLocal()
yield db
finally:
db.close()
@pytest.fixture()
def test_db():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
app.dependency_overrides[get_db] = override_get_db
client = TestClient(app)
def test_get_todos(test_db):
response = client.post("/todos/", json={"text": "some new todo"})
data1 = response.json()
response = client.post("/todos/", json={"text": "some even newer todo"})
data2 = response.json()
assert data1["user_id"] == data2["user_id"]
response = client.get("/todos/")
assert response.status_code == 200
assert response.json() == [
{"id": data1["id"], "user_id": data1["user_id"], "text": data1["text"]},
{"id": data2["id"], "user_id": data2["user_id"], "text": data2["text"]},
]
def test_get_empty_todos_list(test_db):
response = client.get("/todos/")
assert response.status_code == 200
assert response.json() == []
As your application grows, setting up and tearing down the whole database for each test might get slow.
A solution for that is to only set up the db once and then never actually commit anything to it. This can be achieved using nested transactions and rollbacks:
import pytest
import sqlalchemy as sa
from fastapi.testclient import TestClient
from sqlalchemy.orm import sessionmaker
from database import Base
from main import app, get_db
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = sa.create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Set up the database once
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
# These two event listeners are only needed for sqlite for proper
# SAVEPOINT / nested transaction support. Other databases like postgres
# don't need them.
# From: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None
@sa.event.listens_for(engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.exec_driver_sql("BEGIN")
# This fixture is the main difference to before. It creates a nested
# transaction, recreates it when the application code calls session.commit
# and rolls it back at the end.
# Based on: https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#joining-a-session-into-an-external-transaction-such-as-for-test-suites
@pytest.fixture()
def session():
connection = engine.connect()
transaction = connection.begin()
session = TestingSessionLocal(bind=connection)
# Begin a nested transaction (using SAVEPOINT).
nested = connection.begin_nested()
# If the application code calls session.commit, it will end the nested
# transaction. Need to start a new one when that happens.
@sa.event.listens_for(session, "after_transaction_end")
def end_savepoint(session, transaction):
nonlocal nested
if not nested.is_active:
nested = connection.begin_nested()
yield session
# Rollback the overall transaction, restoring the state before the test ran.
session.close()
transaction.rollback()
connection.close()
# A fixture for the fastapi test client which depends on the
# previous session fixture. Instead of creating a new session in the
# dependency override as before, it uses the one provided by the
# session fixture.
@pytest.fixture()
def client(session):
def override_get_db():
yield session
app.dependency_overrides[get_db] = override_get_db
yield TestClient(app)
del app.dependency_overrides[get_db]
def test_get_empty_todos_list(client):
response = client.get("/todos/")
assert response.status_code == 200
assert response.json() == []
Having two fixtures (session
and client
) here has an additional advantage:
If a test only talks to the API, then you don't need to remember adding the db fixture explicitly (but it will still be invoked implicitly). And if you want to write a test that directly talks the db, you can do that as well:
def test_something(session):
session.query(...)
Or both, if you for example want to prepare the db state before an API call:
def test_something_else(client, session):
session.add(...)
session.commit()
client.get(...)
Both the application code and test code will see the same state of the db.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With