My tests take a long time to run and I am trying to rollback transactions between tests instead of dropping and creating the tables between tests.
The issues is that in some tests I do multiple commits.
EDIT: How do I rollback transactions between tests so that tests will run faster
Here is the Base class used for testing.
import unittest
from app import create_app
from app.core import db
from test_client import TestClient, TestResponse
class TestBase(unittest.TestCase):
def setUp(self):
self.app = create_app('testing')
self.app_context = self.app.app_context()
self.app_context.push()
self.app.response_class = TestResponse
self.app.test_client_class = TestClient
db.create_all()
def tearDown(self):
db.session.remove()
db.drop_all()
db.get_engine(self.app).dispose()
self.app_context.pop()
Here is my attempt at rolling back transactions.
class TestBase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.app = create_app('testing')
cls.app_context = cls.app.app_context()
cls.app_context.push()
cls.app.response_class = TestResponse
cls.app.test_client_class = TestClient
db.create_all()
@classmethod
def tearDown(cls):
db.session.remove()
db.drop_all()
db.get_engine(cls.app).dispose()
def setUp(self):
self.app_content = self.app.app_context()
self.app_content.push()
db.session.begin(subtransactions=True)
def tearDown(self):
db.session.rollback()
db.session.close()
self.app_context.pop()
This is the code we use to do this. Make sure that __start_transaction gets called in your setup, and __close_transaction in your teardown (with an app context if you're using flask-sqlalchemy). As a further hint, only inherit this code in test cases that hit the database, and seperate the code that checks your database function from the code that checks your business logic, because those will still run WAY faster.
def __start_transaction(self):
# Create a db session outside of the ORM that we can roll back
self.connection = db.engine.connect()
self.trans = self.connection.begin()
# bind db.session to that connection, and start a nested transaction
db.session = db.create_scoped_session(options={'bind': self.connection})
db.session.begin_nested()
# sets a listener on db.session so that whenever the transaction ends-
# commit() or rollback() - it restarts the nested transaction
@event.listens_for(db.session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
session.begin_nested()
self.__after_transaction_end_listener = restart_savepoint
def __close_transaction(self):
# Remove listener
event.remove(db.session, "after_transaction_end", self.__after_transaction_end_listener)
# Roll back the open transaction and return the db connection to
# the pool
db.session.close()
# The app was holding the db connection even after the session was closed.
# This caused the db to run out of connections before the tests finished.
# Disposing of the engine from each created app handles this.
db.get_engine(self.app).dispose()
self.trans.rollback()
self.connection.invalidate()
You could use Session.begin_nested
. As long as all your tests are properly calling commit
to close out their sub-transactions, I think you can simply do
session.begin_nested()
run_test(session)
session.rollback()
Which, in my eyes, seems like it should be faster. Probably depends on your database to some extent, however.
If you're using pytest
you can create the following fixtures:
@pytest.fixture(scope='session')
def app():
app = create_app('config.TestingConfig')
log.info('Initializing Application context.')
ctx = app.app_context()
ctx.push()
yield app
log.info('Destroying Application context.')
ctx.pop()
@pytest.fixture(scope='session')
def db():
log.info('Initializating the database')
_db.drop_all()
_db.create_all()
session = _db.session
seed_data_if_not_exists(session)
session.commit()
yield _db
log.info('Destroying the database')
session.rollback()
#_db.drop_all() #if necessary
@pytest.fixture(scope='function')
def session(app, db):
log.info("Creating database session")
session = db.session
session.begin_nested()
yield session
log.info("Rolling back database session")
session.rollback()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With