I started to use Axum a few weeks ago, honestly, I'm a fan of the framework, so I'm writing this article to document my learning. In this article, we are going to build a REST API using Axum as a web framework and Sqlx for SQL queries.
In this article, we will build a little REST API with Axum and Sqlx for the database.
If you don't know what Axum is, here is what its page says:
Axum is a web application framework that focuses on ergonomics and modularity.
High-level features:
Route requests to handlers with a macro-free API.
Declaratively parse requests using extractors.
Simple and predictable error handling model.
Generate responses with minimal boilerplate.
Take full advantage of the tower and tower-http ecosystem of middleware, services, and utilities. In particular the last point is what sets axum apart from other frameworks. axum doesn't have its own middleware system but instead uses tower::Service. This means axum gets timeouts, tracing, compression, authorization, and more, for free. It also enables you to share middleware with applications written using hyper or tonic.
Here is Axum's documentation.
About Sqlx:
SQLx is an async, pure Rust SQL crate featuring compile-time checked queries without a DSL.
Truly Asynchronous. Built from the ground up using async/await for maximum concurrency.
Compile-time checked queries (if you want). See SQLx is not an ORM.
Database Agnostic. Support for PostgreSQL, MySQL, SQLite, and MSSQL.
Pure Rust. The Postgres and MySQL/MariaDB drivers are written in pure Rust using zero unsafe code.
Runtime Agnostic. Works on different runtimes (async-std / tokio / actix) and TLS backends (native-tls, rustls). The SQLite driver uses the libsqlite3 C library as SQLite is an embedded database (the only way we could be pure Rust for SQLite is by porting all of SQLite to Rust).
SQLx uses #![forbid(unsafe_code)] unless the SQLite feature is enabled. As the SQLite driver interacts with C, those interactions are unsafe.
Cross-platform. Being native Rust, SQLx will compile anywhere Rust is supported.
Built-in connection pooling with sqlx::Pool.
Row streaming. Data is read asynchronously from the database and decoded on-demand.
Automatic statement preparation and caching. When using the high-level query API (sqlx::query), statements are prepared and cached per connection.
Simple (unprepared) query execution including fetching results into the same Row types used by the high-level API. Supports batch execution and returning results from all statements.
Transport Layer Security (TLS) where supported (MySQL and PostgreSQL).
Asynchronous notifications using LISTEN and NOTIFY for PostgreSQL.
Nested transactions with support for saving points.
Any database driver for changing the database driver at runtime. An AnyPool connects to the driver indicated by the URL scheme.
Here is Sqlx documentation
First, we generate our project folder.
cargo new axum_crud_api
Now we add the dependencies.
Cargo.toml
[package]
name = "axum_crud_api"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
axum = "0.5.9"
tokio = { version = "1.0", features = ["full"] }
serde = "1.0.137"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"]}
sqlx = { version = "0.5", features = ["runtime-tokio-native-tls", "json", "postgres"] }
anyhow = "1.0.58"
serde_json = "1.0.57"
tower-http = { version = "0.3.4", features = ["trace"] }
Let's write an example. This is like Hello World's example from its Github page, just with a few changes, here is the source code.
main.rs
use axum::{
routing::{get},
Router,
};
use std::net::SocketAddr;
#[tokio::main]
async fn main() {
let app = Router::new()
.route("/", get(root));
let addr = SocketAddr::from(([127, 0 , 0, 1], 8000));
println!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn root() -> &'static str {
"Hello, World!"
}
Now we run the code
cargo run
It should print "listening on 127.0.0.1:8000" in our console, and if we copy the number and paste it into our browser we should see this page:
This is going to be our directory structure:
axum_crud_api/
---migrations/
---.env
---src/
---errors.rs
---main.rs
---controllers/
---task.rs
---models/
---task.rs
Models
Let's create a folder to store our model's app and create a file named task.rs in it.
task.rs
use serde::{Deserialize, Serialize};
#[derive(sqlx::FromRow, Deserialize, Serialize)]
pub struct Task {
pub id: i32,
pub task: String,
}
#[derive(sqlx::FromRow, Deserialize, Serialize)]
pub struct NewTask {
pub task: String,
}
We create a file in our root directory to store our database URL:
DATABASE_URL = postgresql://user:password@locahost:host/database
To create our database, we need to have installed sqlx-cli, here are the instructions from the doc:
# supports all databases supported by SQLx
$ cargo install sqlx-cli
# only for Postgres
$ cargo install sqlx-cli --no-default-features --features native-tls,postgres
# use vendored OpenSSL (build from source)
$ cargo install sqlx-cli --features openssl-vendored
# use Rustls rather than OpenSSL (be sure to add the features for the databases you intend to use!)
$ cargo install sqlx-cli --no-default-features --features rustls
After we have sqlx-cli installed on our machine, we run the next code in our terminal to create our database:
sqlx database create
Then we run this code in our terminal, it creates a new in migrations/<timestamp>-<name>.sql
and there is where we can add our schema:
sqlx migrate add task
In migrations/task.sql:
CREATE TABLE task (
id SERIAL PRIMARY KEY,
task varchar(255) NOT NULL
);
Then we run the following code in our terminal to run migrations:
sqlx migrate run
If the migration was applied, it will show the next message in our terminal:
Applied <timestamp>task.sql
Now, we will change a few things in our main.rs file to connect our app to the database.
main.rs
We are using Postgres, so we need to import PgPoolOptions
first to handle the connection.
use axum::{
extract::{Extension},routing::{get, post}, Router,
};
use sqlx::postgres::PgPoolOptions;
use std::net::SocketAddr;
use std::fs;
use anyhow::Context;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let env = fs::read_to_string(".env").unwrap();
let (key, database_url) = env.split_once('=').unwrap();
assert_eq!(key, "DATABASE_URL");
tracing_subscriber::fmt::init();
let pool = PgPoolOptions::new()
.max_connections(50)
.connect(&database_url)
.await
.context("could not connect to database_url")?;
let app = Router::new()
.route("/hello", get(root));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await?;
Ok(())
}
We create a pool instance and set the number of max connections to 50 connections. Then we pass our database URL to the connect
method, which creates a new pool from PgPoolOptions
and immediately opens at least one connection. For more details, here is the doc.
We use std::fs
to read our DATABASE_URL
in our .env
file, and stored in database_url
variable.
Controllers
We create a task.rs file.
task.rs
In this file, we are going to add the controllers to do CRUD operations.
To create our handlers we need to import response::IntoResponse
, http::StatusCode
, Extension
to extract state and Json
.
GET
use axum::response::IntoResponse;
use axum::http::StatusCode;
use axum::{Extension, Json};
use sqlx::PgPool;
use crate::{
models::task
};
pub async fn all_tasks(Extension(pool): Extension<PgPool>) -> impl IntoResponse {
let sql = "SELECT * FROM task ".to_string();
let task = sqlx::query_as::<_, task::Task>(&sql).fetch_all(&pool).await.unwrap();
(StatusCode::OK, Json(task))
}
The all_tasks
controller retrieves all the tasks in our database. It receives PgPool
parameters and returns all the tasks in a JSON format.
We use query_as
to make a SQL query that is mapped to a concrete type using FromRow, in this case, task::Task
, and use the fetch_all
function, it executes the query and returns all the generated results, collected into a Vec. More details here
Errors.rs
In this file, we are going to implement the IntoResponse
trait to create custom errors and use them as a response for our controllers.
use axum::{http::StatusCode, response::IntoResponse, Json};
use serde_json::json;
pub enum CustomError {
BadRequest,
TaskNotFound,
InternalServerError,
}
impl IntoResponse for CustomError {
fn into_response(self) -> axum::response::Response {
let (status, error_message) = match self {
Self::InternalServerError => (
StatusCode::INTERNAL_SERVER_ERROR,
"Internal Server Error",
),
Self::BadRequest=> (StatusCode::BAD_REQUEST, "Bad Request"),
Self::TaskNotFound => (StatusCode::NOT_FOUND, "Task Not Found"),
};
(status, Json(json!({"error": error_message}))).into_response()
}
}
GET by Id
pub async fn task(Path(id):Path<i32>,
Extension(pool): Extension<PgPool>) -> Result <Json<task::Task>, CustomError> {
let sql = "SELECT * FROM task where id=$1".to_string();
let task: task::Task = sqlx::query_as(&sql)
.bind(id)
.fetch_one(&pool)
.await
.map_err(|_| {
CustomError::TaskNotFound
})?;
Ok(Json(task))
}
In the task
controller we pass and id
and use the Path
extractor to extract from the URL. We pass the id
to the query, and if it is in the database the controller returns the task as JSON, if it is not, a Task Not Found
as a message.
POST
pub async fn new_task(Json(task): Json<task::NewTask>,
Extension(pool): Extension<PgPool>) -> Result <(StatusCode,
Json<task::NewTask>), CustomError> {
if task.task.is_empty() {
return Err(CustomError::BadRequest)
}
let sql = "INSERT INTO task (task) values ($1)";
let _ = sqlx::query(&sql)
.bind(&task.task)
.execute(&pool)
.await
.map_err(|_| {
CustomError::InternalServerError
})?;
Ok((StatusCode::CREATED, Json(task)))
}
The new_task
controller has a Json
extractor as a parameter. According to its doc, JSON is an extractor that consumes the request body and deserializes it as JSON into some target type. In this code, the target type is NewTask
.
We check if the JSON has the task field empty, and if it does, the function returns a Bad Request
error message.
We use sql::query
to make an SQL query and pass to it our query, store it in the sql
variable, and use the bind
function to bind the value to the query, in this case, the task
field of NewTask
. If there is a problem with the query it will return an Internal Server Error
message.
PUT
pub async fn update_task(Path(id): Path<i32>,
Json(task): Json<task::UpdateTask>, Extension(pool): Extension<PgPool>)
-> Result <(StatusCode, Json<task::UpdateTask>), CustomError> {
let sql = "SELECT * FROM task where id=$1".to_string();
let _find: task::Task = sqlx::query_as(&sql)
.bind(id)
.fetch_one(&pool)
.await
.map_err(|_| {
CustomError::TaskNotFound
})?;
sqlx::query("UPDATE task SET task=$1 WHERE id=$2")
.bind(&task.task)
.bind(id)
.execute(&pool)
.await;
Ok((StatusCode::OK, Json(task)))
}
In update_task
we passed it and id
through the path and use the Path
extractor, and the JSON with the fields we want to update, in this case only the task
field. But first, we check that the id
passed is in the database, if it's not, the controller returns a Task Not Found
message.
Then we pass the SQL query to the query
function and pass the task
and id
to the bind
function. We are using Postgres, so the field that has $1
, binds first than the $2
. The controller returns the JSON updated.
DELETE
pub async fn delete_task(Path(id): Path<i32>,
Extension(pool): Extension<PgPool>)
-> Result <(StatusCode, Json<Value>), CustomError> {
let _find: task::Task = sqlx::query_as("SELECT * FROM task where id=$1")
.bind(id)
.fetch_one(&pool)
.await
.map_err(|_| {
CustomError::TaskNotFound
})?;
sqlx::query("DELETE FROM task WHERE id=$1")
.bind(id)
.execute(&pool)
.await
.map_err(|_| {
CustomError::TaskNotFound
})?;
Ok((StatusCode::OK, Json(json!({"msg": "Task Deleted"}))))
}
In delete_task
we pass the id
of the task we want to delete, pass it to the bind
function and the SQL statement to the query
function, and return a message when the task is deleted after we check that the id
passed is in the database.
Now, let's update our main.rs to add the controllers.
main.rs
use axum::{
extract::{Extension},routing::{get, post, put, delete}, Router,
};
use sqlx::postgres::PgPoolOptions;
use std::net::SocketAddr;
use std::fs;
use anyhow::Context;
use tower_http::trace::TraceLayer;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
mod errors;
mod models;
mod controllers;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let env = fs::read_to_string(".env").unwrap();
let (key, database_url) = env.split_once('=').unwrap();
assert_eq!(key, "DATABASE_URL");
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("tower_http=trace")
.unwrap_or_else(|_| "example_tracing_aka_logging=debug,tower_http=debug".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let pool = PgPoolOptions::new()
.max_connections(50)
.connect(&database_url)
.await
.context("could not connect to database_url")?;
let app = Router::new()
.route("/hello", get(root))
.route("/tasks", get(controllers::task::all_tasks))
.route("/task", post(controllers::task::new_task))
.route("/task/:id",get(controllers::task::task))
.route("/task/:id", put(controllers::task::update_task))
.route("/task/:id", delete(controllers::task::delete_task))
.layer(Extension(pool))
.layer(TraceLayer::new_for_http());
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("Listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await?;
Ok(())
}
async fn root() -> &'static str {
"Hello, World!"
}
We add tower-http::trace::TraceLayer
to get logging, to do that we pass TraceLayer::new_for_http()
as an argument to layer
function in the route instance.
Here is the complete source code.
Thank you for taking the time to read this article.
If you have any recommendations about other packages, architectures, how to improve my code, my English, or anything; please leave a comment or contact me through Twitter, LinkedIn.
Top comments (0)