DEV Community

Cover image for Implementing JWT Authentication in Rust with Axum.
Simon Bittok
Simon Bittok

Posted on

Implementing JWT Authentication in Rust with Axum.

This is Part 2 of the "Implementing JWT Authentication in Rust" series.

Previous: Part 1: Project Setup and Configuration

Quick Recap

In Part 1, we set up our Rust project with Axum, configured environment-based settings, and created a basic error handling system. Now we'll add logging capabilities.

Introduction

Picking up from where we left off in Part 1, we will focus on setting up the infrastructure layer: logging with the tracing ecosystem.

Step 1: Install dependencies.

We will use the tracing ecosystem since they do provide the best logging facilities for asynchronous systems.

cargo add tracing-subscriber -F "env-filter, serde, tracing, json"
cargo add tracing -F log
cargo add tracing-error
cargo add tower-http -F "cors, trace"
Enter fullscreen mode Exit fullscreen mode

Dependencies Overview

  • tower-http: Provides middleware utilities.
  • tracing: Provides structured, event-based, data collection and logging.
  • tracing-subscriber: Provides a subscriber for logging traces with reasonable defaults.
  • tracing-error: Utilities for enriching error handling with tracing diagnostic information.

Setting up Logging

Add the following to our development.yaml

log:
  level: debug # off, trace, debug, info, warn, error
  format: pretty # compacy, full, json, pretty
  crates:
    - auth # name of the project
    - axum
    - tower-http
    - tower
Enter fullscreen mode Exit fullscreen mode

Create a config/log.rs file and add the following contents.

use std::{
    env::VarError,
    error::Error as _,
    fmt::{self, Display},
    io::IsTerminal,
    str::FromStr,
};

use serde::{Deserialize, Serialize};
use tracing::Subscriber;
use tracing_error::ErrorLayer;
use tracing_subscriber::{
    EnvFilter, Layer, filter::Directive, fmt::Layer as FmtLayer, layer::SubscriberExt,
    registry::LookupSpan, util::SubscriberInitExt,
};

use crate::{Error, Result};

#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub enum Level {
    #[serde(rename = "off")]
    Off,
    #[serde(rename = "trace")]
    Trace,
    #[serde(rename = "debug")]
    Debug,
    #[serde(rename = "info")]
    #[default]
    Info,
    #[serde(rename = "warn")]
    Warn,
    #[serde(rename = "error")]
    Error,
}

impl Display for Level {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "{}",
            match self {
                Self::Off => "off",
                Self::Trace => "trace",
                Self::Debug => "debug",
                Self::Info => "info",
                Self::Warn => "warn",
                Self::Error => "error",
            }
        )
    }
}

#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub enum Format {
    #[serde(rename = "compact")]
    Compact,
    #[serde(rename = "full")]
    Full,
    #[serde(rename = "json")]
    Json,
    #[serde(rename = "pretty")]
    #[default]
    Pretty,
}

impl Display for Format {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "{}",
            match self {
                Self::Compact => "compact",
                Self::Full => "full",
                Self::Json => "json",
                Self::Pretty => "pretty",
            }
        )
    }
}

#[derive(Debug, Deserialize, Clone)]
pub struct Logger {
    level: Level,
    format: Format,
    crates: Vec<String>,
}

impl Logger {
    pub fn setup(&self) -> Result<()> {
        let env_filter_layer = self.env_filter()?;
        let registry = tracing_subscriber::registry()
            .with(env_filter_layer)
            .with(ErrorLayer::default());

        match self.format {
            Format::Compact => registry.with(self.compact_fmt_layer()).try_init()?,
            Format::Full => registry.with(self.base_fmt_layer()).try_init()?,
            Format::Json => registry.with(self.json_fmt_layer()).try_init()?,
            Format::Pretty => registry.with(self.pretty_fmt_layer()).try_init()?,
        }

        Ok(())
    }

    fn env_filter(&self) -> Result<EnvFilter> {
        let mut env_filter = match EnvFilter::try_from_default_env() {
            Ok(env_filter) => env_filter,
            Err(from_env_err) => {
                if let Some(err) = from_env_err.source() {
                    match err.downcast_ref::<VarError>() {
                        Some(VarError::NotPresent) => (),
                        Some(other) => return Err(Error::EnvFilter(other.clone()).into()), // Converts into crate::Report
                        _ => return Err(Error::FromEnv(from_env_err).into()),
                    }
                }

                if self.crates.is_empty() {
                    EnvFilter::try_new(format!("{}={}", env!("CARGO_PKG_NAME"), &self.level))?
                } else {
                    EnvFilter::try_new("")?
                }
            }
        };

        let directives = self.directives()?;

        for directive in directives {
            env_filter = env_filter.add_directive(directive);
        }

        Ok(env_filter)
    }

    fn base_fmt_layer<S>(&self) -> FmtLayer<S>
    where
        S: Subscriber + for<'a> LookupSpan<'a>,
    {
        FmtLayer::new()
            .with_ansi(std::io::stderr().is_terminal())
            // TODO: Implement other writers
            .with_writer(std::io::stdout)
    }

    fn pretty_fmt_layer<S>(&self) -> impl Layer<S>
    where
        S: Subscriber + for<'a> LookupSpan<'a>,
    {
        self.base_fmt_layer().pretty()
    }

    fn json_fmt_layer<S>(&self) -> impl Layer<S>
    where
        S: Subscriber + for<'a> LookupSpan<'a>,
    {
        self.base_fmt_layer().json()
    }

    fn compact_fmt_layer<S>(&self) -> impl Layer<S>
    where
        S: Subscriber + for<'a> LookupSpan<'a>,
    {
        self.base_fmt_layer()
            .compact()
            .with_target(false)
            .with_thread_ids(false)
            .with_thread_names(false)
            .with_file(false)
            .with_line_number(false)
    }

    pub fn level(&self) -> &Level {
        &self.level
    }

    pub fn format(&self) -> &Format {
        &self.format
    }

    pub fn directives(&self) -> Result<Vec<Directive>> {
        self.crates
            .iter()
            .map(|c| -> Result<Directive> {
                let str_directive = format!("{}={}", c, &self.level);
                Ok(Directive::from_str(&str_directive)?)
            })
            .collect()
    }
}
Enter fullscreen mode Exit fullscreen mode

The log module provides three key components:

  • Level & Format enums: Define log verbosity and output styling
  • Logger struct: Configures the tracing subscriber based on your YAML settings
  • env_filter(): Handles environment variable overrides for log levels.

The crates field specifies which crates to log. Without this, you'd see logs from ALL dependencies, creating noise. List only the crates you care about.

Now in our config struct add the log field as shown below.

#[derive(Debug, Deserialize, Clone)]
pub struct Config {
    server: ServerConfig,
    log: Logger,
}

impl Config {
// snip

 pub fn log(&self) -> &Logger {
     self.log
 }
}

Enter fullscreen mode Exit fullscreen mode

Go to the Error enum and add the following variants.

#[derive(Debug, thiserror::Error)]
pub enum Error {
// rest of the content
    #[error(transparent)]
    DirectiveParseError(#[from] tracing_subscriber::filter::ParseError),
    #[error(transparent)]
    EnvFilter(#[from] std::env::VarError),
    #[error(transparent)]
    FromEnv(#[from] tracing_subscriber::filter::FromEnvError),
    #[error(transparent)]
    TryInit(#[from] tracing_subscriber::util::TryInitError),
}


Enter fullscreen mode Exit fullscreen mode

Logging Middleware

Next create a middleware module directory & add a trace.rs file with the following content.
This middleware captures HTTP request/response details. It uses tracing "spans" to track the lifecycle of each request and records timing, status codes, and errors.

use std::{net::SocketAddr, time::Duration};

use axum::{
    body::Body,
    extract::ConnectInfo,
    http::{Request, Response},
};
use tower_http::classify::ServerErrorsFailureClass;
use tracing::{Span, field};

pub fn make_span_with(request: &Request<Body>) -> Span {
    tracing::error_span!(
        "<->",
        version = field::debug(request.version()),
        uri = field::display(request.uri()),
        method = field::display(request.method()),
        source = field::Empty,
        status = field::Empty,
        latency = field::Empty,
        error = field::Empty
    )
}

pub fn on_request(request: &Request<Body>, span: &Span) {
    span.record(
        "source",
        request
            .extensions()
            .get::<ConnectInfo<SocketAddr>>()
            .map_or_else(
                || field::display(String::from("<unkown>")),
                |connect_info| field::display(connect_info.ip().to_string()),
            ),
    );

    tracing::info!("Request");
}

pub fn on_response(response: &Response<Body>, latency: Duration, span: &Span) {
    span.record("status", field::display(response.status()));
    span.record(
        "latency",
        field::display(format!("{}µs", latency.as_millis())),
    );

    tracing::info!("Response");
}

pub fn on_failure(error: ServerErrorsFailureClass, latency: Duration, span: &Span) {
    span.record("error", field::display(error.to_string()));
    span.record(
        "latency",
        field::display(format!("{}µs", latency.as_millis())),
    );

    tracing::error!("Error on request");
}

Enter fullscreen mode Exit fullscreen mode

Logging Layer.

Axum uses the .layer() method to attatch middlewares to routes. So head to the app.rs file, initialise the logger and starting logging HTTP Requests and Responses.

use std::io::IsTerminal;

use axum::{Router, routing::get};
use color_eyre::config::{HookBuilder, Theme};
use tokio::net::TcpListener;
use tower_http::trace::TraceLayer;

use crate::{Result, config::Config, middlewares::trace};

pub struct App;

impl App {
    pub async fn run() -> Result<()> {
        HookBuilder::default().theme(if std::io::stderr().is_terminal() {
            Theme::dark()
        } else {
            Theme::new()
        });

        let config = Config::load()?;

        config.log().setup()?;

        let router = Router::new()
            .route("/hello", get(|| async { "Hello World!" }))
            .layer(
                TraceLayer::new_for_http()
                    .make_span_with(trace::make_span_with)
                    .on_request(trace::on_request)
                    .on_response(trace::on_response)
                    .on_failure(trace::on_failure),
            );

        let listener = TcpListener::bind(config.server().address()).await?;

        tracing::info!("Listening on {}", config.server().url());

        axum::serve(listener, router).await.map_err(Into::into)
    }
}


Enter fullscreen mode Exit fullscreen mode

The HookBuilder type uses a dark theme if the standard error is your terminal. The setup on the log initialises your log level (info, debug, error, etc) and format type (json, full, etc) and the crates you want to log their spans.

The TraceLayer uses four callbacks:
- make_span_with: Creates a new span for each request
- on_request: Records when a request arrives
- on_response: Logs successful responses with timing
- on_failure: Captures errors with diagnostic info

Restart your dev server and make a HTTP request to http://127.0.0.1:7150. If you have done everything right you should have logs like these.

You can play around with the log levels and formats, but for production environments, info level and JSON format is recommended written to an output file (might implement later).

What's Next?

With logging in place, we can now track our application's behavior. In Part 3, we'll set up PostgreSQL for user storage and implement the database layer.

This Series

Part 1: Project Setup & Configuration
Part 2: Implementing Logging (You are here)
Next: Part 3: Database Setup with SQLx and PostgreSQL (Coming soon)

Top comments (0)