Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Axum Middleware to log the response body

Tags:

rust

rust-axum

I want to log the responses of my Http requests. So i looked at some examples at the axum github and found the following.

...
.layer(axum::middleware::from_fn(print_request_response))
...

async fn print_request_response<B>(
    req: Request<B>,
    next: Next<B>
) -> Result<impl IntoResponse, (StatusCode, String)> {
    let (parts, body) = req.into_parts();
    let bytes = buffer_and_print("request", body).await?;
    let req = Request::from_parts(parts, hyper::Body::from(bytes));
    
    let res = next.run(req).await;
    
    let (parts, body) = res.into_parts();
    let bytes = buffer_and_print("response", body).await?;
    let res = Response::from_parts(parts, Body::from(bytes));

    Ok(res)
}
async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
{
    let bytes = match hyper::body::to_bytes(body).await {
        Ok(bytes) => bytes,
        Err(err) => {
            return Err((
                StatusCode::BAD_REQUEST,
                format!("failed to read {} body: {}", direction, err),
            ));
        }
    };

    if let Ok(body) = std::str::from_utf8(&bytes) {
        tracing::debug!("{} body = {:?}", direction, body);
    }

    Ok(bytes)
}

In the example no types were given but the compiler directly said i need some types for Request, Next and the functions. I've been struggling to get it to work. Right now the problem is the following. At the line

let res = next.run(req).await;

I get this error:

error[E0308]: mismatched types
   --> src\core.rs:302:24
    |
294 | async fn print_request_response<B>(
    |                                 - this type parameter
...
302 |     let res = next.run(req).await;
    |                    --- ^^^ expected type parameter `B`, found struct `Body`
    |                    |
    |                    arguments to this function are incorrect
    |
    = note: expected struct `hyper::Request<B>`
               found struct `hyper::Request<Body>`

I understand the type mismatch. But according to the implementation, next.run() accepts a generic type?

I tried different type parameters and changing the return type of

let req = Request::from_parts(parts, hyper::Body::from(bytes));

but it didn't work.

I also dont need this exact example to work, I just want to get the responses of my Http Request logged.

Edit minimal reproducible example:

cargo.toml

[package]
name = "test"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
axum = { version = "0.6.18", features = ["http2"] }
hyper = { version = "0.14", features = ["full"] }
tokio = { version = "1.0", features = ["full"] }
tower = { version = "0.4", features = ["util", "filter"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

main.rs

use std::net::SocketAddr;
use axum::{
    body::{Body, Bytes},
    http::StatusCode,
    middleware::{self, Next},
    response::{IntoResponse, Response},
    routing::post,
    Router,
};
use hyper::Request;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", post(|| async move { "Hello from `POST /`" }))
        .layer(middleware::from_fn(print_request_response));

    let addr = SocketAddr::from(([0, 0, 0, 0], 8080));
    axum::Server::bind(&addr)
        // .http2_only(true)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn print_request_response<B>(
    req: Request<B>,
    next: Next<B>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
    let (parts, body) = req.into_parts();
    let bytes = buffer_and_print("request", body).await?;
    let req = Request::from_parts(parts, Body::from(bytes));

    let res = next.run(req).await;

    let (parts, body) = res.into_parts();
    let bytes = buffer_and_print("response", body).await?;
    let res = Response::from_parts(parts, Body::from(bytes));

    Ok(res)
}

async fn buffer_and_print<B>(direction: &str, body: B) -> Result<Bytes, (StatusCode, String)>
{
    let bytes = match hyper::body::to_bytes(body).await {
        Ok(bytes) => bytes,
        Err(err) => {
            return Err((
                StatusCode::BAD_REQUEST,
                format!("failed to read {} body: {}", direction, err),
            ));
        }
    };

    if let Ok(body) = std::str::from_utf8(&bytes) {
        tracing::debug!("{} body = {:?}", direction, body);
    }

    Ok(bytes)
}
like image 791
Tennie Avatar asked Sep 16 '25 22:09

Tennie


1 Answers

The solution that works for me now.

use axum::{middleware, Router};
use axum::body::Bytes;
use axum::http::{Request, Response, StatusCode};
use axum::middleware::Next;
use axum::response::IntoResponse;
use axum::routing::{get, post};
use hyper::Body;
use log::info;
use tower::ServiceExt;

pub async fn log_request_response(
    req: Request<axum::body::Body>,
    next: Next<axum::body::Body>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
    let mut do_log = true;

    let path = &req.uri().path().to_string();

    // Don't log these extensions
    let extension_skip = vec![".js", ".html", ".css", ".png", ".jpeg"];
    for ext in extension_skip {
        if path.ends_with(ext) {
            do_log = false;
            break;
        }
    }

    // Want to skip logging these paths
    let skip_paths = vec!["/example/path"];
    for skip_path in skip_paths {
        if path.ends_with(skip_path) {
            do_log = false;
            break;
        }
    }

    let (req_parts, req_body) = req.into_parts();

    // Print request
    let bytes = buffer_and_print("request", path, req_body, do_log).await?;
    let req = Request::from_parts(req_parts, hyper::Body::from(bytes));

    let res = next.run(req).await;
    

    let (mut res_parts, res_body) = res.into_parts();

    // Print response
    let bytes = buffer_and_print("response", path, res_body, do_log).await?;
    
    // When your encoding is chunked there can be problems without removing the header
    res_parts.headers.remove("transfer-encoding");
    
    let res = Response::from_parts(res_parts, Body::from(bytes));
     
    Ok(res)
}

// Consumes body and prints
async fn buffer_and_print<B>(direction: &str, path: &str, body: B, log: bool) -> Result<Bytes, (StatusCode, String)>
    where
        B: axum::body::HttpBody<Data=Bytes>,
        B::Error: std::fmt::Display,
{
    let bytes = match hyper::body::to_bytes(body).await {
        Ok(bytes) => bytes,
        Err(err) => {
            return Err((
                StatusCode::BAD_REQUEST,
                format!("failed to read {} body: {}", direction, err),
            ));
        }
    };

    if let Ok(body) = std::str::from_utf8(&bytes) {
        if log && !body.is_empty() {
            if body.len() > 2000 {
                info!("{} for req: {} with body: {}...", direction, path, &body[0..2000]);
            }
            else {
                info!("{} for req: {} with body: {}", direction, path, body);
            }
        }
    }

    Ok(bytes)
}

#[tokio::test]
async fn test_log_request_response() {
    // create a request to be passed to the middleware
    let req = Request::new(Body::from("Hello, Axum!"));

    // create a simple router to test the middleware
    let app = Router::new()
        .route("/", get(|| async { "Hello, World!" }))
        .layer(middleware::from_fn(log_request_response));

    // send the request through the middleware
    let res = app.clone().oneshot(req).await.unwrap();

    // make sure the response has a status code of 200
    assert_eq!(res.status(), StatusCode::OK);
}
like image 79
Tennie Avatar answered Sep 19 '25 09:09

Tennie