Get started
FastAPI uses Python's asyncio module to improve its I/O performance.
According to the official documentation, when using path or Depends, it will always be asynchronous, regardless of whether you use async def
(to run in coroutines) or def
(to run in the thread pool).
When you use
async def
for your function, you MUST use theawait
keyword avoid "sequence" behavior.
This behavior is slightly different from JavaScript's async-await, which could be the subject of another significant discussion.
The code
Suppose your application looks like this
# create db connection
engine = create_async_engine(
url=get_db_settings().async_connection_string,
echo=True,
)
async_session_global = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async with async_session_global.begin() as session:
try:
yield session
except:
await session.rollback()
raise
finally:
await session.close()
# defind fastapi application
app = FastAPI()
router = APIRouter()
@router.get('/api/async-examples/{id}')
def get_example(id: int, db = Depends(get_async_session)):
return await db.execute(select(Example)).all()
@router.put('/api/async-examples/{id}')
def put_example(id: int, db = Depends(get_async_session)):
await db.execute(update(Example).where(id=id).values(name='testtest', age=123))
await db.commit()
await db.refresh(Example)
return await db.execute(select(Example).filter_by(id=id)).scalar_one()
app.include_router(router)
Firstly, we need fixtures for our tests. Here, I'll be using asyncpg as my async database connector.
# conftest.py
import asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from fastapi import FastAPI
import pytest
engine = create_async_engine(
url='postgresql+asyncpg://...',
echo=True,
)
# drop all database every time when test complete
@pytest.fixture(scope='session')
async def async_db_engine():
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield async_engine
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
# truncate all table to isolate tests
@pytest.fixture(scope='function')
async def async_db(async_db_engine):
async_session = sessionmaker(
expire_on_commit=False,
autocommit=False,
autoflush=False,
bind=async_db_engine,
class_=AsyncSession,
)
async with async_session() as session:
await session.begin()
yield session
await session.rollback()
for table in reversed(SQLModel.metadata.sorted_tables):
await session.execute(f'TRUNCATE {table.name} CASCADE;')
await session.commit()
@pytest.fixture(scope='session')
async def async_client() -> AsyncClient:
return AsyncClient(app=FastAPI(), base_url='http://localhost')
# let test session to know it is running inside event loop
@pytest.fixture(scope='session')
def event_loop():
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
yield loop
loop.close()
# assume we have a example model
@pytest.fixture
async def async_example_orm(async_db: AsyncSession) -> Example:
example = Example(name='test', age=18, nick_name='my_nick')
async_db.add(example)
await async_db.commit()
await async_db.refresh(example)
return example
Then, write our tests
# test_what_ever_you_want.py
# make all test mark with `asyncio`
pytestmark = pytest.mark.asyncio
async def test_get_example(async_client: AsyncClient, async_db: AsyncSession,
async_example_orm: Example) -> None:
response = await async_client.get(f'/api/async-examples/{async_example_orm.id}')
assert response.status_code == status.HTTP_200_OK
assert (await async_db.execute(select(Example).filter_by(id=async_example_orm.id)
)).scalar_one().id == async_example_orm.id
async def test_update_example(async_client: AsyncClient, async_db: AsyncSession,
async_example_orm: Example) -> None:
payload = {'name': 'updated_name', 'age': 20}
response = await async_client.put(f'/api/async-examples/{async_example_orm.id}',
json=payload)
assert response.status_code == status.HTTP_200_OK
await async_db.refresh(async_example_orm)
assert (await
async_db.execute(select(Example).filter_by(id=async_example_orm.id)
)).scalar_one().name == response.json()['data']['name']
The key here is async_db
and event_loop
, and also you have to make sure your program's db session does not using global commit.
Top comments (10)
Nice idea! How do you override the
get_async_session
("get_db") dependency used by FastAPI endpoints to query data with the new one from the tests? I don't see you using app.dependency_overrides anywhere.I didn't override my database connection, I use a test database for testing.
Ahh okay, I see. That is also a viable option. In my case, I want to use a local SQLite in-memory DB for testing. I achieved this using the following code:
oh, I know what you mean.
I separate my environment to archive what you did by using the
pytest-env
package.And my db connection is wrapped in another module
So I didn't do
app.dependency_overrides[get_db]
to connect my test database.And second one, any actions deep inside after await session.begin() may call .commit which brakes this approach.
Yes, but it's hidden in the database(or coroutine). I prefer explicitly calling 'commit' to let developers know what is happening.
Let me suggest you another solution with a wrapped transaction.
I didn't know there was an
async_scoped_session
. It looks much better and less redundant. Thank you!What the reason to use truncate? Shouldn't rollback prevent to save data?
In my code base there are many implicit/explicit commit which some data may not rollback correctly, so I add an extra truncate action