diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 313520ee09..082191424b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -2,7 +2,7 @@ name: CI env: CARGO_TERM_COLOR: always - MSRV: '1.66' + MSRV: '1.75' on: push: @@ -20,6 +20,8 @@ jobs: with: components: clippy, rustfmt - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Check run: cargo clippy --workspace --all-targets --all-features -- -D warnings - name: rustfmt @@ -31,6 +33,8 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: cargo doc env: RUSTDOCFLAGS: "-D rustdoc::all -A rustdoc::private-doc-tests" @@ -43,6 +47,8 @@ jobs: - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Install cargo-hack run: | curl -LsSf https://github.com/taiki-e/cargo-hack/releases/latest/download/cargo-hack-x86_64-unknown-linux-gnu.tar.gz | tar xzf - -C ~/.cargo/bin @@ -56,13 +62,21 @@ jobs: crate: [axum, axum-core, axum-extra, axum-macros] steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@nightly + # Pinned version due to failing `cargo-public-api-crates`. + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2024-06-06 - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Install cargo-public-api-crates run: | cargo install --git https://github.com/davidpdrsn/cargo-public-api-crates + - name: Build rustdoc + run: | + cargo rustdoc --all-features --manifest-path ${{ matrix.crate }}/Cargo.toml -- -Z unstable-options --output-format json - name: cargo public-api-crates check - run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml check + run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml --skip-build check test-versions: needs: check @@ -77,6 +91,8 @@ jobs: with: toolchain: ${{ matrix.rust }} - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Run tests run: cargo test --workspace --all-features --all-targets @@ -93,6 +109,8 @@ jobs: with: toolchain: ${{ steps.rust-toolchain.outputs.version }} - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Run nightly tests working-directory: axum-macros run: cargo test @@ -110,6 +128,8 @@ jobs: - name: "install Rust nightly" uses: dtolnay/rust-toolchain@nightly - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Select minimal version run: cargo +nightly update -Z minimal-versions - name: Fix up Cargo.lock @@ -142,6 +162,8 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Run doc tests run: cargo test --all-features --doc @@ -170,6 +192,8 @@ jobs: with: target: armv5te-unknown-linux-musleabi - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Check env: # Clang has native cross-compilation support @@ -194,6 +218,8 @@ jobs: with: target: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Check run: > cargo @@ -207,6 +233,8 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@beta - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} - name: Install cargo-sort run: | cargo install cargo-sort diff --git a/Cargo.toml b/Cargo.toml index a68aaab16a..f9c9d027b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,3 +5,6 @@ default-members = ["axum", "axum-*"] # Example has been deleted, but README.md remains exclude = ["examples/async-graphql"] resolver = "2" + +[workspace.package] +rust-version = "1.75" diff --git a/ECOSYSTEM.md b/ECOSYSTEM.md index 9b3c691c27..6ec5248e3b 100644 --- a/ECOSYSTEM.md +++ b/ECOSYSTEM.md @@ -28,7 +28,6 @@ If your project isn't listed here and you would like it to be, please feel free - [aide](https://docs.rs/aide): Code-first Open API documentation generator with [axum integration](https://docs.rs/aide/latest/aide/axum/index.html). - [axum-typed-routing](https://docs.rs/axum-typed-routing/latest/axum_typed_routing/): Statically typed routing macros with OpenAPI generation using aide. - [axum-jsonschema](https://docs.rs/axum-jsonschema/): A `Json` extractor that does JSON schema validation of requests. -- [axum-sessions](https://docs.rs/axum-sessions): Cookie-based sessions for axum via async-session. - [axum-login](https://docs.rs/axum-login): Session-based user authentication for axum. - [axum-csrf-sync-pattern](https://crates.io/crates/axum-csrf-sync-pattern): A middleware implementing CSRF STP for AJAX backends and API endpoints. - [axum-otel-metrics](https://github.com/ttys3/axum-otel-metrics/): A axum OpenTelemetry Metrics middleware with prometheus exporter supported. @@ -48,6 +47,8 @@ If your project isn't listed here and you would like it to be, please feel free - [loco.rs](https://github.com/loco-rs/loco): A full stack Web and API productivity framework similar to Rails, based on Axum. - [axum-test](https://crates.io/crates/axum-test): High level library for writing Cargo tests that run against Axum. - [axum-messages](https://github.com/maxcountryman/axum-messages): One-time notification messages for Axum. +- [spring-rs](https://github.com/spring-rs/spring-rs): spring-rs is a microservice framework written in rust inspired by java's spring-boot, based on axum +- [zino](https://github.com/zino-rs/zino): Zino is a next-generation framework for composable applications which provides full integrations with axum. ## Project showcase @@ -61,6 +62,7 @@ If your project isn't listed here and you would like it to be, please feel free - [realworld-axum-sqlx](https://github.com/launchbadge/realworld-axum-sqlx): A Rust implementation of the [Realworld] demo app spec using Axum and [SQLx]. See https://github.com/davidpdrsn/realworld-axum-sqlx for a fork with up to date dependencies. - [Rustapi](https://github.com/ndelvalle/rustapi): RESTful API template using MongoDB +- [axum-postgres-template](https://github.com/koskeller/axum-postgres-template): Production-ready Axum + PostgreSQL application template - [RUSTfulapi](https://github.com/robatipoor/rustfulapi): Reusable template for building REST Web Services in Rust. Uses Axum HTTP web framework and SeaORM. - [Jotsy](https://github.com/ohsayan/jotsy): Self-hosted notes app powered by Skytable, Axum and Tokio - [Svix](https://www.svix.com) ([repository](https://github.com/svix/svix-webhooks)): Enterprise-ready webhook service @@ -86,6 +88,7 @@ If your project isn't listed here and you would like it to be, please feel free - [randoku](https://github.com/stchris/randoku): A tiny web service which generates random numbers and shuffles lists randomly - [sero](https://github.com/clowzed/sero): Host static sites with custom subdomains as surge.sh does. But with full control and cool new features. (axum, sea-orm, postgresql) - [Hatsu](https://github.com/importantimport/hatsu): 🩵 Self-hosted & Fully-automated ActivityPub Bridge for Static Sites. +- [Mini RPS](https://github.com/marcodpt/minirps): Mini reverse proxy server, HTTPS, CORS, static file hosting and template engine (minijinja). [Realworld]: https://github.com/gothinkster/realworld [SQLx]: https://github.com/launchbadge/sqlx @@ -101,6 +104,7 @@ If your project isn't listed here and you would like it to be, please feel free - [Introduction to axum]: YouTube playlist - [Rust Axum Full Course]: YouTube video - [Deploying Axum projects with Shuttle] +- [API Development with Rust](https://rust-api.dev/docs/front-matter/preface/): REST APIs based on Axum [axum-tutorial]: https://github.com/programatik29/axum-tutorial [axum-tutorial-website]: https://programatik29.github.io/axum-tutorial/ diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index ef0d1a6793..dfa3fbf247 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -7,7 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.4.5 + +- **fixed:** Compile errors from the internal `__log_rejection` macro under + certain Cargo feature combinations between axum crates ([#2933]) + +[#2933]: https://github.com/tokio-rs/axum/pull/2933 + +# 0.4.4 + +- **added:** Derive `Clone` and `Copy` for `AppendHeaders` ([#2776]) +- **added:** `must_use` attribute on `AppendHeaders` ([#2846]) +- **added:** `must_use` attribute on `ErrorResponse` ([#2846]) +- **added:** `must_use` attribute on `IntoResponse::into_response` ([#2846]) +- **added:** `must_use` attribute on `IntoResponseParts` trait methods ([#2846]) +- **added:** Implement `Copy` for `DefaultBodyLimit` ([#2875]) +- **added**: `DefaultBodyLimit::max` and `DefaultBodyLimit::disable` are now + allowed in const context ([#2875]) + +[#2776]: https://github.com/tokio-rs/axum/pull/2776 +[#2846]: https://github.com/tokio-rs/axum/pull/2846 +[#2875]: https://github.com/tokio-rs/axum/pull/2875 # 0.4.3 (13. January, 2024) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index f9e7869283..69c5aa18f3 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -2,14 +2,14 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "Core types and traits for axum" edition = "2021" -rust-version = "1.57" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" name = "axum-core" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.4.3" # remember to also bump the version that axum and axum-extra depend on +version = "0.4.5" # remember to also bump the version that axum and axum-extra depend on [features] tracing = ["dep:tracing"] @@ -18,8 +18,7 @@ tracing = ["dep:tracing"] __private_docs = ["dep:tower-http"] [dependencies] -async-trait = "0.1.67" -bytes = "1.0" +bytes = "1.2" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1.0.0" http-body = "1.0.0" @@ -32,7 +31,7 @@ tower-layer = "0.3" tower-service = "0.3" # optional dependencies -tower-http = { version = "0.5.0", optional = true, features = ["limit"] } +tower-http = { version = "0.6.0", optional = true, features = ["limit"] } tracing = { version = "0.1.37", default-features = false, optional = true } [dev-dependencies] @@ -41,7 +40,7 @@ axum-extra = { path = "../axum-extra", features = ["typed-header"] } futures-util = { version = "0.3", default-features = false, features = ["alloc"] } hyper = "1.0.0" tokio = { version = "1.25.0", features = ["macros"] } -tower-http = { version = "0.5.0", features = ["limit"] } +tower-http = { version = "0.6.0", features = ["limit"] } [package.metadata.cargo-public-api-crates] allowed = [ @@ -55,6 +54,9 @@ allowed = [ "http_body", ] +[package.metadata.cargo-machete] +ignored = ["tower-http"] # See __private_docs feature + [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/axum-core/README.md b/axum-core/README.md index 01ff4e5105..600ec33791 100644 --- a/axum-core/README.md +++ b/axum-core/README.md @@ -14,7 +14,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum-core's MSRV is 1.56. +axum-core's MSRV is 1.75. ## Getting Help diff --git a/axum-core/src/ext_traits/mod.rs b/axum-core/src/ext_traits/mod.rs index 02595fbeac..951a12d70c 100644 --- a/axum-core/src/ext_traits/mod.rs +++ b/axum-core/src/ext_traits/mod.rs @@ -6,13 +6,11 @@ mod tests { use std::convert::Infallible; use crate::extract::{FromRef, FromRequestParts}; - use async_trait::async_trait; use http::request::Parts; #[derive(Debug, Default, Clone, Copy)] pub(crate) struct State(pub(crate) S); - #[async_trait] impl FromRequestParts for State where InnerState: FromRef, @@ -30,9 +28,9 @@ mod tests { } // some extractor that requires the state, such as `SignedCookieJar` + #[allow(dead_code)] pub(crate) struct RequiresState(pub(crate) String); - #[async_trait] impl FromRequestParts for RequiresState where S: Send + Sync, diff --git a/axum-core/src/ext_traits/request.rs b/axum-core/src/ext_traits/request.rs index 5b7aee783a..1123fdd3d6 100644 --- a/axum-core/src/ext_traits/request.rs +++ b/axum-core/src/ext_traits/request.rs @@ -1,6 +1,6 @@ use crate::body::Body; use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request}; -use futures_util::future::BoxFuture; +use std::future::Future; mod sealed { pub trait Sealed {} @@ -20,7 +20,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// extract::{Request, FromRequest}, /// body::Body, /// http::{header::CONTENT_TYPE, StatusCode}, @@ -30,7 +29,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// struct FormOrJson(T); /// - /// #[async_trait] /// impl FromRequest for FormOrJson /// where /// Json: FromRequest<()>, @@ -67,7 +65,7 @@ pub trait RequestExt: sealed::Sealed + Sized { /// } /// } /// ``` - fn extract(self) -> BoxFuture<'static, Result> + fn extract(self) -> impl Future> + Send where E: FromRequest<(), M> + 'static, M: 'static; @@ -83,7 +81,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// body::Body, /// extract::{Request, FromRef, FromRequest}, /// RequestExt, @@ -93,7 +90,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// requires_state: RequiresState, /// } /// - /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// String: FromRef, @@ -111,7 +107,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// // some extractor that consumes the request body and requires state /// struct RequiresState { /* ... */ } /// - /// #[async_trait] /// impl FromRequest for RequiresState /// where /// String: FromRef, @@ -124,7 +119,10 @@ pub trait RequestExt: sealed::Sealed + Sized { /// # } /// } /// ``` - fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> + fn extract_with_state( + self, + state: &S, + ) -> impl Future> + Send where E: FromRequest + 'static, S: Send + Sync; @@ -137,7 +135,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// extract::{Path, Request, FromRequest}, /// response::{IntoResponse, Response}, /// body::Body, @@ -154,7 +151,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// payload: T, /// } /// - /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// S: Send + Sync, @@ -179,7 +175,7 @@ pub trait RequestExt: sealed::Sealed + Sized { /// } /// } /// ``` - fn extract_parts(&mut self) -> BoxFuture<'_, Result> + fn extract_parts(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static; @@ -191,7 +187,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// extract::{Request, FromRef, FromRequest, FromRequestParts}, /// http::request::Parts, /// response::{IntoResponse, Response}, @@ -204,7 +199,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// payload: T, /// } /// - /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// String: FromRef, @@ -234,7 +228,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// struct RequiresState {} /// - /// #[async_trait] /// impl FromRequestParts for RequiresState /// where /// String: FromRef, @@ -250,7 +243,7 @@ pub trait RequestExt: sealed::Sealed + Sized { fn extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> impl Future> + Send + 'a where E: FromRequestParts + 'static, S: Send + Sync; @@ -267,7 +260,7 @@ pub trait RequestExt: sealed::Sealed + Sized { } impl RequestExt for Request { - fn extract(self) -> BoxFuture<'static, Result> + fn extract(self) -> impl Future> + Send where E: FromRequest<(), M> + 'static, M: 'static, @@ -275,7 +268,10 @@ impl RequestExt for Request { self.extract_with_state(&()) } - fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> + fn extract_with_state( + self, + state: &S, + ) -> impl Future> + Send where E: FromRequest + 'static, S: Send + Sync, @@ -283,17 +279,17 @@ impl RequestExt for Request { E::from_request(self, state) } - fn extract_parts(&mut self) -> BoxFuture<'_, Result> + fn extract_parts(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static, { self.extract_parts_with_state(&()) } - fn extract_parts_with_state<'a, E, S>( + async fn extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> Result where E: FromRequestParts + 'static, S: Send + Sync, @@ -306,17 +302,15 @@ impl RequestExt for Request { *req.extensions_mut() = std::mem::take(self.extensions_mut()); let (mut parts, ()) = req.into_parts(); - Box::pin(async move { - let result = E::from_request_parts(&mut parts, state).await; + let result = E::from_request_parts(&mut parts, state).await; - *self.version_mut() = parts.version; - *self.method_mut() = parts.method.clone(); - *self.uri_mut() = parts.uri.clone(); - *self.headers_mut() = std::mem::take(&mut parts.headers); - *self.extensions_mut() = std::mem::take(&mut parts.extensions); + *self.version_mut() = parts.version; + *self.method_mut() = parts.method.clone(); + *self.uri_mut() = parts.uri.clone(); + *self.headers_mut() = std::mem::take(&mut parts.headers); + *self.extensions_mut() = std::mem::take(&mut parts.extensions); - result - }) + result } fn with_limited_body(self) -> Request { @@ -345,7 +339,6 @@ mod tests { ext_traits::tests::{RequiresState, State}, extract::FromRef, }; - use async_trait::async_trait; use http::Method; #[tokio::test] @@ -414,7 +407,6 @@ mod tests { body: String, } - #[async_trait] impl FromRequest for WorksForCustomExtractor where S: Send + Sync, diff --git a/axum-core/src/ext_traits/request_parts.rs b/axum-core/src/ext_traits/request_parts.rs index e7063f4d8b..9e1a3d1c16 100644 --- a/axum-core/src/ext_traits/request_parts.rs +++ b/axum-core/src/ext_traits/request_parts.rs @@ -1,6 +1,6 @@ use crate::extract::FromRequestParts; -use futures_util::future::BoxFuture; use http::request::Parts; +use std::future::Future; mod sealed { pub trait Sealed {} @@ -21,7 +21,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// response::{Response, IntoResponse}, /// http::request::Parts, /// RequestPartsExt, - /// async_trait, /// }; /// use std::collections::HashMap; /// @@ -30,7 +29,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// query_params: HashMap, /// } /// - /// #[async_trait] /// impl FromRequestParts for MyExtractor /// where /// S: Send + Sync, @@ -54,7 +52,7 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// } /// } /// ``` - fn extract(&mut self) -> BoxFuture<'_, Result> + fn extract(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static; @@ -70,14 +68,12 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// response::{Response, IntoResponse}, /// http::request::Parts, /// RequestPartsExt, - /// async_trait, /// }; /// /// struct MyExtractor { /// requires_state: RequiresState, /// } /// - /// #[async_trait] /// impl FromRequestParts for MyExtractor /// where /// String: FromRef, @@ -97,7 +93,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// struct RequiresState { /* ... */ } /// /// // some extractor that requires a `String` in the state - /// #[async_trait] /// impl FromRequestParts for RequiresState /// where /// String: FromRef, @@ -113,14 +108,14 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { fn extract_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> impl Future> + Send + 'a where E: FromRequestParts + 'static, S: Send + Sync; } impl RequestPartsExt for Parts { - fn extract(&mut self) -> BoxFuture<'_, Result> + fn extract(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static, { @@ -130,7 +125,7 @@ impl RequestPartsExt for Parts { fn extract_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> impl Future> + Send + 'a where E: FromRequestParts + 'static, S: Send + Sync, @@ -148,7 +143,6 @@ mod tests { ext_traits::tests::{RequiresState, State}, extract::FromRef, }; - use async_trait::async_trait; use http::{Method, Request}; #[tokio::test] @@ -181,7 +175,6 @@ mod tests { from_state: String, } - #[async_trait] impl FromRequestParts for WorksForCustomExtractor where S: Send + Sync, diff --git a/axum-core/src/extract/default_body_limit.rs b/axum-core/src/extract/default_body_limit.rs index 2ec82febc6..a045d1cd3f 100644 --- a/axum-core/src/extract/default_body_limit.rs +++ b/axum-core/src/extract/default_body_limit.rs @@ -72,7 +72,7 @@ use tower_layer::Layer; /// [`RequestBodyLimit`]: tower_http::limit::RequestBodyLimit /// [`RequestExt::with_limited_body`]: crate::RequestExt::with_limited_body /// [`RequestExt::into_limited_body`]: crate::RequestExt::into_limited_body -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] #[must_use] pub struct DefaultBodyLimit { kind: DefaultBodyLimitKind, @@ -116,7 +116,7 @@ impl DefaultBodyLimit { /// [`Bytes`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html - pub fn disable() -> Self { + pub const fn disable() -> Self { Self { kind: DefaultBodyLimitKind::Disable, } @@ -149,7 +149,7 @@ impl DefaultBodyLimit { /// [`Bytes::from_request`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html - pub fn max(limit: usize) -> Self { + pub const fn max(limit: usize) -> Self { Self { kind: DefaultBodyLimitKind::Limit(limit), } diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index f59c1eadde..1baa893555 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -5,9 +5,9 @@ //! [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html use crate::{body::Body, response::IntoResponse}; -use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; +use std::future::Future; pub mod rejection; @@ -42,7 +42,6 @@ mod private { /// See [`axum::extract`] for more general docs about extractors. /// /// [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html -#[async_trait] #[rustversion::attr( since(1.78), diagnostic::on_unimplemented( @@ -55,7 +54,10 @@ pub trait FromRequestParts: Sized { type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result; + fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> impl Future> + Send; } /// Types that can be created from requests. @@ -69,7 +71,6 @@ pub trait FromRequestParts: Sized { /// See [`axum::extract`] for more general docs about extractors. /// /// [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html -#[async_trait] #[rustversion::attr( since(1.78), diagnostic::on_unimplemented( @@ -82,10 +83,12 @@ pub trait FromRequest: Sized { type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: Request, state: &S) -> Result; + fn from_request( + req: Request, + state: &S, + ) -> impl Future> + Send; } -#[async_trait] impl FromRequest for T where S: Send + Sync, @@ -99,7 +102,6 @@ where } } -#[async_trait] impl FromRequestParts for Option where T: FromRequestParts, @@ -115,7 +117,6 @@ where } } -#[async_trait] impl FromRequest for Option where T: FromRequest, @@ -128,7 +129,6 @@ where } } -#[async_trait] impl FromRequestParts for Result where T: FromRequestParts, @@ -141,7 +141,6 @@ where } } -#[async_trait] impl FromRequest for Result where T: FromRequest, diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 73f54db793..695f7e1e9e 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,12 +1,10 @@ use super::{rejection::*, FromRequest, FromRequestParts, Request}; use crate::{body::Body, RequestExt}; -use async_trait::async_trait; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version}; use http_body_util::BodyExt; use std::convert::Infallible; -#[async_trait] impl FromRequest for Request where S: Send + Sync, @@ -18,7 +16,6 @@ where } } -#[async_trait] impl FromRequestParts for Method where S: Send + Sync, @@ -30,7 +27,6 @@ where } } -#[async_trait] impl FromRequestParts for Uri where S: Send + Sync, @@ -42,7 +38,6 @@ where } } -#[async_trait] impl FromRequestParts for Version where S: Send + Sync, @@ -59,7 +54,6 @@ where /// Prefer using [`TypedHeader`] to extract only the headers you need. /// /// [`TypedHeader`]: https://docs.rs/axum/0.7/axum/extract/struct.TypedHeader.html -#[async_trait] impl FromRequestParts for HeaderMap where S: Send + Sync, @@ -71,7 +65,36 @@ where } } -#[async_trait] +impl FromRequest for BytesMut +where + S: Send + Sync, +{ + type Rejection = BytesRejection; + + async fn from_request(req: Request, _: &S) -> Result { + let mut body = req.into_limited_body(); + let mut bytes = BytesMut::new(); + body_to_bytes_mut(&mut body, &mut bytes).await?; + Ok(bytes) + } +} + +async fn body_to_bytes_mut(body: &mut Body, bytes: &mut BytesMut) -> Result<(), BytesRejection> { + while let Some(frame) = body + .frame() + .await + .transpose() + .map_err(FailedToBufferBody::from_err)? + { + let Ok(data) = frame.into_data() else { + return Ok(()); + }; + bytes.put(data); + } + + Ok(()) +} + impl FromRequest for Bytes where S: Send + Sync, @@ -90,7 +113,6 @@ where } } -#[async_trait] impl FromRequest for String where S: Send + Sync, @@ -106,15 +128,12 @@ where } })?; - let string = std::str::from_utf8(&bytes) - .map_err(InvalidUtf8::from_err)? - .to_owned(); + let string = String::from_utf8(bytes.into()).map_err(InvalidUtf8::from_err)?; Ok(string) } } -#[async_trait] impl FromRequestParts for Parts where S: Send + Sync, @@ -126,7 +145,6 @@ where } } -#[async_trait] impl FromRequestParts for Extensions where S: Send + Sync, @@ -138,7 +156,6 @@ where } } -#[async_trait] impl FromRequest for Body where S: Send + Sync, diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 021b9616df..cbd91a7fb3 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -1,10 +1,8 @@ use super::{FromRequest, FromRequestParts, Request}; use crate::response::{IntoResponse, Response}; -use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; -#[async_trait] impl FromRequestParts for () where S: Send + Sync, @@ -20,7 +18,6 @@ macro_rules! impl_from_request { ( [$($ty:ident),*], $last:ident ) => { - #[async_trait] #[allow(non_snake_case, unused_mut, unused_variables)] impl FromRequestParts for ($($ty,)* $last,) where @@ -46,7 +43,6 @@ macro_rules! impl_from_request { // This impl must not be generic over M, otherwise it would conflict with the blanket // implementation of `FromRequest` for `T: FromRequestParts`. - #[async_trait] #[allow(non_snake_case, unused_mut, unused_variables)] impl FromRequest for ($($ty,)* $last,) where diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs index 994b522c07..134c566b30 100644 --- a/axum-core/src/lib.rs +++ b/axum-core/src/lib.rs @@ -21,7 +21,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -51,6 +50,11 @@ #[macro_use] pub(crate) mod macros; +#[doc(hidden)] // macro helpers +pub mod __private { + #[cfg(feature = "tracing")] + pub use tracing; +} mod error; mod ext_traits; diff --git a/axum-core/src/macros.rs b/axum-core/src/macros.rs index 69358a2940..aa99ba402e 100644 --- a/axum-core/src/macros.rs +++ b/axum-core/src/macros.rs @@ -1,4 +1,5 @@ /// Private API. +#[cfg(feature = "tracing")] #[doc(hidden)] #[macro_export] macro_rules! __log_rejection { @@ -7,20 +8,30 @@ macro_rules! __log_rejection { body_text = $body_text:expr, status = $status:expr, ) => { - #[cfg(feature = "tracing")] { - tracing::event!( + $crate::__private::tracing::event!( target: "axum::rejection", - tracing::Level::TRACE, + $crate::__private::tracing::Level::TRACE, status = $status.as_u16(), body = $body_text, - rejection_type = std::any::type_name::<$ty>(), + rejection_type = ::std::any::type_name::<$ty>(), "rejecting request", ); } }; } +#[cfg(not(feature = "tracing"))] +#[doc(hidden)] +#[macro_export] +macro_rules! __log_rejection { + ( + rejection_type = $ty:ident, + body_text = $body_text:expr, + status = $status:expr, + ) => {}; +} + /// Private API. #[doc(hidden)] #[macro_export] diff --git a/axum-core/src/response/append_headers.rs b/axum-core/src/response/append_headers.rs index e4ac4812f9..aa8f2dbdfb 100644 --- a/axum-core/src/response/append_headers.rs +++ b/axum-core/src/response/append_headers.rs @@ -29,7 +29,7 @@ use std::fmt; /// ) /// } /// ``` -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[must_use] pub struct AppendHeaders(pub I); diff --git a/axum-core/src/response/into_response.rs b/axum-core/src/response/into_response.rs index 679b0cbb74..915b55eff8 100644 --- a/axum-core/src/response/into_response.rs +++ b/axum-core/src/response/into_response.rs @@ -111,6 +111,7 @@ use std::{ /// ``` pub trait IntoResponse { /// Create a response. + #[must_use] fn into_response(self) -> Response; } diff --git a/axum-core/src/response/into_response_parts.rs b/axum-core/src/response/into_response_parts.rs index 72b61bc75b..955648238d 100644 --- a/axum-core/src/response/into_response_parts.rs +++ b/axum-core/src/response/into_response_parts.rs @@ -105,21 +105,25 @@ pub struct ResponseParts { impl ResponseParts { /// Gets a reference to the response headers. + #[must_use] pub fn headers(&self) -> &HeaderMap { self.res.headers() } /// Gets a mutable reference to the response headers. + #[must_use] pub fn headers_mut(&mut self) -> &mut HeaderMap { self.res.headers_mut() } /// Gets a reference to the response extensions. + #[must_use] pub fn extensions(&self) -> &Extensions { self.res.extensions() } /// Gets a mutable reference to the response extensions. + #[must_use] pub fn extensions_mut(&mut self) -> &mut Extensions { self.res.extensions_mut() } diff --git a/axum-core/src/response/mod.rs b/axum-core/src/response/mod.rs index 6b66c60e71..dd6728b1c2 100644 --- a/axum-core/src/response/mod.rs +++ b/axum-core/src/response/mod.rs @@ -117,6 +117,7 @@ where /// /// See [`Result`] for more details. #[derive(Debug)] +#[must_use] pub struct ErrorResponse(Response); impl From for ErrorResponse diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 4d49aab547..a8d1f5ce37 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,7 +7,17 @@ and this project adheres to [Semantic Versioning]. # Unreleased -- None. +- **breaking:** Update to prost 0.13. Used for the `Protobuf` extractor ([#2829]) +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2829]: https://github.com/tokio-rs/axum/pull/2829 +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.9.4 + +- **added:** The `response::Attachment` type ([#2789]) + +[#2789]: https://github.com/tokio-rs/axum/pull/2789 # 0.9.3 (24. March, 2024) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index c618861d96..fe620ce3dd 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -2,19 +2,20 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "Extra utilities for axum" edition = "2021" -rust-version = "1.66" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" name = "axum-extra" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.9.3" +version = "0.9.4" [features] -default = ["tracing"] +default = ["tracing", "multipart"] async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"] +attachment = ["dep:tracing"] cookie = ["dep:cookie"] cookie-private = ["cookie", "cookie?/private"] cookie-signed = ["cookie", "cookie?/signed"] @@ -30,16 +31,16 @@ json-lines = [ "tokio-stream?/io-util", "dep:tokio", ] -multipart = ["dep:multer"] +multipart = ["dep:multer", "dep:fastrand"] protobuf = ["dep:prost"] query = ["dep:serde_html_form"] -tracing = ["dep:tracing", "axum-core/tracing"] +tracing = ["axum-core/tracing", "axum/tracing"] typed-header = ["dep:headers"] typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"] [dependencies] -axum = { path = "../axum", version = "0.7.2", default-features = false } -axum-core = { path = "../axum-core", version = "0.4.3" } +axum = { path = "../axum", version = "0.7.7", default-features = false } +axum-core = { path = "../axum-core", version = "0.4.5" } bytes = "1.1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1.0.0" @@ -48,18 +49,19 @@ http-body-util = "0.1.0" mime = "0.3" pin-project-lite = "0.2" serde = "1.0" -tower = { version = "0.4", default_features = false, features = ["util"] } +tower = { version = "0.5.1", default-features = false, features = ["util"] } tower-layer = "0.3" tower-service = "0.3" # optional dependencies -axum-macros = { path = "../axum-macros", version = "0.4.1", optional = true } +axum-macros = { path = "../axum-macros", version = "0.4.2", optional = true } cookie = { package = "cookie", version = "0.18.0", features = ["percent-encode"], optional = true } +fastrand = { version = "2.1.0", optional = true } form_urlencoded = { version = "1.1.0", optional = true } headers = { version = "0.4.0", optional = true } multer = { version = "3.0.0", optional = true } percent-encoding = { version = "2.1", optional = true } -prost = { version = "0.12", optional = true } +prost = { version = "0.13", optional = true } serde_html_form = { version = "0.2.0", optional = true } serde_json = { version = "1.0.71", optional = true } serde_path_to_error = { version = "0.1.8", optional = true } @@ -75,8 +77,8 @@ reqwest = { version = "0.12", default-features = false, features = ["json", "str serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.71" tokio = { version = "1.14", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5.0", features = ["map-response-body", "timeout"] } +tower = { version = "0.5.1", features = ["util"] } +tower-http = { version = "0.6.0", features = ["map-response-body", "timeout"] } [package.metadata.docs.rs] all-features = true diff --git a/axum-extra/README.md b/axum-extra/README.md index 16b96cc8c9..7d3e904e9c 100644 --- a/axum-extra/README.md +++ b/axum-extra/README.md @@ -14,7 +14,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum-extra's MSRV is 1.66. +axum-extra's MSRV is 1.75. ## Getting Help diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 2742debb85..9fa1f82f3f 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -7,7 +7,6 @@ //! use axum::{ //! body::Bytes, //! Router, -//! async_trait, //! routing::get, //! extract::FromRequestParts, //! }; @@ -15,7 +14,6 @@ //! // extractors for checking permissions //! struct AdminPermissions {} //! -//! #[async_trait] //! impl FromRequestParts for AdminPermissions //! where //! S: Send + Sync, @@ -29,7 +27,6 @@ //! //! struct User {} //! -//! #[async_trait] //! impl FromRequestParts for User //! where //! S: Send + Sync, @@ -96,7 +93,6 @@ use std::task::{Context, Poll}; use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, Response}, }; @@ -236,7 +232,6 @@ macro_rules! impl_traits_for_either { [$($ident:ident),* $(,)?], $last:ident $(,)? ) => { - #[async_trait] impl FromRequestParts for $either<$($ident),*, $last> where $($ident: FromRequestParts),*, @@ -247,12 +242,12 @@ macro_rules! impl_traits_for_either { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { $( - if let Ok(value) = FromRequestParts::from_request_parts(parts, state).await { + if let Ok(value) = <$ident as FromRequestParts>::from_request_parts(parts, state).await { return Ok(Self::$ident(value)); } )* - FromRequestParts::from_request_parts(parts, state).await.map(Self::$last) + <$last as FromRequestParts>::from_request_parts(parts, state).await.map(Self::$last) } } diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index f9714eb014..64b4c3056f 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -1,7 +1,4 @@ -use axum::{ - async_trait, - extract::{Extension, FromRequestParts}, -}; +use axum::extract::{Extension, FromRequestParts}; use http::request::Parts; /// Cache results of other extractors. @@ -19,7 +16,6 @@ use http::request::Parts; /// ```rust /// use axum_extra::extract::Cached; /// use axum::{ -/// async_trait, /// extract::FromRequestParts, /// response::{IntoResponse, Response}, /// http::{StatusCode, request::Parts}, @@ -28,7 +24,6 @@ use http::request::Parts; /// #[derive(Clone)] /// struct Session { /* ... */ } /// -/// #[async_trait] /// impl FromRequestParts for Session /// where /// S: Send + Sync, @@ -43,7 +38,6 @@ use http::request::Parts; /// /// struct CurrentUser { /* ... */ } /// -/// #[async_trait] /// impl FromRequestParts for CurrentUser /// where /// S: Send + Sync, @@ -86,7 +80,6 @@ pub struct Cached(pub T); #[derive(Clone)] struct CachedEntry(T); -#[async_trait] impl FromRequestParts for Cached where S: Send + Sync, @@ -125,7 +118,6 @@ mod tests { #[derive(Clone, Debug, PartialEq, Eq)] struct Extractor(Instant); - #[async_trait] impl FromRequestParts for Extractor where S: Send + Sync, diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index efd2dcdf86..50fa6031ac 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -3,7 +3,6 @@ //! See [`CookieJar`], [`SignedCookieJar`], and [`PrivateCookieJar`] for more details. use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; @@ -90,7 +89,6 @@ pub struct CookieJar { jar: cookie::CookieJar, } -#[async_trait] impl FromRequestParts for CookieJar where S: Send + Sync, diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 911b0ef2ec..3a7d0beee6 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -1,6 +1,5 @@ use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ - async_trait, extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; @@ -122,7 +121,6 @@ impl fmt::Debug for PrivateCookieJar { } } -#[async_trait] impl FromRequestParts for PrivateCookieJar where S: Send + Sync, diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index b65df79f95..87ba5444b5 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -1,6 +1,5 @@ use super::{cookies_from_request, set_cookies}; use axum::{ - async_trait, extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; @@ -139,7 +138,6 @@ impl fmt::Debug for SignedCookieJar { } } -#[async_trait] impl FromRequestParts for SignedCookieJar where S: Send + Sync, diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index 453e782372..a7ca9305aa 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::{rejection::RawFormRejection, FromRequest, RawForm, Request}, response::{IntoResponse, Response}, Error, RequestExt, @@ -44,7 +43,6 @@ pub struct Form(pub T); axum_core::__impl_deref!(Form); -#[async_trait] impl FromRequest for Form where T: DeserializeOwned, diff --git a/axum-extra/src/extract/json_deserializer.rs b/axum-extra/src/extract/json_deserializer.rs index b138c50f3a..4a84a72418 100644 --- a/axum-extra/src/extract/json_deserializer.rs +++ b/axum-extra/src/extract/json_deserializer.rs @@ -1,4 +1,3 @@ -use axum::async_trait; use axum::extract::{FromRequest, Request}; use axum_core::__composite_rejection as composite_rejection; use axum_core::__define_rejection as define_rejection; @@ -23,8 +22,7 @@ use std::marker::PhantomData; /// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if: /// /// - The body doesn't contain syntactically valid JSON. -/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target -/// type. +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type. /// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`). /// /// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the @@ -85,7 +83,6 @@ pub struct JsonDeserializer { _marker: PhantomData, } -#[async_trait] impl FromRequest for JsonDeserializer where T: Deserialize<'static>, diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index 70f0486628..cbbdfd425f 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -3,7 +3,6 @@ //! See [`Multipart`] for more details. use axum::{ - async_trait, body::{Body, Bytes}, extract::FromRequest, response::{IntoResponse, Response}, @@ -90,7 +89,6 @@ pub struct Multipart { inner: multer::Multipart<'static>, } -#[async_trait] impl FromRequest for Multipart where S: Send + Sync, diff --git a/axum-extra/src/extract/optional_path.rs b/axum-extra/src/extract/optional_path.rs index 236b9836a4..0d41a66cd6 100644 --- a/axum-extra/src/extract/optional_path.rs +++ b/axum-extra/src/extract/optional_path.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path}, RequestPartsExt, }; @@ -35,7 +34,6 @@ use serde::de::DeserializeOwned; #[derive(Debug)] pub struct OptionalPath(pub Option); -#[async_trait] impl FromRequestParts for OptionalPath where T: DeserializeOwned + Send + 'static, diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index 7822f108a2..695ea9576b 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, Response}, Error, @@ -82,7 +81,6 @@ use std::fmt; #[derive(Debug, Clone, Copy, Default)] pub struct Query(pub T); -#[async_trait] impl FromRequestParts for Query where T: DeserializeOwned, @@ -187,7 +185,6 @@ impl std::error::Error for QueryRejection { #[derive(Debug, Clone, Copy, Default)] pub struct OptionalQuery(pub Option); -#[async_trait] impl FromRequestParts for OptionalQuery where T: DeserializeOwned, diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index 000672fe47..c093f6fa47 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -1,4 +1,3 @@ -use axum::async_trait; use axum::extract::{FromRequest, FromRequestParts, Request}; use axum::response::IntoResponse; use http::request::Parts; @@ -110,7 +109,6 @@ impl DerefMut for WithRejection { } } -#[async_trait] impl FromRequest for WithRejection where S: Send + Sync, @@ -125,7 +123,6 @@ where } } -#[async_trait] impl FromRequestParts for WithRejection where S: Send + Sync, @@ -169,7 +166,6 @@ mod tests { struct TestExtractor; struct TestRejection; - #[async_trait] impl FromRequestParts for TestExtractor where S: Send + Sync, diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 2438889930..571ab67707 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -47,7 +47,6 @@ pub trait HandlerCallWithExtractors: Sized { /// use axum_extra::handler::HandlerCallWithExtractors; /// use axum::{ /// Router, - /// async_trait, /// routing::get, /// extract::FromRequestParts, /// }; @@ -68,7 +67,6 @@ pub trait HandlerCallWithExtractors: Sized { /// // extractors for checking permissions /// struct AdminPermissions {} /// - /// #[async_trait] /// impl FromRequestParts for AdminPermissions /// where /// S: Send + Sync, @@ -82,7 +80,6 @@ pub trait HandlerCallWithExtractors: Sized { /// /// struct User {} /// - /// #[async_trait] /// impl FromRequestParts for User /// where /// S: Send + Sync, @@ -168,7 +165,7 @@ pub struct IntoHandler { impl Handler for IntoHandler where - H: HandlerCallWithExtractors + Clone + Send + 'static, + H: HandlerCallWithExtractors + Clone + Send + Sync + 'static, T: FromRequest + Send + 'static, T::Rejection: Send, S: Send + Sync + 'static, diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index e4a1dc67e7..f15ccc70b0 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -54,8 +54,8 @@ where impl Handler<(M, Lt, Rt), S> for Or where - L: HandlerCallWithExtractors + Clone + Send + 'static, - R: HandlerCallWithExtractors + Clone + Send + 'static, + L: HandlerCallWithExtractors + Clone + Send + Sync + 'static, + R: HandlerCallWithExtractors + Clone + Send + Sync + 'static, Lt: FromRequestParts + Send + 'static, Rt: FromRequest + Send + 'static, Lt::Rejection: Send, diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index ec955e796f..7c513f96cd 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -1,7 +1,6 @@ //! Newline delimited JSON extractor and response. use axum::{ - async_trait, body::Body, extract::{FromRequest, Request}, response::{IntoResponse, Response}, @@ -99,7 +98,6 @@ impl JsonLines { } } -#[async_trait] impl FromRequest for JsonLines where T: DeserializeOwned, diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index 2ddda783e4..02dd6e697a 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -40,7 +40,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index 30e048e273..d563807403 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -1,7 +1,6 @@ //! Protocol Buffer extractor and response. use axum::{ - async_trait, extract::{rejection::BytesRejection, FromRequest, Request}, response::{IntoResponse, Response}, }; @@ -90,7 +89,6 @@ use prost::Message; #[must_use] pub struct Protobuf(pub T); -#[async_trait] impl FromRequest for Protobuf where T: Message + Default, diff --git a/axum-extra/src/response/attachment.rs b/axum-extra/src/response/attachment.rs new file mode 100644 index 0000000000..2063d30f05 --- /dev/null +++ b/axum-extra/src/response/attachment.rs @@ -0,0 +1,103 @@ +use axum::response::IntoResponse; +use http::{header, HeaderMap, HeaderValue}; +use tracing::error; + +/// A file attachment response. +/// +/// This type will set the `Content-Disposition` header to `attachment`. In response a webbrowser +/// will offer to download the file instead of displaying it directly. +/// +/// Use the `filename` and `content_type` methods to set the filename or content-type of the +/// attachment. If these values are not set they will not be sent. +/// +/// +/// # Example +/// +/// ```rust +/// use axum::{http::StatusCode, routing::get, Router}; +/// use axum_extra::response::Attachment; +/// +/// async fn cargo_toml() -> Result, (StatusCode, String)> { +/// let file_contents = tokio::fs::read_to_string("Cargo.toml") +/// .await +/// .map_err(|err| (StatusCode::NOT_FOUND, format!("File not found: {err}")))?; +/// Ok(Attachment::new(file_contents) +/// .filename("Cargo.toml") +/// .content_type("text/x-toml")) +/// } +/// +/// let app = Router::new().route("/Cargo.toml", get(cargo_toml)); +/// let _: Router = app; +/// ``` +/// +/// # Note +/// +/// If you use axum with hyper, hyper will set the `Content-Length` if it is known. +#[derive(Debug)] +#[must_use] +pub struct Attachment { + inner: T, + filename: Option, + content_type: Option, +} + +impl Attachment { + /// Creates a new [`Attachment`]. + pub fn new(inner: T) -> Self { + Self { + inner, + filename: None, + content_type: None, + } + } + + /// Sets the filename of the [`Attachment`]. + /// + /// This updates the `Content-Disposition` header to add a filename. + pub fn filename>(mut self, value: H) -> Self { + self.filename = if let Ok(filename) = value.try_into() { + Some(filename) + } else { + error!("Attachment filename contains invalid characters"); + None + }; + self + } + + /// Sets the content-type of the [`Attachment`] + pub fn content_type>(mut self, value: H) -> Self { + if let Ok(content_type) = value.try_into() { + self.content_type = Some(content_type); + } else { + error!("Attachment content-type contains invalid characters"); + } + self + } +} + +impl IntoResponse for Attachment +where + T: IntoResponse, +{ + fn into_response(self) -> axum::response::Response { + let mut headers = HeaderMap::new(); + + if let Some(content_type) = self.content_type { + headers.append(header::CONTENT_TYPE, content_type); + } + + let content_disposition = if let Some(filename) = self.filename { + let mut bytes = b"attachment; filename=\"".to_vec(); + bytes.extend_from_slice(filename.as_bytes()); + bytes.push(b'\"'); + + HeaderValue::from_bytes(&bytes).expect("This was a HeaderValue so this can not fail") + } else { + HeaderValue::from_static("attachment") + }; + + headers.append(header::CONTENT_DISPOSITION, content_disposition); + + (headers, self.inner).into_response() + } +} diff --git a/axum-extra/src/response/mod.rs b/axum-extra/src/response/mod.rs index dda382cf02..29b2d0e915 100644 --- a/axum-extra/src/response/mod.rs +++ b/axum-extra/src/response/mod.rs @@ -3,6 +3,12 @@ #[cfg(feature = "erased-json")] mod erased_json; +#[cfg(feature = "attachment")] +mod attachment; + +#[cfg(feature = "multipart")] +pub mod multiple; + #[cfg(feature = "erased-json")] pub use erased_json::ErasedJson; @@ -10,6 +16,9 @@ pub use erased_json::ErasedJson; #[doc(no_inline)] pub use crate::json_lines::JsonLines; +#[cfg(feature = "attachment")] +pub use attachment::Attachment; + macro_rules! mime_response { ( $(#[$m:meta])* @@ -57,14 +66,6 @@ macro_rules! mime_response { }; } -mime_response! { - /// A HTML response. - /// - /// Will automatically get `Content-Type: text/html; charset=utf-8`. - Html, - TEXT_HTML_UTF_8, -} - mime_response! { /// A JavaScript response. /// diff --git a/axum-extra/src/response/multiple.rs b/axum-extra/src/response/multiple.rs new file mode 100644 index 0000000000..1fdbd8e765 --- /dev/null +++ b/axum-extra/src/response/multiple.rs @@ -0,0 +1,296 @@ +//! Generate forms to use in responses. + +use axum::response::{IntoResponse, Response}; +use fastrand; +use http::{header, HeaderMap, StatusCode}; +use mime::Mime; + +/// Create multipart forms to be used in API responses. +/// +/// This struct implements [`IntoResponse`], and so it can be returned from a handler. +#[derive(Debug)] +pub struct MultipartForm { + parts: Vec, +} + +impl MultipartForm { + /// Initialize a new multipart form with the provided vector of parts. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// let parts: Vec = vec![Part::text("foo".to_string(), "abc"), Part::text("bar".to_string(), "def")]; + /// let form = MultipartForm::with_parts(parts); + /// ``` + #[deprecated] + pub fn with_parts(parts: Vec) -> Self { + MultipartForm { parts } + } +} + +impl IntoResponse for MultipartForm { + fn into_response(self) -> Response { + // see RFC5758 for details + let boundary = generate_boundary(); + let mut headers = HeaderMap::new(); + let mime_type: Mime = match format!("multipart/form-data; boundary={}", boundary).parse() { + Ok(m) => m, + // Realistically this should never happen unless the boundary generation code + // is modified, and that will be caught by unit tests + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Invalid multipart boundary generated", + ) + .into_response() + } + }; + // The use of unwrap is safe here because mime types are inherently string representable + headers.insert(header::CONTENT_TYPE, mime_type.to_string().parse().unwrap()); + let mut serialized_form: Vec = Vec::new(); + for part in self.parts { + // for each part, the boundary is preceded by two dashes + serialized_form.extend_from_slice(format!("--{}\r\n", boundary).as_bytes()); + serialized_form.extend_from_slice(&part.serialize()); + } + serialized_form.extend_from_slice(format!("--{}--", boundary).as_bytes()); + (headers, serialized_form).into_response() + } +} + +// Valid settings for that header are: "base64", "quoted-printable", "8bit", "7bit", and "binary". +/// A single part of a multipart form as defined by +/// +/// and RFC5758. +#[derive(Debug)] +pub struct Part { + // Every part is expected to contain: + // - a [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition + // header, where `Content-Disposition` is set to `form-data`, with a parameter of `name` that is set to + // the name of the field in the form. In the below example, the name of the field is `user`: + // ``` + // Content-Disposition: form-data; name="user" + // ``` + // If the field contains a file, then the `filename` parameter may be set to the name of the file. + // Handling for non-ascii field names is not done here, support for non-ascii characters may be encoded using + // methodology described in RFC 2047. + // - (optionally) a `Content-Type` header, which if not set, defaults to `text/plain`. + // If the field contains a file, then the file should be identified with that file's MIME type (eg: `image/gif`). + // If the `MIME` type is not known or specified, then the MIME type should be set to `application/octet-stream`. + /// The name of the part in question + name: String, + /// If the part should be treated as a file, the filename that should be attached that part + filename: Option, + /// The `Content-Type` header. While not strictly required, it is always set here + mime_type: Mime, + /// The content/body of the part + contents: Vec, +} + +impl Part { + /// Create a new part with `Content-Type` of `text/plain` with the supplied name and contents. + /// + /// This form will not have a defined file name. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "foo", + /// // and a value of "abc" + /// let parts: Vec = vec![Part::text("foo".to_string(), "abc")]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn text(name: String, contents: &str) -> Self { + Self { + name, + filename: None, + mime_type: mime::TEXT_PLAIN_UTF_8, + contents: contents.as_bytes().to_vec(), + } + } + + /// Create a new part containing a generic file, with a `Content-Type` of `application/octet-stream` + /// using the provided file name, field name, and contents. + /// + /// If the MIME type of the file is known, consider using `Part::raw_part`. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "foo", + /// // with a file name of "foo.txt", and with the specified contents + /// let parts: Vec = vec![Part::file("foo", "foo.txt", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d])]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn file(field_name: &str, file_name: &str, contents: Vec) -> Self { + Self { + name: field_name.to_owned(), + filename: Some(file_name.to_owned()), + // If the `MIME` type is not known or specified, then the MIME type should be set to `application/octet-stream`. + // See RFC2388 section 3 for specifics. + mime_type: mime::APPLICATION_OCTET_STREAM, + contents, + } + } + + /// Create a new part with more fine-grained control over the semantics of that part. + /// + /// The caller is assumed to have set a valid MIME type. + /// + /// This function will return an error if the provided MIME type is not valid. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "part_name", + /// // with a MIME type of "application/json", and the supplied contents. + /// let parts: Vec = vec![Part::raw_part("part_name", "application/json", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d], None).expect("MIME type must be valid")]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn raw_part( + name: &str, + mime_type: &str, + contents: Vec, + filename: Option<&str>, + ) -> Result { + let mime_type = mime_type.parse().map_err(|_| "Invalid MIME type")?; + Ok(Self { + name: name.to_owned(), + filename: filename.map(|f| f.to_owned()), + mime_type, + contents, + }) + } + + /// Serialize this part into a chunk that can be easily inserted into a larger form + pub(super) fn serialize(&self) -> Vec { + // A part is serialized in this general format: + // // the filename is optional + // Content-Disposition: form-data; name="FIELD_NAME"; filename="FILENAME"\r\n + // // the mime type (not strictly required by the spec, but always sent here) + // Content-Type: mime/type\r\n + // // a blank line, then the contents of the file start + // \r\n + // CONTENTS\r\n + + // Format what we can as a string, then handle the rest at a byte level + let mut serialized_part = format!("Content-Disposition: form-data; name=\"{}\"", self.name); + // specify a filename if one was set + if let Some(filename) = &self.filename { + serialized_part += &format!("; filename=\"{}\"", filename); + } + serialized_part += "\r\n"; + // specify the MIME type + serialized_part += &format!("Content-Type: {}\r\n", self.mime_type); + serialized_part += "\r\n"; + let mut part_bytes = serialized_part.as_bytes().to_vec(); + part_bytes.extend_from_slice(&self.contents); + part_bytes.extend_from_slice(b"\r\n"); + + part_bytes + } +} + +impl FromIterator for MultipartForm { + fn from_iter>(iter: T) -> Self { + Self { + parts: iter.into_iter().collect(), + } + } +} + +/// A boundary is defined as a user defined (arbitrary) value that does not occur in any of the data. +/// +/// Because the specification does not clearly define a methodology for generating boundaries, this implementation +/// follow's Reqwest's, and generates a boundary in the format of `XXXXXXXX-XXXXXXXX-XXXXXXXX-XXXXXXXX` where `XXXXXXXX` +/// is a hexadecimal representation of a pseudo randomly generated u64. +fn generate_boundary() -> String { + let a = fastrand::u64(0..u64::MAX); + let b = fastrand::u64(0..u64::MAX); + let c = fastrand::u64(0..u64::MAX); + let d = fastrand::u64(0..u64::MAX); + format!("{a:016x}-{b:016x}-{c:016x}-{d:016x}") +} + +#[cfg(test)] +mod tests { + use super::{generate_boundary, MultipartForm, Part}; + use axum::{body::Body, http}; + use axum::{routing::get, Router}; + use http::{Request, Response}; + use http_body_util::BodyExt; + use mime::Mime; + use tower::ServiceExt; + + #[tokio::test] + async fn process_form() -> Result<(), Box> { + // create a boilerplate handle that returns a form + async fn handle() -> MultipartForm { + let parts: Vec = vec![ + Part::text("part1".to_owned(), "basictext"), + Part::file( + "part2", + "file.txt", + vec![0x68, 0x69, 0x20, 0x6d, 0x6f, 0x6d], + ), + Part::raw_part("part3", "text/plain", b"rawpart".to_vec(), None).unwrap(), + ]; + MultipartForm::from_iter(parts) + } + + // make a request to that handle + let app = Router::new().route("/", get(handle)); + let response: Response<_> = app + .oneshot(Request::builder().uri("/").body(Body::empty())?) + .await?; + // content_type header + let ct_header = response.headers().get("content-type").unwrap().to_str()?; + let boundary = ct_header.split("boundary=").nth(1).unwrap().to_owned(); + let body: &[u8] = &response.into_body().collect().await?.to_bytes(); + assert_eq!( + std::str::from_utf8(body)?, + &format!( + "--{boundary}\r\n\ + Content-Disposition: form-data; name=\"part1\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\n\ + \r\n\ + basictext\r\n\ + --{boundary}\r\n\ + Content-Disposition: form-data; name=\"part2\"; filename=\"file.txt\"\r\n\ + Content-Type: application/octet-stream\r\n\ + \r\n\ + hi mom\r\n\ + --{boundary}\r\n\ + Content-Disposition: form-data; name=\"part3\"\r\n\ + Content-Type: text/plain\r\n\ + \r\n\ + rawpart\r\n\ + --{boundary}--", + boundary = boundary + ) + ); + + Ok(()) + } + + #[test] + fn valid_boundary_generation() { + for _ in 0..256 { + let boundary = generate_boundary(); + let mime_type: Result = + format!("multipart/form-data; boundary={}", boundary).parse(); + assert!( + mime_type.is_ok(), + "The generated boundary was unable to be parsed into a valid mime type." + ); + } + } +} diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 9f9d18cb6e..9d9aa0cbe6 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -165,7 +165,7 @@ pub trait RouterExt: sealed::Sealed { /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`]. fn route_service_with_tsr(self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized; @@ -268,7 +268,7 @@ where #[track_caller] fn route_service_with_tsr(mut self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized, diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index ff2d6eed7b..02c5be672c 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -85,12 +85,12 @@ use serde::Serialize; /// /// - A `TypedPath` implementation. /// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`], -/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must -/// also implement [`serde::Deserialize`], unless it's a unit struct. +/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must +/// also implement [`serde::Deserialize`], unless it's a unit struct. /// - A [`Display`] implementation that interpolates the captures. This can be used to, among other -/// things, create links to known paths and have them verified statically. Note that the -/// [`Display`] implementation for each field must return something that's compatible with its -/// [`Deserialize`] implementation. +/// things, create links to known paths and have them verified statically. Note that the +/// [`Display`] implementation for each field must return something that's compatible with its +/// [`Deserialize`] implementation. /// /// Additionally the macro will verify the captures in the path matches the fields of the struct. /// For example this fails to compile since the struct doesn't have a `team_id` field: diff --git a/axum-extra/src/typed_header.rs b/axum-extra/src/typed_header.rs index da40f2c031..ef94c3779c 100644 --- a/axum-extra/src/typed_header.rs +++ b/axum-extra/src/typed_header.rs @@ -1,12 +1,11 @@ //! Extractor and response for typed headers. use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use headers::{Header, HeaderMapExt}; -use http::request::Parts; +use http::{request::Parts, StatusCode}; use std::convert::Infallible; /// Extractor and response that works with typed header values from [`headers`]. @@ -55,7 +54,6 @@ use std::convert::Infallible; #[must_use] pub struct TypedHeader(pub T); -#[async_trait] impl FromRequestParts for TypedHeader where T: Header, @@ -156,7 +154,10 @@ impl TypedHeaderRejectionReason { impl IntoResponse for TypedHeaderRejection { fn into_response(self) -> Response { - (http::StatusCode::BAD_REQUEST, self.to_string()).into_response() + let status = StatusCode::BAD_REQUEST; + let body = self.to_string(); + axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,); + (status, body).into_response() } } diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 36a8ca18c1..ff132d8528 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -7,7 +7,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.4.2 + +- **added:** Add `#[debug_middleware]` ([#1993], [#2725]) + +[#1993]: https://github.com/tokio-rs/axum/pull/1993 +[#2725]: https://github.com/tokio-rs/axum/pull/2725 # 0.4.1 (13. January, 2024) diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index ded3c183ab..543b520c50 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -2,14 +2,14 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "Macros for axum" edition = "2021" -rust-version = "1.66" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["axum"] license = "MIT" name = "axum-macros" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.4.1" # remember to also bump the version that axum and axum-extra depends on +version = "0.4.2" # remember to also bump the version that axum and axum-extra depends on [features] default = [] @@ -19,7 +19,6 @@ __private = ["syn/visit-mut"] proc-macro = true [dependencies] -heck = "0.4" proc-macro2 = "1.0" quote = "1.0" syn = { version = "2.0", features = [ diff --git a/axum-macros/README.md b/axum-macros/README.md index 8fcde01ed2..c3967b19ae 100644 --- a/axum-macros/README.md +++ b/axum-macros/README.md @@ -14,7 +14,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum-macros's MSRV is 1.66. +axum-macros's MSRV is 1.75. ## Getting Help diff --git a/axum-macros/rust-toolchain b/axum-macros/rust-toolchain index b83876b519..eca143c73f 100644 --- a/axum-macros/rust-toolchain +++ b/axum-macros/rust-toolchain @@ -1 +1 @@ -nightly-2024-03-13 +nightly-2024-06-22 diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 90e2ae9262..3a37c17ab3 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, fmt}; use crate::{ attr_parsing::{parse_assignment_attribute, second}, @@ -8,13 +8,13 @@ use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type}; -pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { +pub(crate) fn expand(attr: Attrs, item_fn: ItemFn, kind: FunctionKind) -> TokenStream { let Attrs { state_ty } = attr; let mut state_ty = state_ty.map(second); - let check_extractor_count = check_extractor_count(&item_fn); - let check_path_extractor = check_path_extractor(&item_fn); + let check_extractor_count = check_extractor_count(&item_fn, kind); + let check_path_extractor = check_path_extractor(&item_fn, kind); let check_output_tuples = check_output_tuples(&item_fn); let check_output_impls_into_response = if check_output_tuples.is_empty() { check_output_impls_into_response(&item_fn) @@ -37,8 +37,10 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { err = Some( syn::Error::new( Span::call_site(), - "can't infer state type, please add set it explicitly, as in \ - `#[debug_handler(state = MyStateType)]`", + format!( + "can't infer state type, please add set it explicitly, as in \ + `#[axum_macros::debug_{kind}(state = MyStateType)]`" + ), ) .into_compile_error(), ); @@ -48,16 +50,16 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { err.unwrap_or_else(|| { let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); - let check_future_send = check_future_send(&item_fn); + let check_future_send = check_future_send(&item_fn, kind); - if let Some(check_input_order) = check_input_order(&item_fn) { + if let Some(check_input_order) = check_input_order(&item_fn, kind) { quote! { #check_input_order #check_future_send } } else { let check_inputs_impls_from_request = - check_inputs_impls_from_request(&item_fn, state_ty); + check_inputs_impls_from_request(&item_fn, state_ty, kind); quote! { #check_inputs_impls_from_request @@ -68,17 +70,45 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { } else { syn::Error::new_spanned( &item_fn.sig.generics, - "`#[axum_macros::debug_handler]` doesn't support generic functions", + format!("`#[axum_macros::debug_{kind}]` doesn't support generic functions"), ) .into_compile_error() }; + let middleware_takes_next_as_last_arg = + matches!(kind, FunctionKind::Middleware).then(|| next_is_last_input(&item_fn)); + quote! { #item_fn #check_extractor_count #check_path_extractor #check_output_impls_into_response #check_inputs_and_future_send + #middleware_takes_next_as_last_arg + } +} + +#[derive(Clone, Copy)] +pub(crate) enum FunctionKind { + Handler, + Middleware, +} + +impl fmt::Display for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionKind::Handler => f.write_str("handler"), + FunctionKind::Middleware => f.write_str("middleware"), + } + } +} + +impl FunctionKind { + fn name_uppercase_plural(&self) -> &'static str { + match self { + FunctionKind::Handler => "Handlers", + FunctionKind::Middleware => "Middleware", + } } } @@ -110,25 +140,36 @@ impl Parse for Attrs { } } -fn check_extractor_count(item_fn: &ItemFn) -> Option { +fn check_extractor_count(item_fn: &ItemFn, kind: FunctionKind) -> Option { let max_extractors = 16; - if item_fn.sig.inputs.len() <= max_extractors { + let inputs = item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)) + .count(); + if inputs <= max_extractors { None } else { let error_message = format!( - "Handlers cannot take more than {max_extractors} arguments. \ + "{} cannot take more than {max_extractors} arguments. \ Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", + kind.name_uppercase_plural(), ); let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error(); Some(error) } } -fn extractor_idents(item_fn: &ItemFn) -> impl Iterator { +fn extractor_idents( + item_fn: &ItemFn, + kind: FunctionKind, +) -> impl Iterator { item_fn .sig .inputs .iter() + .filter(move |arg| skip_next_arg(arg, kind)) .enumerate() .filter_map(|(idx, fn_arg)| match fn_arg { FnArg::Receiver(_) => None, @@ -146,8 +187,8 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator TokenStream { - let path_extractors = extractor_idents(item_fn) +fn check_path_extractor(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { + let path_extractors = extractor_idents(item_fn, kind) .filter(|(_, _, ident)| *ident == "Path") .collect::>(); @@ -179,113 +220,122 @@ fn is_self_pat_type(typed: &syn::PatType) -> bool { ident == "self" } -fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStream { +fn check_inputs_impls_from_request( + item_fn: &ItemFn, + state_ty: Type, + kind: FunctionKind, +) -> TokenStream { let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); - WithPosition::new(&item_fn.sig.inputs) - .enumerate() - .map(|(idx, arg)| { - let must_impl_from_request_parts = match &arg { - Position::First(_) | Position::Middle(_) => true, - Position::Last(_) | Position::Only(_) => false, - }; + WithPosition::new( + item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)), + ) + .enumerate() + .map(|(idx, arg)| { + let must_impl_from_request_parts = match &arg { + Position::First(_) | Position::Middle(_) => true, + Position::Last(_) | Position::Only(_) => false, + }; - let arg = arg.into_inner(); + let arg = arg.into_inner(); - let (span, ty) = match arg { - FnArg::Receiver(receiver) => { - if receiver.reference.is_some() { - return syn::Error::new_spanned( - receiver, - "Handlers must only take owned values", - ) - .into_compile_error(); - } - - let span = receiver.span(); - (span, syn::parse_quote!(Self)) + let (span, ty) = match arg { + FnArg::Receiver(receiver) => { + if receiver.reference.is_some() { + return syn::Error::new_spanned( + receiver, + "Handlers must only take owned values", + ) + .into_compile_error(); } - FnArg::Typed(typed) => { - let ty = &typed.ty; - let span = ty.span(); - if is_self_pat_type(typed) { - (span, syn::parse_quote!(Self)) - } else { - (span, ty.clone()) - } - } - }; + let span = receiver.span(); + (span, syn::parse_quote!(Self)) + } + FnArg::Typed(typed) => { + let ty = &typed.ty; + let span = ty.span(); - let consumes_request = request_consuming_type_name(&ty).is_some(); + if is_self_pat_type(typed) { + (span, syn::parse_quote!(Self)) + } else { + (span, ty.clone()) + } + } + }; - let check_fn = format_ident!( - "__axum_macros_check_{}_{}_from_request_check", - item_fn.sig.ident, - idx, - span = span, - ); + let consumes_request = request_consuming_type_name(&ty).is_some(); - let call_check_fn = format_ident!( - "__axum_macros_check_{}_{}_from_request_call_check", - item_fn.sig.ident, - idx, - span = span, - ); + let check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_check", + item_fn.sig.ident, + idx, + span = span, + ); - let call_check_fn_body = if takes_self { - quote_spanned! {span=> - Self::#check_fn(); - } - } else { - quote_spanned! {span=> - #check_fn(); - } - }; + let call_check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_call_check", + item_fn.sig.ident, + idx, + span = span, + ); - let check_fn_generics = if must_impl_from_request_parts || consumes_request { - quote! {} - } else { - quote! { } - }; + let call_check_fn_body = if takes_self { + quote_spanned! {span=> + Self::#check_fn(); + } + } else { + quote_spanned! {span=> + #check_fn(); + } + }; - let from_request_bound = if must_impl_from_request_parts { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequestParts<#state_ty> + Send - } - } else if consumes_request { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequest<#state_ty> + Send - } - } else { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequest<#state_ty, M> + Send - } - }; + let check_fn_generics = if must_impl_from_request_parts || consumes_request { + quote! {} + } else { + quote! { } + }; + let from_request_bound = if must_impl_from_request_parts { quote_spanned! {span=> - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - fn #check_fn #check_fn_generics() - where - #from_request_bound, - {} + #ty: ::axum::extract::FromRequestParts<#state_ty> + Send + } + } else if consumes_request { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequest<#state_ty> + Send + } + } else { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequest<#state_ty, M> + Send + } + }; - // we have to call the function to actually trigger a compile error - // since the function is generic, just defining it is not enough - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - fn #call_check_fn() { - #call_check_fn_body - } + quote_spanned! {span=> + #[allow(warnings)] + #[doc(hidden)] + fn #check_fn #check_fn_generics() + where + #from_request_bound, + {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[doc(hidden)] + fn #call_check_fn() + { + #call_check_fn_body } - }) - .collect::() + } + }) + .collect::() } fn check_output_tuples(item_fn: &ItemFn) -> TokenStream { @@ -445,11 +495,19 @@ fn check_into_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStr } } -fn check_input_order(item_fn: &ItemFn) -> Option { +fn check_input_order(item_fn: &ItemFn, kind: FunctionKind) -> Option { + let number_of_inputs = item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)) + .count(); + let types_that_consume_the_request = item_fn .sig .inputs .iter() + .filter(|arg| skip_next_arg(arg, kind)) .enumerate() .filter_map(|(idx, arg)| { let ty = match arg { @@ -469,7 +527,7 @@ fn check_input_order(item_fn: &ItemFn) -> Option { // exactly one type that consumes the request if types_that_consume_the_request.len() == 1 { // and that is not the last - if types_that_consume_the_request[0].0 != item_fn.sig.inputs.len() - 1 { + if types_that_consume_the_request[0].0 != number_of_inputs - 1 { let (_idx, type_name, span) = &types_that_consume_the_request[0]; let error = format!( "`{type_name}` consumes the request body and thus must be \ @@ -653,13 +711,13 @@ fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { } } -fn check_future_send(item_fn: &ItemFn) -> TokenStream { +fn check_future_send(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { if item_fn.sig.asyncness.is_none() { match &item_fn.sig.output { syn::ReturnType::Default => { return syn::Error::new_spanned( item_fn.sig.fn_token, - "Handlers must be `async fn`s", + format!("{} must be `async fn`s", kind.name_uppercase_plural()), ) .into_compile_error(); } @@ -763,7 +821,69 @@ fn state_types_from_args(item_fn: &ItemFn) -> HashSet { crate::infer_state_types(types).collect() } +fn next_is_last_input(item_fn: &ItemFn) -> TokenStream { + let next_args = item_fn + .sig + .inputs + .iter() + .enumerate() + .filter(|(_, arg)| !skip_next_arg(arg, FunctionKind::Middleware)) + .collect::>(); + + if next_args.is_empty() { + return quote! { + compile_error!( + "Middleware functions must take `axum::middleware::Next` as the last argument", + ); + }; + } + + if next_args.len() == 1 { + let (idx, arg) = &next_args[0]; + if *idx != item_fn.sig.inputs.len() - 1 { + return quote_spanned! {arg.span()=> + compile_error!("`axum::middleware::Next` must the last argument"); + }; + } + } + + if next_args.len() >= 2 { + return quote! { + compile_error!( + "Middleware functions can only take one argument of type `axum::middleware::Next`", + ); + }; + } + + quote! {} +} + +fn skip_next_arg(arg: &FnArg, kind: FunctionKind) -> bool { + match kind { + FunctionKind::Handler => true, + FunctionKind::Middleware => match arg { + FnArg::Receiver(_) => true, + FnArg::Typed(pat_type) => { + if let Type::Path(type_path) = &*pat_type.ty { + type_path + .path + .segments + .last() + .map_or(true, |path_segment| path_segment.ident != "Next") + } else { + true + } + } + }, + } +} + #[test] -fn ui() { +fn ui_debug_handler() { crate::run_ui_tests("debug_handler"); } + +#[test] +fn ui_debug_middleware() { + crate::run_ui_tests("debug_middleware"); +} diff --git a/axum-macros/src/from_ref.rs b/axum-macros/src/from_ref.rs index 2ab69eb54d..1a27765a4f 100644 --- a/axum-macros/src/from_ref.rs +++ b/axum-macros/src/from_ref.rs @@ -54,7 +54,7 @@ fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream { }; quote_spanned! {span=> - #[allow(clippy::clone_on_copy)] + #[allow(clippy::clone_on_copy, clippy::clone_on_ref_ptr)] impl ::axum::extract::FromRef<#state> for #field_ty { fn from_ref(state: &#state) -> Self { #body diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 191f14452a..a6c95ab7ff 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -180,7 +180,7 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { variants, } = item; - let generics_error = format!("`#[derive({tr})] on enums don't support generics"); + let generics_error = format!("`#[derive({tr})]` on enums don't support generics"); if !generics.params.is_empty() { return Err(syn::Error::new_spanned(generics, generics_error)); @@ -373,7 +373,6 @@ fn impl_struct_by_extracting_each_field( Ok(match tr { Trait::FromRequest => quote! { - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where @@ -390,7 +389,6 @@ fn impl_struct_by_extracting_each_field( } }, Trait::FromRequestParts => quote! { - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where @@ -435,7 +433,7 @@ fn extract_fields( } } - fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { + fn into_inner(via: &Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { if let Some((_, path)) = via { let span = path.span(); quote_spanned! {span=> @@ -448,6 +446,23 @@ fn extract_fields( } } + fn into_outer( + via: &Option<(attr::kw::via, syn::Path)>, + ty_span: Span, + field_ty: &Type, + ) -> TokenStream { + if let Some((_, path)) = via { + let span = path.span(); + quote_spanned! {span=> + #path<#field_ty> + } + } else { + quote_spanned! {ty_span=> + #field_ty + } + } + } + let mut fields_iter = fields.iter(); let last = match tr { @@ -464,16 +479,17 @@ fn extract_fields( let member = member(field, index); let ty_span = field.ty.span(); - let into_inner = into_inner(via, ty_span); + let into_inner = into_inner(&via, ty_span); if peel_option(&field.ty).is_some() { + let field_ty = into_outer(&via, ty_span, peel_option(&field.ty).unwrap()); let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( &mut parts, state, ) @@ -488,7 +504,7 @@ fn extract_fields( Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( parts, state, ) @@ -501,13 +517,14 @@ fn extract_fields( }; Ok(tokens) } else if peel_result_ok(&field.ty).is_some() { + let field_ty = into_outer(&via,ty_span, peel_result_ok(&field.ty).unwrap()); let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( &mut parts, state, ) @@ -521,7 +538,7 @@ fn extract_fields( Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( parts, state, ) @@ -533,6 +550,7 @@ fn extract_fields( }; Ok(tokens) } else { + let field_ty = into_outer(&via,ty_span,&field.ty); let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { @@ -545,7 +563,7 @@ fn extract_fields( #member: { let (mut parts, body) = req.into_parts(); let value = - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( &mut parts, state, ) @@ -560,7 +578,7 @@ fn extract_fields( Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( parts, state, ) @@ -582,26 +600,29 @@ fn extract_fields( let member = member(field, fields.len() - 1); let ty_span = field.ty.span(); - let into_inner = into_inner(via, ty_span); + let into_inner = into_inner(&via, ty_span); let item = if peel_option(&field.ty).is_some() { + let field_ty = into_outer(&via, ty_span, peel_option(&field.ty).unwrap()); quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequest::from_request(req, state) + <#field_ty as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .ok() .map(#into_inner) }, } } else if peel_result_ok(&field.ty).is_some() { + let field_ty = into_outer(&via, ty_span, peel_result_ok(&field.ty).unwrap()); quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequest::from_request(req, state) + <#field_ty as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(#into_inner) }, } } else { + let field_ty = into_outer(&via, ty_span, &field.ty); let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { @@ -610,7 +631,7 @@ fn extract_fields( quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequest::from_request(req, state) + <#field_ty as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(#into_inner) .map_err(#map_err)? @@ -807,7 +828,6 @@ fn impl_struct_by_extracting_all_at_once( let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics where @@ -821,7 +841,7 @@ fn impl_struct_by_extracting_all_at_once( req: ::axum::http::Request<::axum::body::Body>, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequest::from_request(req, state) + <#via_path<#via_type_generics> as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(|#via_path(value)| #value_to_self) .map_err(#map_err) @@ -831,7 +851,6 @@ fn impl_struct_by_extracting_all_at_once( } Trait::FromRequestParts => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics where @@ -845,7 +864,7 @@ fn impl_struct_by_extracting_all_at_once( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequestParts::from_request_parts(parts, state) + <#via_path<#via_type_generics> as ::axum::extract::FromRequestParts<_>>::from_request_parts(parts, state) .await .map(|#via_path(value)| #value_to_self) .map_err(#map_err) @@ -920,7 +939,6 @@ fn impl_enum_by_extracting_all_at_once( let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where @@ -932,7 +950,7 @@ fn impl_enum_by_extracting_all_at_once( req: ::axum::http::Request<::axum::body::Body>, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequest::from_request(req, state) + <#path::<#ident> as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(|#path(inner)| inner) .map_err(#map_err) @@ -942,7 +960,6 @@ fn impl_enum_by_extracting_all_at_once( } Trait::FromRequestParts => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where @@ -954,7 +971,7 @@ fn impl_enum_by_extracting_all_at_once( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequestParts::from_request_parts(parts, state) + <#path::<#ident> as ::axum::extract::FromRequestParts<_>>::from_request_parts(parts, state) .await .map(|#path(inner)| inner) .map_err(#map_err) diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index f35e7dae76..f5aeaab748 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -15,7 +15,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -44,6 +43,7 @@ #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] +use debug_handler::FunctionKind; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse::Parse, Type}; @@ -246,7 +246,6 @@ use from_request::Trait::{FromRequest, FromRequestParts}; /// /// struct MyInnerType; /// -/// #[axum::async_trait] /// impl FromRequestParts for MyInnerType { /// // ... /// # type Rejection = (); @@ -398,7 +397,7 @@ use from_request::Trait::{FromRequest, FromRequestParts}; /// /// # Known limitations /// -/// Generics are only supported on tuple structs with exactly on field. Thus this doesn't work +/// Generics are only supported on tuple structs with exactly one field. Thus this doesn't work /// /// ```compile_fail /// #[derive(axum_macros::FromRequest)] @@ -464,7 +463,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { expand_with(item, |item| from_request::expand(item, FromRequestParts)) } -/// Generates better error messages when applied handler functions. +/// Generates better error messages when applied to handler functions. /// /// While using [`axum`], you can get long error messages for simple mistakes. For example: /// @@ -515,17 +514,15 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// /// As the error message says, handler function needs to be async. /// -/// ``` +/// ```no_run /// use axum::{routing::get, Router, debug_handler}; /// /// #[tokio::main] /// async fn main() { -/// # async { /// let app = Router::new().route("/", get(handler)); /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, app).await.unwrap(); -/// # }; /// } /// /// #[debug_handler] @@ -618,7 +615,65 @@ pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream { return input; #[cfg(debug_assertions)] - return expand_attr_with(_attr, input, debug_handler::expand); + return expand_attr_with(_attr, input, |attrs, item_fn| { + debug_handler::expand(attrs, item_fn, FunctionKind::Handler) + }); +} + +/// Generates better error messages when applied to middleware functions. +/// +/// This works similarly to [`#[debug_handler]`](macro@debug_handler) except for middleware using +/// [`axum::middleware::from_fn`]. +/// +/// # Example +/// +/// ```no_run +/// use axum::{ +/// routing::get, +/// extract::Request, +/// response::Response, +/// Router, +/// middleware::{self, Next}, +/// debug_middleware, +/// }; +/// +/// #[tokio::main] +/// async fn main() { +/// let app = Router::new() +/// .route("/", get(|| async {})) +/// .layer(middleware::from_fn(my_middleware)); +/// +/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); +/// axum::serve(listener, app).await.unwrap(); +/// } +/// +/// // if this wasn't a valid middleware function #[debug_middleware] would +/// // improve compile error +/// #[debug_middleware] +/// async fn my_middleware( +/// request: Request, +/// next: Next, +/// ) -> Response { +/// next.run(request).await +/// } +/// ``` +/// +/// # Performance +/// +/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`) +/// +/// [`axum`]: https://docs.rs/axum/latest +/// [`axum::middleware::from_fn`]: https://docs.rs/axum/0.7/axum/middleware/fn.from_fn.html +/// [`debug_middleware`]: macro@debug_middleware +#[proc_macro_attribute] +pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream { + #[cfg(not(debug_assertions))] + return input; + + #[cfg(debug_assertions)] + return expand_attr_with(_attr, input, |attrs, item_fn| { + debug_handler::expand(attrs, item_fn, FunctionKind::Middleware) + }); } /// Private API: Do no use this! diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index baaf7f9fa2..fa272252be 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -133,7 +133,6 @@ fn expand_named_fields( let map_err_rejection = map_err_rejection(&rejection); let from_request_impl = quote! { - #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where @@ -238,7 +237,6 @@ fn expand_unnamed_fields( let map_err_rejection = map_err_rejection(&rejection); let from_request_impl = quote! { - #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where @@ -322,7 +320,6 @@ fn expand_unit_fields( }; let from_request_impl = quote! { - #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index eedccb2317..f5687df0e8 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -6,15 +6,15 @@ error[E0277]: the trait bound `bool: FromRequest<(), axum_core::extract::private | = note: Function argument is not a valid axum extractor. See `https://docs.rs/axum/0.7/axum/extract/index.html` for details - = help: the following other types implement trait `FromRequest`: - axum::body::Bytes - Body - Form - Json - axum::http::Request - RawForm - String - Option + = help: the following other types implement trait `FromRequestParts`: + `()` implements `FromRequestParts` + `(T1, T2)` implements `FromRequestParts` + `(T1, T2, T3)` implements `FromRequestParts` + `(T1, T2, T3, T4)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7, T8)` implements `FromRequestParts` and $N others = note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>` note: required by a bound in `__axum_macros_check_handler_0_from_request_check` diff --git a/axum-macros/tests/debug_handler/fail/extension_not_clone.rs b/axum-macros/tests/debug_handler/fail/extension_not_clone.rs new file mode 100644 index 0000000000..6bed79e195 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/extension_not_clone.rs @@ -0,0 +1,9 @@ +use axum::extract::Extension; +use axum_macros::debug_handler; + +struct NonCloneType; + +#[debug_handler] +async fn test_extension_non_clone(_: Extension) {} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/extension_not_clone.stderr b/axum-macros/tests/debug_handler/fail/extension_not_clone.stderr new file mode 100644 index 0000000000..81bec91835 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/extension_not_clone.stderr @@ -0,0 +1,28 @@ +error[E0277]: the trait bound `NonCloneType: Clone` is not satisfied + --> tests/debug_handler/fail/extension_not_clone.rs:7:38 + | +7 | async fn test_extension_non_clone(_: Extension) {} + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `Clone` is not implemented for `NonCloneType`, which is required by `Extension: FromRequest<(), _>` + | + = help: the following other types implement trait `FromRequest`: + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) + (T1, T2, T3, T4, T5, T6, T7, T8, T9) + and $N others + = note: required for `Extension` to implement `FromRequestParts<()>` + = note: required for `Extension` to implement `FromRequest<(), axum_core::extract::private::ViaParts>` +note: required by a bound in `__axum_macros_check_test_extension_non_clone_0_from_request_check` + --> tests/debug_handler/fail/extension_not_clone.rs:7:38 + | +7 | async fn test_extension_non_clone(_: Extension) {} + | ^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `__axum_macros_check_test_extension_non_clone_0_from_request_check` +help: consider annotating `NonCloneType` with `#[derive(Clone)]` + | +4 + #[derive(Clone)] +5 | struct NonCloneType; + | diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index 21ae99d6b8..eb17c1df52 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -1,12 +1,8 @@ -use axum::{ - async_trait, - extract::{Request, FromRequest}, -}; +use axum::extract::{FromRequest, Request}; use axum_macros::debug_handler; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr index 595786bf4e..0610a22a3b 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_mut.rs:23:22 + --> tests/debug_handler/fail/extract_self_mut.rs:19:22 | -23 | async fn handler(&mut self) {} +19 | async fn handler(&mut self) {} | ^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index 8e32811994..d70c5f2318 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -1,12 +1,8 @@ -use axum::{ - async_trait, - extract::{Request, FromRequest}, -}; +use axum::extract::{FromRequest, Request}; use axum_macros::debug_handler; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr index 4c0b4950c7..d475c5092f 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_ref.rs:23:22 + --> tests/debug_handler/fail/extract_self_ref.rs:19:22 | -23 | async fn handler(&self) {} +19 | async fn handler(&self) {} | ^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr b/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr index 14b3ae83fa..afda86b65d 100644 --- a/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr +++ b/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr @@ -4,20 +4,25 @@ error[E0277]: the trait bound `for<'de> Struct: serde::de::Deserialize<'de>` is 7 | async fn handler(_foo: Json) {} | ^^^^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `Struct`, which is required by `Json: FromRequest<()>` | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `Struct` type + = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `serde::de::Deserialize<'de>`: - bool - char - isize - i8 - i16 - i32 - i64 - i128 + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) and $N others = note: required for `Struct` to implement `serde::de::DeserializeOwned` = note: required for `Json` to implement `FromRequest<()>` = help: see issue #48214 - = help: add `#![feature(trivial_bounds)]` to the crate attributes to enable +help: add `#![feature(trivial_bounds)]` to the crate attributes to enable + | +1 + #![feature(trivial_bounds)] + | error[E0277]: the trait bound `for<'de> Struct: serde::de::Deserialize<'de>` is not satisfied --> tests/debug_handler/fail/json_not_deserialize.rs:7:24 @@ -25,15 +30,17 @@ error[E0277]: the trait bound `for<'de> Struct: serde::de::Deserialize<'de>` is 7 | async fn handler(_foo: Json) {} | ^^^^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `Struct`, which is required by `Json: FromRequest<()>` | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `Struct` type + = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `serde::de::Deserialize<'de>`: - bool - char - isize - i8 - i16 - i32 - i64 - i128 + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) and $N others = note: required for `Struct` to implement `serde::de::DeserializeOwned` = note: required for `Json` to implement `FromRequest<()>` diff --git a/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr b/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr index a6ffb0ed1c..8909373553 100644 --- a/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr +++ b/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr @@ -5,14 +5,14 @@ error[E0277]: the trait bound `NotIntoResponse: IntoResponse` is not satisfied | ^^^^^^^^^^^^^^^^^ the trait `IntoResponse` is not implemented for `NotIntoResponse` | = help: the following other types implement trait `IntoResponse`: - Box - Box<[u8]> - axum::body::Bytes - Body - axum::extract::rejection::FailedToBufferBody - axum::extract::rejection::LengthLimitError - axum::extract::rejection::UnknownBodyError - axum::extract::rejection::InvalidUtf8 + &'static [u8; N] + &'static [u8] + &'static str + () + (R,) + (Response<()>, R) + (Response<()>, T1, R) + (Response<()>, T1, T2, R) and $N others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/single_wrong_return_tuple.rs:6:23 diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr index 812c37edac..77597b3358 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr +++ b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr @@ -11,17 +11,20 @@ error[E0277]: the trait bound `CustomIntoResponse: IntoResponseParts` is not sat | ^^^^^^^^^^^^^^^^^^ the trait `IntoResponseParts` is not implemented for `CustomIntoResponse` | = help: the following other types implement trait `IntoResponseParts`: - AppendHeaders - HeaderMap - Extension - Extensions - Option - [(K, V); N] () - (T1,) + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) and $N others = help: see issue #48214 - = help: add `#![feature(trivial_bounds)]` to the crate attributes to enable +help: add `#![feature(trivial_bounds)]` to the crate attributes to enable + | +3 + #![feature(trivial_bounds)] + | error[E0277]: the trait bound `CustomIntoResponse: IntoResponseParts` is not satisfied --> tests/debug_handler/fail/wrong_return_tuple.rs:24:5 @@ -30,14 +33,14 @@ error[E0277]: the trait bound `CustomIntoResponse: IntoResponseParts` is not sat | ^^^^^^^^^^^^^^^^^^ the trait `IntoResponseParts` is not implemented for `CustomIntoResponse` | = help: the following other types implement trait `IntoResponseParts`: - AppendHeaders - HeaderMap - Extension - Extensions - Option - [(K, V); N] () - (T1,) + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) and $N others note: required by a bound in `__axum_macros_check_custom_type_into_response_parts_1_check` --> tests/debug_handler/fail/wrong_return_tuple.rs:24:5 diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr index cc718aae0c..c305e7e781 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr +++ b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr @@ -5,14 +5,14 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied | ^^^^ the trait `IntoResponse` is not implemented for `bool` | = help: the following other types implement trait `IntoResponse`: - Box - Box<[u8]> - axum::body::Bytes - Body - axum::extract::rejection::FailedToBufferBody - axum::extract::rejection::LengthLimitError - axum::extract::rejection::UnknownBodyError - axum::extract::rejection::InvalidUtf8 + &'static [u8; N] + &'static [u8] + &'static str + () + (R,) + (Response<()>, R) + (Response<()>, T1, R) + (Response<()>, T1, T2, R) and $N others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/wrong_return_type.rs:4:23 diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 782fc9301c..f23c9b627c 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -1,4 +1,4 @@ -use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::IntoResponse}; +use axum::{extract::FromRequestParts, http::request::Parts, response::IntoResponse}; use axum_macros::debug_handler; fn main() {} @@ -115,7 +115,6 @@ impl A { } } -#[async_trait] impl FromRequestParts for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index 9b72284502..3fbcc4e03b 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -1,12 +1,8 @@ -use axum::{ - async_trait, - extract::{Request, FromRequest}, -}; +use axum::extract::{FromRequest, Request}; use axum_macros::debug_handler; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, @@ -18,7 +14,6 @@ where } } -#[async_trait] impl FromRequest for Box where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/pass/set_state.rs b/axum-macros/tests/debug_handler/pass/set_state.rs index 60a7a3304e..72bba5aede 100644 --- a/axum-macros/tests/debug_handler/pass/set_state.rs +++ b/axum-macros/tests/debug_handler/pass/set_state.rs @@ -1,6 +1,5 @@ +use axum::extract::{FromRef, FromRequest, Request}; use axum_macros::debug_handler; -use axum::extract::{Request, FromRef, FromRequest}; -use axum::async_trait; #[debug_handler(state = AppState)] async fn handler(_: A) {} @@ -10,7 +9,6 @@ struct AppState; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs new file mode 100644 index 0000000000..12092e857b --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs @@ -0,0 +1,13 @@ +use axum::{ + debug_middleware, + extract::Request, + response::{IntoResponse, Response}, +}; + +#[debug_middleware] +async fn my_middleware(request: Request) -> Response { + let _ = request; + ().into_response() +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr new file mode 100644 index 0000000000..2474a4ebb4 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr @@ -0,0 +1,7 @@ +error: Middleware functions must take `axum::middleware::Next` as the last argument + --> tests/debug_middleware/fail/doesnt_take_next.rs:7:1 + | +7 | #[debug_middleware] + | ^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_middleware/fail/next_not_last.rs b/axum-macros/tests/debug_middleware/fail/next_not_last.rs new file mode 100644 index 0000000000..0108c85433 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/next_not_last.rs @@ -0,0 +1,13 @@ +use axum::{ + extract::Request, + response::Response, + middleware::Next, + debug_middleware, +}; + +#[debug_middleware] +async fn my_middleware(next: Next, request: Request) -> Response { + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/next_not_last.stderr b/axum-macros/tests/debug_middleware/fail/next_not_last.stderr new file mode 100644 index 0000000000..8f08bed72d --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/next_not_last.stderr @@ -0,0 +1,5 @@ +error: `axum::middleware::Next` must the last argument + --> tests/debug_middleware/fail/next_not_last.rs:9:24 + | +9 | async fn my_middleware(next: Next, request: Request) -> Response { + | ^^^^^^^^^^ diff --git a/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs b/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs new file mode 100644 index 0000000000..995a97bda6 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs @@ -0,0 +1,9 @@ +use axum::{debug_middleware, extract::Request, middleware::Next, response::Response}; + +#[debug_middleware] +async fn my_middleware(request: Request, next: Next, next2: Next) -> Response { + let _ = next2; + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr b/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr new file mode 100644 index 0000000000..596f55817f --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr @@ -0,0 +1,7 @@ +error: Middleware functions can only take one argument of type `axum::middleware::Next` + --> tests/debug_middleware/fail/takes_next_twice.rs:3:1 + | +3 | #[debug_middleware] + | ^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_middleware/pass/basic.rs b/axum-macros/tests/debug_middleware/pass/basic.rs new file mode 100644 index 0000000000..605cacfd40 --- /dev/null +++ b/axum-macros/tests/debug_middleware/pass/basic.rs @@ -0,0 +1,13 @@ +use axum::{ + extract::Request, + response::Response, + middleware::Next, + debug_middleware, +}; + +#[debug_middleware] +async fn my_middleware(request: Request, next: Next) -> Response { + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/from_request/fail/generic_without_via.stderr b/axum-macros/tests/from_request/fail/generic_without_via.stderr index 60c16256f2..daabab098d 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via.stderr @@ -14,8 +14,8 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {fo | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr index 9da7b93a2f..66f90281ca 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr @@ -14,8 +14,8 @@ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {fo | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr index b220ab4eb9..e70248f3a6 100644 --- a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr +++ b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr @@ -14,8 +14,8 @@ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future {hand | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | @@ -36,8 +36,8 @@ error[E0277]: the trait bound `fn(Result) -> impl Futu | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `MethodRouter::::post` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr index c7c9b201fe..d2401803dd 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -1,18 +1,18 @@ -error[E0277]: the trait bound `String: FromRequestParts` is not satisfied +error[E0277]: the trait bound `String: FromRequestParts<_>` is not satisfied --> tests/from_request/fail/parts_extracting_body.rs:5:11 | 5 | body: String, - | ^^^^^^ the trait `FromRequestParts` is not implemented for `String` + | ^^^^^^ the trait `FromRequestParts<_>` is not implemented for `String` | = note: Function argument is not a valid axum extractor. See `https://docs.rs/axum/0.7/axum/extract/index.html` for details = help: the following other types implement trait `FromRequestParts`: - > - > - as FromRequestParts> - > - > - > - > - > + `()` implements `FromRequestParts` + `(T1, T2)` implements `FromRequestParts` + `(T1, T2, T3)` implements `FromRequestParts` + `(T1, T2, T3, T4)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7, T8)` implements `FromRequestParts` and $N others diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index 25e399b4e0..736006edad 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -1,6 +1,6 @@ use axum::{ - async_trait, - extract::{Request, rejection::ExtensionRejection, FromRequest}, + body::Body, + extract::{rejection::ExtensionRejection, FromRequest, Request}, http::StatusCode, response::{IntoResponse, Response}, routing::get, @@ -26,7 +26,6 @@ struct MyExtractor { struct OtherExtractor; -#[async_trait] impl FromRequest for OtherExtractor where S: Send + Sync, diff --git a/axum-macros/tests/from_request/pass/override_rejection_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_parts.rs index 8ef9cb22db..7cc27de24c 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_parts.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_parts.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::{rejection::ExtensionRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, @@ -26,7 +25,6 @@ struct MyExtractor { struct OtherExtractor; -#[async_trait] impl FromRequestParts for OtherExtractor where S: Send + Sync, diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 8513e2aeaf..ed2c9d7571 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -1,13 +1,42 @@ +error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is not satisfied + --> tests/typed_path/fail/not_deserialize.rs:3:10 + | +3 | #[derive(TypedPath)] + | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` + | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `MyPath` type + = note: for types from other crates check whether the crate offers a `serde` feature flag + = help: the following other types implement trait `serde::de::Deserialize<'de>`: + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + and $N others + = note: required for `MyPath` to implement `serde::de::DeserializeOwned` + = note: required for `axum::extract::Path` to implement `FromRequestParts` + error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied --> tests/typed_path/fail/not_deserialize.rs:3:10 | 3 | #[derive(TypedPath)] | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` | - = help: the trait `FromRequestParts` is implemented for `axum::extract::Path` + = help: the following other types implement trait `serde::de::Deserialize<'de>`: + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + and $N others = note: required for `MyPath` to implement `serde::de::DeserializeOwned` = note: required for `axum::extract::Path` to implement `FromRequestParts` - = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied --> tests/typed_path/fail/not_deserialize.rs:3:10 @@ -15,7 +44,16 @@ error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satis 3 | #[derive(TypedPath)] | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` | - = help: the trait `FromRequestParts` is implemented for `axum::extract::Path` + = help: the following other types implement trait `serde::de::Deserialize<'de>`: + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + and $N others = note: required for `MyPath` to implement `serde::de::DeserializeOwned` = note: required for `axum::extract::Path` to implement `FromRequestParts` - = note: this error originates in the attribute macro `::axum::async_trait` (in Nightly builds, run with -Z macro-backtrace for more info) + = note: this error originates in the derive macro `TypedPath` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 316f2150d4..47b69e2c83 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,10 +7,34 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased +- **breaking:** Require `Sync` for all handlers and services added to `Router` + and `MethodRouter` ([#2473]) +- **breaking:** The tuple and tuple_struct `Path` extractor deserializers now check that the number of parameters matches the tuple length exactly ([#2931]) +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2473]: https://github.com/tokio-rs/axum/pull/2473 +[#2931]: https://github.com/tokio-rs/axum/pull/2931 +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.7.7 + +- **change**: Remove manual tables of content from the documentation, since + rustdoc now generates tables of content in the sidebar ([#2921]) + +[#2921]: https://github.com/tokio-rs/axum/pull/2921 + +# 0.7.6 + - **change:** Avoid cloning `Arc` during deserialization of `Path` - **added:** `axum::serve::Serve::tcp_nodelay` and `axum::serve::WithGracefulShutdown::tcp_nodelay` ([#2653]) +- **added:** `Router::has_routes` function ([#2790]) +- **change:** Update tokio-tungstenite to 0.23 ([#2841]) +- **added:** `Serve::local_addr` and `WithGracefulShutdown::local_addr` functions ([#2881]) [#2653]: https://github.com/tokio-rs/axum/pull/2653 +[#2790]: https://github.com/tokio-rs/axum/pull/2790 +[#2841]: https://github.com/tokio-rs/axum/pull/2841 +[#2881]: https://github.com/tokio-rs/axum/pull/2881 # 0.7.5 (24. March, 2024) @@ -18,7 +42,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 a `Router` or `MethodRouter` ([#2586]) - **fixed:** `h2` is no longer pulled as a dependency unless the `http2` feature is enabled ([#2605]) +- **added:** Add `#[debug_middleware]` ([#1993], [#2725]) +[#1993]: https://github.com/tokio-rs/axum/pull/1993 +[#2725]: https://github.com/tokio-rs/axum/pull/2725 [#2586]: https://github.com/tokio-rs/axum/pull/2586 [#2605]: https://github.com/tokio-rs/axum/pull/2605 @@ -31,7 +58,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#2201]: https://github.com/tokio-rs/axum/pull/2201 [#2483]: https://github.com/tokio-rs/axum/pull/2483 -[#2201]: https://github.com/tokio-rs/axum/pull/2201 [#2484]: https://github.com/tokio-rs/axum/pull/2484 # 0.7.3 (29. December, 2023) @@ -571,7 +597,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ```rust use axum::{Json, http::HeaderMap}; - // This wont compile on 0.6 because both `Json` and `String` need to consume + // This won't compile on 0.6 because both `Json` and `String` need to consume // the request body. You can use either `Json` or `String`, but not both. async fn handler_1( json: Json, @@ -602,7 +628,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ```rust struct MyExtractor { /* ... */ } - #[async_trait] impl FromRequest for MyExtractor where B: Send, @@ -621,13 +646,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 use axum::{ extract::{FromRequest, FromRequestParts}, http::{StatusCode, Request, request::Parts}, - async_trait, }; struct MyExtractor { /* ... */ } // implement `FromRequestParts` if you don't need to consume the request body - #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, @@ -640,7 +663,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 } // implement `FromRequest` if you do need to consume the request body - #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, @@ -1157,7 +1179,7 @@ Yanked, as it didn't compile in release mode. ```rust use axum::{Json, http::HeaderMap}; - // This wont compile on 0.6 because both `Json` and `String` need to consume + // This won't compile on 0.6 because both `Json` and `String` need to consume // the request body. You can use either `Json` or `String`, but not both. async fn handler_1( json: Json, @@ -1188,7 +1210,6 @@ Yanked, as it didn't compile in release mode. ```rust struct MyExtractor { /* ... */ } - #[async_trait] impl FromRequest for MyExtractor where B: Send, @@ -1207,13 +1228,11 @@ Yanked, as it didn't compile in release mode. use axum::{ extract::{FromRequest, FromRequestParts}, http::{StatusCode, Request, request::Parts}, - async_trait, }; struct MyExtractor { /* ... */ } // implement `FromRequestParts` if you don't need to consume the request body - #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, @@ -1226,7 +1245,6 @@ Yanked, as it didn't compile in release mode. } // implement `FromRequest` if you do need to consume the request body - #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, diff --git a/axum/Cargo.toml b/axum/Cargo.toml index c81d9aacc7..6b8d70445d 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "axum" -version = "0.7.5" +version = "0.7.7" categories = ["asynchronous", "network-programming", "web-programming::http-server"] description = "Web framework that focuses on ergonomics and modularity" edition = "2021" -rust-version = "1.66" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" @@ -41,8 +41,7 @@ ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"] __private_docs = ["tower/full", "dep:tower-http"] [dependencies] -async-trait = "0.1.67" -axum-core = { path = "../axum-core", version = "0.4.3" } +axum-core = { path = "../axum-core", version = "0.4.5" } bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1.0.0" @@ -57,26 +56,26 @@ pin-project-lite = "0.2.7" rustversion = "1.0.9" serde = "1.0" sync_wrapper = "1.0.0" -tower = { version = "0.4.13", default-features = false, features = ["util"] } +tower = { version = "0.5.1", default-features = false, features = ["util"] } tower-layer = "0.3.2" tower-service = "0.3" # optional dependencies -axum-macros = { path = "../axum-macros", version = "0.4.1", optional = true } -base64 = { version = "0.21.0", optional = true } +axum-macros = { path = "../axum-macros", version = "0.4.2", optional = true } +base64 = { version = "0.22.1", optional = true } hyper = { version = "1.1.0", optional = true } -hyper-util = { version = "0.1.3", features = ["tokio", "server"], optional = true } +hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true } multer = { version = "3.0.0", optional = true } serde_json = { version = "1.0", features = ["raw_value"], optional = true } serde_path_to_error = { version = "0.1.8", optional = true } serde_urlencoded = { version = "0.7", optional = true } sha1 = { version = "0.10", optional = true } tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true } -tokio-tungstenite = { version = "0.21", optional = true } +tokio-tungstenite = { version = "0.24.0", optional = true } tracing = { version = "0.1", default-features = false, optional = true } [dependencies.tower-http] -version = "0.5.0" +version = "0.6.0" optional = true features = [ # all tower-http features except (de)?compression-zstd which doesn't @@ -121,7 +120,7 @@ serde_json = "1.0" time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" -tokio-tungstenite = "0.21" +tokio-tungstenite = "0.24.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json"] } uuid = { version = "1.0", features = ["serde", "v4"] } @@ -132,7 +131,7 @@ rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies.tower] package = "tower" -version = "0.4.10" +version = "0.5.1" features = [ "util", "timeout", @@ -143,7 +142,7 @@ features = [ ] [dev-dependencies.tower-http] -version = "0.5.0" +version = "0.6.0" features = [ # all tower-http features except (de)?compression-zstd which doesn't # build on `--target armv5te-unknown-linux-musleabi` @@ -199,7 +198,6 @@ allowed = [ "tower_service", # >=1.0 - "async_trait", "bytes", "http", "http_body", diff --git a/axum/README.md b/axum/README.md index dc7f1a95dd..344484a8ec 100644 --- a/axum/README.md +++ b/axum/README.md @@ -4,7 +4,7 @@ [![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum/actions/workflows/CI.yml) [![Crates.io](https://img.shields.io/crates/v/axum)](https://crates.io/crates/axum) -[![Documentation](https://docs.rs/axum/badge.svg)](https://docs.rs/axum) +[![Documentation](https://docs.rs/axum/badge.svg)][docs] More information about this crate can be found in the [crate documentation][docs]. @@ -23,6 +23,13 @@ In particular the last point is what sets `axum` apart from other frameworks. authorization, and more, for free. It also enables you to share middleware with applications written using [`hyper`] or [`tonic`]. +## ⚠ Breaking changes ⚠ + +We are currently working towards axum 0.8 so the `main` branch contains breaking +changes. See the [`0.7.x`] branch for what's released to crates.io. + +[`0.7.x`]: https://github.com/tokio-rs/axum/tree/v0.7.x + ## Usage example ```rust @@ -104,7 +111,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum's MSRV is 1.66. +axum's MSRV is 1.75. ## Examples diff --git a/axum/clippy.toml b/axum/clippy.toml deleted file mode 100644 index 291e8cd5f4..0000000000 --- a/axum/clippy.toml +++ /dev/null @@ -1,3 +0,0 @@ -disallowed-types = [ - { path = "std::sync::Mutex", reason = "Use our internal AxumMutex instead" }, -] diff --git a/axum/src/box_clone_service.rs b/axum/src/box_clone_service.rs new file mode 100644 index 0000000000..25c0b205b8 --- /dev/null +++ b/axum/src/box_clone_service.rs @@ -0,0 +1,80 @@ +use futures_util::future::BoxFuture; +use std::{ + fmt, + task::{Context, Poll}, +}; +use tower::ServiceExt; +use tower_service::Service; + +/// Like `tower::BoxCloneService` but `Sync` +pub(crate) struct BoxCloneService( + Box< + dyn CloneService>> + + Send + + Sync, + >, +); + +impl BoxCloneService { + pub(crate) fn new(inner: S) -> Self + where + S: Service + Clone + Send + Sync + 'static, + S::Future: Send + 'static, + { + let inner = inner.map_future(|f| Box::pin(f) as _); + BoxCloneService(Box::new(inner)) + } +} + +impl Service for BoxCloneService { + type Response = U; + type Error = E; + type Future = BoxFuture<'static, Result>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + #[inline] + fn call(&mut self, request: T) -> Self::Future { + self.0.call(request) + } +} + +impl Clone for BoxCloneService { + fn clone(&self) -> Self { + Self(self.0.clone_box()) + } +} + +trait CloneService: Service { + fn clone_box( + &self, + ) -> Box< + dyn CloneService + + Send + + Sync, + >; +} + +impl CloneService for T +where + T: Service + Send + Sync + Clone + 'static, +{ + fn clone_box( + &self, + ) -> Box< + dyn CloneService + + Send + + Sync, + > { + Box::new(self.clone()) + } +} + +impl fmt::Debug for BoxCloneService { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("BoxCloneService").finish() + } +} diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index f541a9fa30..0d65b2d38f 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -1,7 +1,6 @@ use std::{convert::Infallible, fmt}; use crate::extract::Request; -use crate::util::AxumMutex; use tower::Service; use crate::{ @@ -10,7 +9,7 @@ use crate::{ Router, }; -pub(crate) struct BoxedIntoRoute(AxumMutex>>); +pub(crate) struct BoxedIntoRoute(Box>); impl BoxedIntoRoute where @@ -21,10 +20,10 @@ where H: Handler, T: 'static, { - Self(AxumMutex::new(Box::new(MakeErasedHandler { + Self(Box::new(MakeErasedHandler { handler, into_route: |handler, state| Route::new(Handler::with_state(handler, state)), - }))) + })) } } @@ -33,23 +32,23 @@ impl BoxedIntoRoute { where S: 'static, E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, E2: 'static, { - BoxedIntoRoute(AxumMutex::new(Box::new(Map { - inner: self.0.into_inner().unwrap(), + BoxedIntoRoute(Box::new(Map { + inner: self.0, layer: Box::new(f), - }))) + })) } pub(crate) fn into_route(self, state: S) -> Route { - self.0.into_inner().unwrap().into_route(state) + self.0.into_route(state) } } impl Clone for BoxedIntoRoute { fn clone(&self) -> Self { - Self(AxumMutex::new(self.0.lock().unwrap().clone_box())) + Self(self.0.clone_box()) } } @@ -59,7 +58,7 @@ impl fmt::Debug for BoxedIntoRoute { } } -pub(crate) trait ErasedIntoRoute: Send { +pub(crate) trait ErasedIntoRoute: Send + Sync { fn clone_box(&self) -> Box>; fn into_route(self: Box, state: S) -> Route; @@ -75,7 +74,7 @@ pub(crate) struct MakeErasedHandler { impl ErasedIntoRoute for MakeErasedHandler where - H: Clone + Send + 'static, + H: Clone + Send + Sync + 'static, S: 'static, { fn clone_box(&self) -> Box> { @@ -103,6 +102,7 @@ where } } +#[allow(dead_code)] pub(crate) struct MakeErasedRouter { pub(crate) router: Router, pub(crate) into_route: fn(Router, S) -> Route, @@ -164,13 +164,13 @@ where } } -pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send { +pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send + Sync { fn clone_box(&self) -> Box>; } impl LayerFn for F where - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) diff --git a/axum/src/docs/error_handling.md b/axum/src/docs/error_handling.md index 6993b29ad0..7d7e14ee05 100644 --- a/axum/src/docs/error_handling.md +++ b/axum/src/docs/error_handling.md @@ -1,12 +1,5 @@ Error handling model and utilities -# Table of contents - -- [axum's error handling model](#axums-error-handling-model) -- [Routing to fallible services](#routing-to-fallible-services) -- [Applying fallible middleware](#applying-fallible-middleware) -- [Running extractors for error handling](#running-extractors-for-error-handling) - # axum's error handling model axum is based on [`tower::Service`] which bundles errors through its associated diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 807d7895e7..244528d6a8 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -1,26 +1,10 @@ Types and traits for extracting data from requests. -# Table of contents - -- [Intro](#intro) -- [Common extractors](#common-extractors) -- [Applying multiple extractors](#applying-multiple-extractors) -- [The order of extractors](#the-order-of-extractors) -- [Optional extractors](#optional-extractors) -- [Customizing extractor responses](#customizing-extractor-responses) -- [Accessing inner errors](#accessing-inner-errors) -- [Defining custom extractors](#defining-custom-extractors) -- [Accessing other extractors in `FromRequest` or `FromRequestParts` implementations](#accessing-other-extractors-in-fromrequest-or-fromrequestparts-implementations) -- [Request body limits](#request-body-limits) -- [Wrapping extractors](#wrapping-extractors) -- [Logging rejections](#logging-rejections) - # Intro A handler function is an async function that takes any number of "extractors" as arguments. An extractor is a type that implements -[`FromRequest`](crate::extract::FromRequest) -or [`FromRequestParts`](crate::extract::FromRequestParts). +[`FromRequest`] or [`FromRequestParts`]. For example, [`Json`] is an extractor that consumes the request body and deserializes it as JSON into some target type: @@ -281,10 +265,15 @@ let app = Router::new().route("/users", post(create_user)); # let _: Router = app; ``` +Another option is to make use of the optional extractors in [axum-extra] that +either returns `None` if there are no query parameters in the request URI, +or returns `Some(T)` if deserialization was successful. +If the deserialization was not successful, the request is rejected. + # Customizing extractor responses If an extractor fails it will return a response with the error and your -handler will not be called. To customize the error response you have a two +handler will not be called. To customize the error response you have two options: 1. Use `Result` as your extractor like shown in ["Optional @@ -420,7 +409,6 @@ request body: ```rust,no_run use axum::{ - async_trait, extract::FromRequestParts, routing::get, Router, @@ -433,7 +421,6 @@ use axum::{ struct ExtractUserAgent(HeaderValue); -#[async_trait] impl FromRequestParts for ExtractUserAgent where S: Send + Sync, @@ -463,7 +450,6 @@ If your extractor needs to consume the request body you must implement [`FromReq ```rust,no_run use axum::{ - async_trait, extract::{Request, FromRequest}, response::{Response, IntoResponse}, body::{Bytes, Body}, @@ -477,7 +463,6 @@ use axum::{ struct ValidatedBody(Bytes); -#[async_trait] impl FromRequest for ValidatedBody where Bytes: FromRequest, @@ -517,7 +502,6 @@ use axum::{ extract::{FromRequest, Request, FromRequestParts}, http::request::Parts, body::Body, - async_trait, }; use std::convert::Infallible; @@ -525,7 +509,6 @@ use std::convert::Infallible; struct MyExtractor; // `MyExtractor` implements both `FromRequest` -#[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, @@ -539,7 +522,6 @@ where } // and `FromRequestParts` -#[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, @@ -573,7 +555,6 @@ in your implementation. ```rust use axum::{ - async_trait, extract::{Extension, FromRequestParts}, http::{StatusCode, HeaderMap, request::Parts}, response::{IntoResponse, Response}, @@ -590,7 +571,6 @@ struct AuthenticatedUser { // ... } -#[async_trait] impl FromRequestParts for AuthenticatedUser where S: Send + Sync, @@ -644,7 +624,6 @@ use axum::{ routing::get, extract::{Request, FromRequest, FromRequestParts}, http::{HeaderMap, request::Parts}, - async_trait, }; use std::time::{Instant, Duration}; @@ -655,7 +634,6 @@ struct Timing { } // we must implement both `FromRequestParts` -#[async_trait] impl FromRequestParts for Timing where S: Send + Sync, @@ -675,7 +653,6 @@ where } // and `FromRequest` -#[async_trait] impl FromRequest for Timing where S: Send + Sync, @@ -710,6 +687,7 @@ logs, enable the `tracing` feature for axum (enabled by default) and the `axum::rejection=trace` tracing target, for example with `RUST_LOG=info,axum::rejection=trace cargo run`. +[axum-extra]: https://docs.rs/axum-extra/latest/axum_extra/extract/index.html [`body::Body`]: crate::body::Body [`Bytes`]: crate::body::Bytes [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 1529ef0365..3a90237236 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -1,15 +1,3 @@ -# Table of contents - -- [Intro](#intro) -- [Applying middleware](#applying-middleware) -- [Commonly used middleware](#commonly-used-middleware) -- [Ordering](#ordering) -- [Writing middleware](#writing-middleware) -- [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure) -- [Accessing state in middleware](#accessing-state-in-middleware) -- [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers) -- [Rewriting request URI in middleware](#rewriting-request-uri-in-middleware) - # Intro axum is unique in that it doesn't have its own bespoke middleware system and @@ -352,11 +340,11 @@ readiness inside the response future returned by `Service::call`. This works well when your services don't care about backpressure and are always ready anyway. -axum expects that all services used in your app wont care about +axum expects that all services used in your app won't care about backpressure and so it uses the latter strategy. However that means you should avoid routing to a service (or using a middleware) that _does_ care -about backpressure. At the very least you should [load shed] so requests are -dropped quickly and don't keep piling up. +about backpressure. At the very least you should [load shed][tower::load_shed] +so requests are dropped quickly and don't keep piling up. It also means that if `poll_ready` returns an error then that error will be returned in the response future from `call` and _not_ from `poll_ready`. In @@ -388,8 +376,7 @@ let app = ServiceBuilder::new() ``` However when applying middleware around your whole application in this way -you have to take care that errors are still being handled with -appropriately. +you have to take care that errors are still being handled appropriately. Also note that handlers created from async functions don't care about backpressure and are always ready. So if you're not using any Tower diff --git a/axum/src/docs/response.md b/axum/src/docs/response.md index a5761c34ed..c0974fb640 100644 --- a/axum/src/docs/response.md +++ b/axum/src/docs/response.md @@ -1,11 +1,5 @@ Types and traits for generating responses. -# Table of contents - -- [Building responses](#building-responses) -- [Returning different response types](#returning-different-response-types) -- [Regarding `impl IntoResponse`](#regarding-impl-intoresponse) - # Building responses Anything that implements [`IntoResponse`] can be returned from a handler. axum @@ -166,7 +160,7 @@ In general you can return tuples like: This means you cannot accidentally override the status or body as [`IntoResponseParts`] only allows setting headers and extensions. -Use [`Response`](crate::response::Response) for more low level control: +Use [`Response`] for more low level control: ```rust,no_run use axum::{ diff --git a/axum/src/docs/routing/fallback.md b/axum/src/docs/routing/fallback.md index 27fb76a59e..a864b7a45d 100644 --- a/axum/src/docs/routing/fallback.md +++ b/axum/src/docs/routing/fallback.md @@ -23,7 +23,11 @@ async fn fallback(uri: Uri) -> (StatusCode, String) { Fallbacks only apply to routes that aren't matched by anything in the router. If a handler is matched by a request but returns 404 the -fallback is not called. +fallback is not called. Note that this applies to [`MethodRouter`]s too: if the +request hits a valid path but the [`MethodRouter`] does not have an appropriate +method handler installed, the fallback is not called (use +[`MethodRouter::fallback`] for this purpose instead). + # Handling all requests without other routes diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index 8e315f5474..bb5b2ea6cb 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -82,6 +82,11 @@ let app = Router::new() # let _: Router = app; ``` +Additionally, while the wildcard route `/foo/*rest` will not match the +paths `/foo` or `/foo/`, a nested router at `/foo` will match the path `/foo` +(but not `/foo/`), and a nested router at `/foo/` will match the path `/foo/` +(but not `/foo`). + # Fallbacks If a nested router doesn't have its own fallback then it will inherit the @@ -181,7 +186,7 @@ router. # Panics - If the route overlaps with another route. See [`Router::route`] -for more details. + for more details. - If the route contains a wildcard (`*`). - If `path` is empty. diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md index 0d9853341f..01be9152ed 100644 --- a/axum/src/docs/routing/route.md +++ b/axum/src/docs/routing/route.md @@ -5,8 +5,7 @@ can be either static, a capture, or a wildcard. `method_router` is the [`MethodRouter`] that should receive the request if the path matches `path`. `method_router` will commonly be a handler wrapped in a method -router like [`get`](crate::routing::get). See [`handler`](crate::handler) for -more details on handlers. +router like [`get`]. See [`handler`](crate::handler) for more details on handlers. # Static paths diff --git a/axum/src/docs/routing/route_layer.md b/axum/src/docs/routing/route_layer.md index bc7b219742..9cce3ea79e 100644 --- a/axum/src/docs/routing/route_layer.md +++ b/axum/src/docs/routing/route_layer.md @@ -11,6 +11,10 @@ the request matches a route. This is useful for middleware that return early (such as authorization) which might otherwise convert a `404 Not Found` into a `401 Unauthorized`. +This function will panic if no routes have been declared yet on the router, +since the new layer will have no effect, and this is typically a bug. +In generic code, you can test if that is the case first, by calling [`Router::has_routes`]. + # Example ```rust diff --git a/axum/src/docs/routing/with_state.md b/axum/src/docs/routing/with_state.md index bece920fe0..973a87e01c 100644 --- a/axum/src/docs/routing/with_state.md +++ b/axum/src/docs/routing/with_state.md @@ -20,7 +20,7 @@ axum::serve(listener, routes).await.unwrap(); # Returning routers with states from functions -When returning `Router`s from functions it is generally recommend not set the +When returning `Router`s from functions, it is generally recommended not to set the state directly: ```rust @@ -171,7 +171,7 @@ work: # #[derive(Clone)] # struct AppState {} # -// This wont work because we're returning a `Router` +// This won't work because we're returning a `Router` // i.e. we're saying we're still missing an `AppState` fn routes(state: AppState) -> Router { Router::new() diff --git a/axum/src/extension.rs b/axum/src/extension.rs index e4d170fb6d..9485443232 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -1,5 +1,4 @@ use crate::{extract::rejection::*, response::IntoResponseParts}; -use async_trait::async_trait; use axum_core::{ extract::FromRequestParts, response::{IntoResponse, Response, ResponseParts}, @@ -70,7 +69,6 @@ use tower_service::Service; #[must_use] pub struct Extension(pub T); -#[async_trait] impl FromRequestParts for Extension where T: Clone + Send + Sync + 'static, diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index f77db6dd44..3d8f9a0163 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -7,7 +7,6 @@ use crate::extension::AddExtension; use super::{Extension, FromRequestParts}; -use async_trait::async_trait; use http::request::Parts; use std::{ convert::Infallible, @@ -139,7 +138,6 @@ opaque_future! { #[derive(Clone, Copy, Debug)] pub struct ConnectInfo(pub T); -#[async_trait] impl FromRequestParts for ConnectInfo where S: Send + Sync, diff --git a/axum/src/extract/host.rs b/axum/src/extract/host.rs index f1d179a545..62f0dc7880 100644 --- a/axum/src/extract/host.rs +++ b/axum/src/extract/host.rs @@ -2,7 +2,6 @@ use super::{ rejection::{FailedToResolveHost, HostRejection}, FromRequestParts, }; -use async_trait::async_trait; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, @@ -23,7 +22,6 @@ const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; #[derive(Debug, Clone)] pub struct Host(pub String); -#[async_trait] impl FromRequestParts for Host where S: Send + Sync, diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index d51d36c2fe..8fdd8e35a9 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,6 +1,5 @@ use super::{rejection::*, FromRequestParts}; use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; -use async_trait::async_trait; use http::request::Parts; use std::{collections::HashMap, sync::Arc}; @@ -63,7 +62,6 @@ impl MatchedPath { } } -#[async_trait] impl FromRequestParts for MatchedPath where S: Send + Sync, diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 7a303a4759..38ebaf9be2 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -4,7 +4,6 @@ use super::{FromRequest, Request}; use crate::body::Bytes; -use async_trait::async_trait; use axum_core::{ __composite_rejection as composite_rejection, __define_rejection as define_rejection, response::{IntoResponse, Response}, @@ -65,7 +64,6 @@ pub struct Multipart { inner: multer::Multipart<'static>, } -#[async_trait] impl FromRequest for Multipart where S: Send + Sync, diff --git a/axum/src/extract/nested_path.rs b/axum/src/extract/nested_path.rs index 72712a4e9a..61966a076a 100644 --- a/axum/src/extract/nested_path.rs +++ b/axum/src/extract/nested_path.rs @@ -4,7 +4,6 @@ use std::{ }; use crate::extract::Request; -use async_trait::async_trait; use axum_core::extract::FromRequestParts; use http::request::Parts; use tower_layer::{layer_fn, Layer}; @@ -47,7 +46,6 @@ impl NestedPath { } } -#[async_trait] impl FromRequestParts for NestedPath where S: Send + Sync, diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index 8ba8a431e9..2d0b04f5ad 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -140,7 +140,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { where V: Visitor<'de>, { - if self.url_params.len() < len { + if self.url_params.len() != len { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(len)); @@ -160,7 +160,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { where V: Visitor<'de>, { - if self.url_params.len() < len { + if self.url_params.len() != len { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(len)); @@ -773,20 +773,6 @@ mod tests { ); } - #[test] - fn test_parse_tuple_ignoring_additional_fields() { - let url_params = create_url_params(vec![ - ("a", "abc"), - ("b", "true"), - ("c", "1"), - ("d", "false"), - ]); - assert_eq!( - <(&str, bool, u32)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), - ("abc", true, 1) - ); - } - #[test] fn test_parse_map() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); @@ -813,6 +799,18 @@ mod tests { }; } + #[test] + fn test_parse_tuple_too_many_fields() { + test_parse_error!( + vec![("a", "abc"), ("b", "true"), ("c", "1"), ("d", "false"),], + (&str, bool, u32), + ErrorKind::WrongNumberOfParameters { + got: 4, + expected: 3, + } + ); + } + #[test] fn test_wrong_number_of_parameters_error() { test_parse_error!( diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 07acf0884a..a6779a5644 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -8,7 +8,6 @@ use crate::{ routing::url_params::UrlParams, util::PercentDecodedStr, }; -use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use http::{request::Parts, StatusCode}; use serde::de::DeserializeOwned; @@ -145,7 +144,6 @@ pub struct Path(pub T); axum_core::__impl_deref!(Path); -#[async_trait] impl FromRequestParts for Path where T: DeserializeOwned + Send, @@ -446,7 +444,6 @@ impl std::error::Error for FailedToDeserializePathParams {} #[derive(Debug)] pub struct RawPathParams(Vec<(Arc, PercentDecodedStr)>); -#[async_trait] impl FromRequestParts for RawPathParams where S: Send + Sync, @@ -750,6 +747,33 @@ mod tests { ); } + #[crate::test] + async fn tuple_param_matches_exactly() { + #[allow(dead_code)] + #[derive(Deserialize)] + struct Tuple(String, String); + + let app = Router::new() + .route("/foo/:a/:b/:c", get(|_: Path<(String, String)>| async {})) + .route("/bar/:a/:b/:c", get(|_: Path| async {})); + + let client = TestClient::new(app); + + let res = client.get("/foo/a/b/c").await; + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + res.text().await, + "Wrong number of path arguments for `Path`. Expected 2 but got 3", + ); + + let res = client.get("/bar/a/b/c").await; + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + res.text().await, + "Wrong number of path arguments for `Path`. Expected 2 but got 3", + ); + } + #[crate::test] async fn deserialize_into_vec_of_tuples() { let app = Router::new().route( diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index a331b68ca5..371612b71a 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -1,5 +1,4 @@ use super::{rejection::*, FromRequestParts}; -use async_trait::async_trait; use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; @@ -51,7 +50,6 @@ use serde::de::DeserializeOwned; #[derive(Debug, Clone, Copy, Default)] pub struct Query(pub T); -#[async_trait] impl FromRequestParts for Query where T: DeserializeOwned, diff --git a/axum/src/extract/raw_form.rs b/axum/src/extract/raw_form.rs index a4e0d6c57c..29cb4c6dd3 100644 --- a/axum/src/extract/raw_form.rs +++ b/axum/src/extract/raw_form.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use axum_core::extract::{FromRequest, Request}; use bytes::Bytes; use http::Method; @@ -30,7 +29,6 @@ use super::{ #[derive(Debug)] pub struct RawForm(pub Bytes); -#[async_trait] impl FromRequest for RawForm where S: Send + Sync, diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index d8c56f84a4..c792960a1b 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -1,5 +1,4 @@ use super::FromRequestParts; -use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; @@ -25,7 +24,6 @@ use std::convert::Infallible; #[derive(Debug)] pub struct RawQuery(pub Option); -#[async_trait] impl FromRequestParts for RawQuery where S: Send + Sync, diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index da1718795e..6d9adc672c 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -1,5 +1,4 @@ use super::{Extension, FromRequestParts}; -use async_trait::async_trait; use http::{request::Parts, Uri}; use std::convert::Infallible; @@ -70,7 +69,6 @@ use std::convert::Infallible; pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] -#[async_trait] impl FromRequestParts for OriginalUri where S: Send + Sync, diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index fb401c00d8..e72c2e11e5 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use axum_core::extract::{FromRef, FromRequestParts}; use http::request::Parts; use std::{ @@ -219,13 +218,11 @@ use std::{ /// ```rust /// use axum_core::extract::{FromRequestParts, FromRef}; /// use http::request::Parts; -/// use async_trait::async_trait; /// use std::convert::Infallible; /// /// // the extractor your library provides /// struct MyLibraryExtractor; /// -/// #[async_trait] /// impl FromRequestParts for MyLibraryExtractor /// where /// // keep `S` generic but require that it can produce a `MyLibraryState` @@ -344,7 +341,6 @@ use std::{ #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); -#[async_trait] impl FromRequestParts for State where InnerState: FromRef, diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index cfd10c5642..5a18d1900c 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -93,7 +93,6 @@ use self::rejection::*; use super::FromRequestParts; use crate::{body::Bytes, response::Response, Error}; -use async_trait::async_trait; use axum_core::body::Body; use futures_util::{ sink::{Sink, SinkExt}, @@ -381,7 +380,6 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade { fn call(self, _error: Error) {} } -#[async_trait] impl FromRequestParts for WebSocketUpgrade where S: Send + Sync, @@ -783,8 +781,9 @@ pub mod close_code { pub const PROTOCOL: u16 = 1002; /// Indicates that an endpoint is terminating the connection because it has received a type of - /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if - /// it receives a binary message). + /// data that it cannot accept. + /// + /// For example, an endpoint MAY send this if it understands only text data, but receives a binary message. pub const UNSUPPORTED: u16 = 1003; /// Indicates that no status code was included in a closing frame. @@ -794,12 +793,15 @@ pub mod close_code { pub const ABNORMAL: u16 = 1006; /// Indicates that an endpoint is terminating the connection because it has received data - /// within a message that was not consistent with the type of the message (e.g., non-UTF-8 - /// RFC3629 data within a text message). + /// within a message that was not consistent with the type of the message. + /// + /// For example, an endpoint received non-UTF-8 RFC3629 data within a text message. pub const INVALID: u16 = 1007; /// Indicates that an endpoint is terminating the connection because it has received a message - /// that violates its policy. This is a generic status code that can be returned when there is + /// that violates its policy. + /// + /// This is a generic status code that can be returned when there is /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to /// hide specific details about the policy. pub const POLICY: u16 = 1008; @@ -808,10 +810,13 @@ pub mod close_code { /// that is too big for it to process. pub const SIZE: u16 = 1009; - /// Indicates that an endpoint (client) is terminating the connection because it has expected - /// the server to negotiate one or more extension, but the server didn't return them in the - /// response message of the WebSocket handshake. The list of extensions that are needed should - /// be given as the reason for closing. Note that this status code is not used by the server, + /// Indicates that an endpoint (client) is terminating the connection because the server + /// did not respond to extension negotiation correctly. + /// + /// Specifically, the client has expected the server to negotiate one or more extension(s), + /// but the server didn't return them in the response message of the WebSocket handshake. + /// The list of extensions that are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, /// because it can fail the WebSocket handshake instead. pub const EXTENSION: u16 = 1010; diff --git a/axum/src/form.rs b/axum/src/form.rs index 966517a124..f754c4c1b8 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -1,6 +1,5 @@ use crate::extract::Request; use crate::extract::{rejection::*, FromRequest, RawForm}; -use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use axum_core::RequestExt; use http::header::CONTENT_TYPE; @@ -72,7 +71,6 @@ use serde::Serialize; #[must_use] pub struct Form(pub T); -#[async_trait] impl FromRequest for Form where T: DeserializeOwned, diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 982776145c..8485390bbf 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -10,8 +10,8 @@ //! // Handler that immediately returns an empty `200 OK` response. //! async fn unit_handler() {} //! -//! // Handler that immediately returns an empty `200 OK` response with a plain -//! // text body. +//! // Handler that immediately returns a `200 OK` response with a plain text +//! // body. //! async fn string_handler() -> String { //! "Hello, World!".to_string() //! } @@ -131,7 +131,7 @@ pub use self::service::HandlerService; note = "Consider using `#[axum::debug_handler]` to improve the error message" ) )] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sync + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; @@ -192,7 +192,7 @@ pub trait Handler: Clone + Send + Sized + 'static { impl Handler<((),), S> for F where - F: FnOnce() -> Fut + Clone + Send + 'static, + F: FnOnce() -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, Res: IntoResponse, { @@ -210,7 +210,7 @@ macro_rules! impl_handler { #[allow(non_snake_case, unused_mut)] impl Handler<(M, $($ty,)* $last,), S> for F where - F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static, + F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, S: Send + Sync + 'static, Res: IntoResponse, @@ -257,7 +257,7 @@ mod private { impl Handler for T where - T: IntoResponse + Clone + Send + 'static, + T: IntoResponse + Clone + Send + Sync + 'static, { type Future = std::future::Ready; @@ -302,7 +302,7 @@ where impl Handler for Layered where - L: Layer> + Clone + Send + 'static, + L: Layer> + Clone + Send + Sync + 'static, H: Handler, L::Service: Service + Clone + Send + 'static, >::Response: IntoResponse, @@ -328,6 +328,8 @@ where ) -> _, > = svc.oneshot(req).map(|result| match result { Ok(res) => res.into_response(), + + #[allow(unreachable_patterns)] Err(err) => match err {}, }); diff --git a/axum/src/json.rs b/axum/src/json.rs index bbe4008e68..a2dfdc2eeb 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -1,6 +1,5 @@ use crate::extract::Request; use crate::extract::{rejection::*, FromRequest}; -use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use bytes::{BufMut, Bytes, BytesMut}; use http::{ @@ -17,8 +16,7 @@ use serde::{de::DeserializeOwned, Serialize}; /// /// - The request doesn't have a `Content-Type: application/json` (or similar) header. /// - The body doesn't contain syntactically valid JSON. -/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target -/// type. +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type. /// - Buffering the request body fails. /// /// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be @@ -92,7 +90,6 @@ use serde::{de::DeserializeOwned, Serialize}; #[must_use] pub struct Json(pub T); -#[async_trait] impl FromRequest for Json where T: DeserializeOwned, diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 1a3032f681..fcc929a6ab 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -1,22 +1,5 @@ //! axum is a web application framework that focuses on ergonomics and modularity. //! -//! # Table of contents -//! -//! - [High-level features](#high-level-features) -//! - [Compatibility](#compatibility) -//! - [Example](#example) -//! - [Routing](#routing) -//! - [Handlers](#handlers) -//! - [Extractors](#extractors) -//! - [Responses](#responses) -//! - [Error handling](#error-handling) -//! - [Middleware](#middleware) -//! - [Sharing state with handlers](#sharing-state-with-handlers) -//! - [Building integrations for axum](#building-integrations-for-axum) -//! - [Required dependencies](#required-dependencies) -//! - [Examples](#examples) -//! - [Feature flags](#feature-flags) -//! //! # High-level features //! //! - Route requests to handlers with a macro-free API. @@ -293,6 +276,67 @@ //! The downside to this approach is that it's a little more verbose than using //! [`State`] or extensions. //! +//! ## Using [tokio's `task_local` macro](https://docs.rs/tokio/1/tokio/macro.task_local.html): +//! +//! This allows to share state with `IntoResponse` implementations. +//! +//! ```rust,no_run +//! use axum::{ +//! extract::Request, +//! http::{header, StatusCode}, +//! middleware::{self, Next}, +//! response::{IntoResponse, Response}, +//! routing::get, +//! Router, +//! }; +//! use tokio::task_local; +//! +//! #[derive(Clone)] +//! struct CurrentUser { +//! name: String, +//! } +//! task_local! { +//! pub static USER: CurrentUser; +//! } +//! +//! async fn auth(req: Request, next: Next) -> Result { +//! let auth_header = req +//! .headers() +//! .get(header::AUTHORIZATION) +//! .and_then(|header| header.to_str().ok()) +//! .ok_or(StatusCode::UNAUTHORIZED)?; +//! if let Some(current_user) = authorize_current_user(auth_header).await { +//! // State is setup here in the middleware +//! Ok(USER.scope(current_user, next.run(req)).await) +//! } else { +//! Err(StatusCode::UNAUTHORIZED) +//! } +//! } +//! async fn authorize_current_user(auth_token: &str) -> Option { +//! Some(CurrentUser { +//! name: auth_token.to_string(), +//! }) +//! } +//! +//! struct UserResponse; +//! +//! impl IntoResponse for UserResponse { +//! fn into_response(self) -> Response { +//! // State is accessed here in the IntoResponse implementation +//! let current_user = USER.with(|u| u.clone()); +//! (StatusCode::OK, current_user.name).into_response() +//! } +//! } +//! +//! async fn handler() -> UserResponse { +//! UserResponse +//! } +//! +//! let app: Router = Router::new() +//! .route("/", get(handler)) +//! .route_layer(middleware::from_fn(auth)); +//! ``` +//! //! # Building integrations for axum //! //! Libraries authors that want to provide [`FromRequest`], [`FromRequestParts`], or @@ -388,7 +432,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -420,6 +463,7 @@ #[macro_use] pub(crate) mod macros; +mod box_clone_service; mod boxed; mod extension; #[cfg(feature = "form")] @@ -442,8 +486,6 @@ pub mod serve; #[cfg(test)] mod test_helpers; -#[doc(no_inline)] -pub use async_trait::async_trait; #[doc(no_inline)] pub use http; @@ -463,7 +505,7 @@ pub use self::form::Form; pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt}; #[cfg(feature = "macros")] -pub use axum_macros::debug_handler; +pub use axum_macros::{debug_handler, debug_middleware}; #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[doc(inline)] diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index 63ef85800a..cda0d97798 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -26,7 +26,7 @@ use tower_service::Service; /// without repeating it in the function signature. /// /// Note that if the extractor consumes the request body, as `String` or -/// [`Bytes`] does, an empty body will be left in its place. Thus wont be +/// [`Bytes`] does, an empty body will be left in its place. Thus won't be /// accessible to subsequent extractors or handlers. /// /// # Example @@ -39,12 +39,10 @@ use tower_service::Service; /// Router, /// http::{header, StatusCode, request::Parts}, /// }; -/// use async_trait::async_trait; /// /// // An extractor that performs authorization. /// struct RequireAuth; /// -/// #[async_trait] /// impl FromRequestParts for RequireAuth /// where /// S: Send + Sync, @@ -303,7 +301,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router}; + use crate::{handler::Handler, routing::get, test_helpers::*, Router}; use axum_core::extract::FromRef; use http::{header, request::Parts, StatusCode}; use tower_http::limit::RequestBodyLimitLayer; @@ -315,7 +313,6 @@ mod tests { struct RequireAuth; - #[async_trait::async_trait] impl FromRequestParts for RequireAuth where S: Send + Sync, @@ -367,7 +364,6 @@ mod tests { fn works_with_request_body_limit() { struct MyExtractor; - #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index e4c44c74f5..d0d5046ceb 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -1,3 +1,4 @@ +use crate::box_clone_service::BoxCloneService; use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, FromRequestParts, Request}; use futures_util::future::BoxFuture; @@ -10,7 +11,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{util::BoxCloneService, ServiceBuilder}; +use tower::ServiceBuilder; use tower_layer::Layer; use tower_service::Service; @@ -19,9 +20,10 @@ use tower_service::Service; /// `from_fn` requires the function given to /// /// 1. Be an `async fn`. -/// 2. Take one or more [extractors] as the first arguments. -/// 3. Take [`Next`](Next) as the final argument. -/// 4. Return something that implements [`IntoResponse`]. +/// 2. Take zero or more [`FromRequestParts`] extractors. +/// 3. Take exactly one [`FromRequest`] extractor as the second to last argument. +/// 4. Take [`Next`](Next) as the last argument. +/// 5. Return something that implements [`IntoResponse`]. /// /// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`]. /// @@ -112,6 +114,8 @@ pub fn from_fn(f: F) -> FromFnLayer { /// Create a middleware from an async function with the given state. /// +/// For the requirements for the function supplied see [`from_fn`]. +/// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example @@ -259,6 +263,7 @@ macro_rules! impl_service { I: Service + Clone + Send + + Sync + 'static, I::Response: IntoResponse, I::Future: Send + 'static, @@ -297,7 +302,7 @@ macro_rules! impl_service { }; let inner = ServiceBuilder::new() - .boxed_clone() + .layer_fn(BoxCloneService::new) .map_response(IntoResponse::into_response) .service(ready_inner); let next = Next { inner }; @@ -340,6 +345,8 @@ impl Next { pub async fn run(mut self, req: Request) -> Response { match self.inner.call(req).await { Ok(res) => res, + + #[allow(unreachable_patterns)] Err(err) => match err {}, } } diff --git a/axum/src/middleware/map_response.rs b/axum/src/middleware/map_response.rs index 2510cdc256..e4c1c397ba 100644 --- a/axum/src/middleware/map_response.rs +++ b/axum/src/middleware/map_response.rs @@ -278,6 +278,8 @@ macro_rules! impl_service { Ok(res) => { f($($ty,)* res).await.into_response() } + + #[allow(unreachable_patterns)] Err(err) => match err {} } }); diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 6cfd9b0763..632df5cbd3 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -1,6 +1,5 @@ #![doc = include_str!("../docs/response.md")] -use axum_core::body::Body; use http::{header, HeaderValue}; mod redirect; @@ -40,7 +39,7 @@ pub struct Html(pub T); impl IntoResponse for Html where - T: Into, + T: IntoResponse, { fn into_response(self) -> Response { ( @@ -48,7 +47,7 @@ where header::CONTENT_TYPE, HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), )], - self.0.into(), + self.0, ) .into_response() } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 1eb6075b22..8c50eccf74 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -78,7 +78,7 @@ macro_rules! top_level_service_fn { $(#[$m])+ pub fn $name(svc: T) -> MethodRouter where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, S: Clone, @@ -210,6 +210,7 @@ macro_rules! chained_service_fn { T: Service + Clone + Send + + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, @@ -312,7 +313,7 @@ top_level_service_fn!(trace_service, TRACE); /// ``` pub fn on_service(filter: MethodFilter, svc: T) -> MethodRouter where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, S: Clone, @@ -371,7 +372,7 @@ where /// ``` pub fn any_service(svc: T) -> MethodRouter where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, S: Clone, @@ -736,7 +737,7 @@ where #[track_caller] pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { @@ -868,7 +869,7 @@ where #[doc = include_str!("../docs/method_routing/fallback.md")] pub fn fallback_service(mut self, svc: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { @@ -879,8 +880,8 @@ where #[doc = include_str!("../docs/method_routing/layer.md")] pub fn layer(self, layer: L) -> MethodRouter where - L: Layer> + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer> + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -908,8 +909,8 @@ where #[track_caller] pub fn route_layer(mut self, layer: L) -> MethodRouter where - L: Layer> + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer> + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Future: Send + 'static, E: 'static, @@ -1151,7 +1152,7 @@ where where S: 'static, E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, E2: 'static, { match self { diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 58babe9f5f..c9fcc0a7e1 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -172,7 +172,7 @@ where #[doc = include_str!("../docs/routing/route_service.md")] pub fn route_service(self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -192,6 +192,7 @@ where } #[doc = include_str!("../docs/routing/nest.md")] + #[doc(alias = "scope")] // Some web frameworks like actix-web use this term #[track_caller] pub fn nest(self, path: &str, router: Router) -> Self { let RouterInner { @@ -217,7 +218,7 @@ where #[track_caller] pub fn nest_service(self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -281,8 +282,8 @@ where #[doc = include_str!("../docs/routing/layer.md")] pub fn layer(self, layer: L) -> Router where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -299,8 +300,8 @@ where #[track_caller] pub fn route_layer(self, layer: L) -> Self where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -313,6 +314,11 @@ where }) } + /// True if the router currently has at least one route added. + pub fn has_routes(&self) -> bool { + self.inner.path_router.has_routes() + } + #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] pub fn fallback(self, handler: H) -> Self @@ -332,7 +338,7 @@ where /// See [`Router::fallback`] for more details. pub fn fallback_service(self, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -639,7 +645,7 @@ where where S: 'static, E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, E2: 'static, { match self { @@ -702,8 +708,8 @@ where { fn layer(self, layer: L) -> Endpoint where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index 8cb8f122dc..32b3102575 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -116,7 +116,7 @@ where service: T, ) -> Result<(), Cow<'static, str>> where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -234,7 +234,7 @@ where svc: T, ) -> Result<(), Cow<'static, str>> where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -269,8 +269,8 @@ where pub(super) fn layer(self, layer: L) -> PathRouter where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -295,8 +295,8 @@ where #[track_caller] pub(super) fn route_layer(self, layer: L) -> Self where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -325,6 +325,10 @@ where } } + pub(super) fn has_routes(&self) -> bool { + !self.routes.is_empty() + } + pub(super) fn with_state(self, state: S) -> PathRouter { let routes = self .routes diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index 2bde8c8c5f..da1e848a49 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -1,7 +1,7 @@ use crate::{ body::{Body, HttpBody}, + box_clone_service::BoxCloneService, response::Response, - util::AxumMutex, }; use axum_core::{extract::Request, response::IntoResponse}; use bytes::Bytes; @@ -18,7 +18,7 @@ use std::{ task::{Context, Poll}, }; use tower::{ - util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot}, + util::{MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot}, ServiceExt, }; use tower_layer::Layer; @@ -28,31 +28,31 @@ use tower_service::Service; /// /// You normally shouldn't need to care about this type. It's used in /// [`Router::layer`](super::Router::layer). -pub struct Route(AxumMutex>); +pub struct Route(BoxCloneService); impl Route { pub(crate) fn new(svc: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { - Self(AxumMutex::new(BoxCloneService::new( + Self(BoxCloneService::new( svc.map_response(IntoResponse::into_response), - ))) + )) } pub(crate) fn oneshot_inner( &mut self, req: Request, ) -> Oneshot, Request> { - self.0.get_mut().unwrap().clone().oneshot(req) + self.0.clone().oneshot(req) } pub(crate) fn layer(self, layer: L) -> Route where L: Layer> + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -72,7 +72,7 @@ impl Route { impl Clone for Route { #[track_caller] fn clone(&self) -> Self { - Self(AxumMutex::new(self.0.lock().unwrap().clone())) + Self(self.0.clone()) } } @@ -236,6 +236,8 @@ impl Future for InfallibleRouteFuture { fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match futures_util::ready!(self.project().future.poll(cx)) { Ok(response) => Poll::Ready(response), + + #[allow(unreachable_patterns)] Err(err) => match err {}, } } diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index 7209607af7..3209da3b12 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -104,7 +104,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option { } // if the prefix matches it will always do so up until a `/`, it cannot match only - // part of a segment. Therefore this will always be at a char boundary and `split_at` wont + // part of a segment. Therefore this will always be at a char boundary and `split_at` won't // panic let after_prefix = uri.path().split_at(matching_prefix_length?).1; diff --git a/axum/src/routing/tests/handle_error.rs b/axum/src/routing/tests/handle_error.rs index e5d575e9dc..a2fd2e6828 100644 --- a/axum/src/routing/tests/handle_error.rs +++ b/axum/src/routing/tests/handle_error.rs @@ -12,23 +12,6 @@ fn timeout() -> TimeoutLayer { TimeoutLayer::new(Duration::from_millis(10)) } -#[derive(Clone)] -struct Svc; - -impl Service for Svc { - type Response = Response; - type Error = hyper::Error; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _req: R) -> Self::Future { - ready(Ok(Response::new(Body::empty()))) - } -} - #[crate::test] async fn handler() { let app = Router::new().route( diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index e3a9d238a7..db5ca480da 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -12,7 +12,6 @@ use crate::{ tracing_helpers::{capture_tracing, TracingEvent}, *, }, - util::mutex_num_locked, BoxError, Extension, Json, Router, ServiceExt, }; use axum_core::extract::Request; @@ -915,31 +914,18 @@ async fn state_isnt_cloned_too_much() { impl Clone for AppState { fn clone(&self) -> Self { - #[rustversion::since(1.66)] - #[track_caller] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - let bt = std::backtrace::Backtrace::force_capture(); - let bt = bt - .to_string() - .lines() - .filter(|line| line.contains("axum") || line.contains("./src")) - .collect::>() - .join("\n"); - println!("AppState::Clone:\n===============\n{bt}\n"); - COUNT.fetch_add(1, Ordering::SeqCst); - } + if SETUP_DONE.load(Ordering::SeqCst) { + let bt = std::backtrace::Backtrace::force_capture(); + let bt = bt + .to_string() + .lines() + .filter(|line| line.contains("axum") || line.contains("./src")) + .collect::>() + .join("\n"); + println!("AppState::Clone:\n===============\n{bt}\n"); + COUNT.fetch_add(1, Ordering::SeqCst); } - #[rustversion::not(since(1.66))] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - count(); - Self } } @@ -969,7 +955,7 @@ async fn logging_rejections() { rejection_type: String, } - let events = capture_tracing::(|| async { + let events = capture_tracing::(|| async { let app = Router::new() .route("/extension", get(|_: Extension| async {})) .route("/string", post(|_: String| async {})); @@ -990,6 +976,7 @@ async fn logging_rejections() { StatusCode::BAD_REQUEST, ); }) + .with_filter("axum::rejection=trace") .await; assert_eq!( @@ -1071,38 +1058,6 @@ async fn impl_handler_for_into_response() { assert_eq!(res.text().await, "thing created"); } -#[crate::test] -async fn locks_mutex_very_little() { - let (num, app) = mutex_num_locked(|| async { - Router::new() - .route("/a", get(|| async {})) - .route("/b", get(|| async {})) - .route("/c", get(|| async {})) - .with_state::<()>(()) - .into_service::() - }) - .await; - // once for `Router::new` for setting the default fallback and 3 times, once per route - assert_eq!(num, 4); - - for path in ["/a", "/b", "/c"] { - // calling the router should only lock the mutex once - let (num, _res) = mutex_num_locked(|| async { - // We cannot use `TestClient` because it uses `serve` which spawns a new task per - // connection and `mutex_num_locked` uses a task local to keep track of the number of - // locks. So spawning a new task would unset the task local set by `mutex_num_locked` - // - // So instead `call` the service directly without spawning new tasks. - app.clone() - .oneshot(Request::builder().uri(path).body(Body::empty()).unwrap()) - .await - .unwrap() - }) - .await; - assert_eq!(num, 1); - } -} - #[crate::test] #[should_panic( expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router." diff --git a/axum/src/serve.rs b/axum/src/serve.rs index c5c540867c..1ba9a1452c 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -7,9 +7,7 @@ use std::{ io, marker::PhantomData, net::SocketAddr, - pin::Pin, sync::Arc, - task::{Context, Poll}, time::Duration, }; @@ -18,13 +16,12 @@ use futures_util::{pin_mut, FutureExt}; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; #[cfg(any(feature = "http1", feature = "http2"))] -use hyper_util::server::conn::auto::Builder; -use pin_project_lite::pin_project; +use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::{ net::{TcpListener, TcpStream}, sync::watch, }; -use tower::util::{Oneshot, ServiceExt}; +use tower::ServiceExt as _; use tower_service::Service; /// Serve the service with the supplied listener. @@ -176,6 +173,11 @@ impl Serve { ..self } } + + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.tcp_listener.local_addr() + } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] @@ -211,62 +213,8 @@ where type IntoFuture = private::ServeFuture; fn into_future(self) -> Self::IntoFuture { - private::ServeFuture(Box::pin(async move { - let Self { - tcp_listener, - mut make_service, - tcp_nodelay, - _marker: _, - } = self; - - loop { - let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await { - Some(conn) => conn, - None => continue, - }; - - if let Some(nodelay) = tcp_nodelay { - if let Err(err) = tcp_stream.set_nodelay(nodelay) { - trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); - } - } - - let tcp_stream = TokioIo::new(tcp_stream); - - poll_fn(|cx| make_service.poll_ready(cx)) - .await - .unwrap_or_else(|err| match err {}); - - let tower_service = make_service - .call(IncomingStream { - tcp_stream: &tcp_stream, - remote_addr, - }) - .await - .unwrap_or_else(|err| match err {}); - - let hyper_service = TowerToHyperService { - service: tower_service, - }; - - tokio::spawn(async move { - match Builder::new(TokioExecutor::new()) - // upgrades needed for websockets - .serve_connection_with_upgrades(tcp_stream, hyper_service) - .await - { - Ok(()) => {} - Err(_err) => { - // This error only appears when the client doesn't send a request and - // terminate the connection. - // - // If client sends one request then terminate connection whenever, it doesn't - // appear. - } - } - }); - } - })) + self.with_graceful_shutdown(std::future::pending()) + .into_future() } } @@ -311,6 +259,11 @@ impl WithGracefulShutdown { ..self } } + + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.tcp_listener.local_addr() + } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] @@ -404,11 +357,10 @@ where remote_addr, }) .await - .unwrap_or_else(|err| match err {}); + .unwrap_or_else(|err| match err {}) + .map_request(|req: Request| req.map(Body::new)); - let hyper_service = TowerToHyperService { - service: tower_service, - }; + let hyper_service = TowerToHyperService::new(tower_service); let signal_tx = Arc::clone(&signal_tx); @@ -518,49 +470,6 @@ mod private { } } -#[derive(Debug, Copy, Clone)] -struct TowerToHyperService { - service: S, -} - -impl hyper::service::Service> for TowerToHyperService -where - S: tower_service::Service + Clone, -{ - type Response = S::Response; - type Error = S::Error; - type Future = TowerToHyperServiceFuture; - - fn call(&self, req: Request) -> Self::Future { - let req = req.map(Body::new); - TowerToHyperServiceFuture { - future: self.service.clone().oneshot(req), - } - } -} - -pin_project! { - struct TowerToHyperServiceFuture - where - S: tower_service::Service, - { - #[pin] - future: Oneshot, - } -} - -impl Future for TowerToHyperServiceFuture -where - S: tower_service::Service, -{ - type Output = Result; - - #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().future.poll(cx) - } -} - /// An incoming stream. /// /// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. @@ -592,6 +501,10 @@ mod tests { routing::get, Router, }; + use std::{ + future::pending, + net::{IpAddr, Ipv4Addr}, + }; #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { @@ -655,4 +568,29 @@ mod tests { } async fn handler() {} + + #[crate::test] + async fn test_serve_local_addr() { + let router: Router = Router::new(); + let addr = "0.0.0.0:0"; + + let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + let address = server.local_addr().unwrap(); + + assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + assert_ne!(address.port(), 0); + } + + #[crate::test] + async fn test_with_graceful_shutdown_local_addr() { + let router: Router = Router::new(); + let addr = "0.0.0.0:0"; + + let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone()) + .with_graceful_shutdown(pending()); + let address = server.local_addr().unwrap(); + + assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + assert_ne!(address.port(), 0); + } } diff --git a/axum/src/test_helpers/tracing_helpers.rs b/axum/src/test_helpers/tracing_helpers.rs index 2240717ee4..09800f8fb9 100644 --- a/axum/src/test_helpers/tracing_helpers.rs +++ b/axum/src/test_helpers/tracing_helpers.rs @@ -1,7 +1,13 @@ -use crate::util::AxumMutex; -use std::{future::Future, io, sync::Arc}; +use std::{ + future::{Future, IntoFuture}, + io, + marker::PhantomData, + pin::Pin, + sync::{Arc, Mutex}, +}; use serde::{de::DeserializeOwned, Deserialize}; +use tracing::instrument::WithSubscriber; use tracing_subscriber::prelude::*; use tracing_subscriber::{filter::Targets, fmt::MakeWriter}; @@ -14,45 +20,78 @@ pub(crate) struct TracingEvent { } /// Run an async closure and capture the tracing output it produces. -pub(crate) async fn capture_tracing(f: F) -> Vec> +pub(crate) fn capture_tracing(f: F) -> CaptureTracing where - F: Fn() -> Fut, - Fut: Future, T: DeserializeOwned, { - let (make_writer, handle) = TestMakeWriter::new(); - - let subscriber = tracing_subscriber::registry().with( - tracing_subscriber::fmt::layer() - .with_writer(make_writer) - .with_target(true) - .without_time() - .with_ansi(false) - .json() - .flatten_event(false) - .with_filter("axum=trace".parse::().unwrap()), - ); - - let guard = tracing::subscriber::set_default(subscriber); - - f().await; - - drop(guard); - - handle - .take() - .lines() - .map(|line| serde_json::from_str(line).unwrap()) - .collect() + CaptureTracing { + f, + filter: None, + _phantom: PhantomData, + } +} + +pub(crate) struct CaptureTracing { + f: F, + filter: Option, + _phantom: PhantomData T>, +} + +impl CaptureTracing { + pub(crate) fn with_filter(mut self, filter_string: &str) -> Self { + self.filter = Some(filter_string.parse().unwrap()); + self + } +} + +impl IntoFuture for CaptureTracing +where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: Future + Send, + T: DeserializeOwned, +{ + type Output = Vec>; + type IntoFuture = Pin + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { f, filter, .. } = self; + Box::pin(async move { + let (make_writer, handle) = TestMakeWriter::new(); + + let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap()); + let subscriber = tracing_subscriber::registry().with( + tracing_subscriber::fmt::layer() + .with_writer(make_writer) + .with_target(true) + .without_time() + .with_ansi(false) + .json() + .flatten_event(false) + .with_filter(filter), + ); + + let guard = tracing::subscriber::set_default(subscriber); + + f().with_current_subscriber().await; + + drop(guard); + + handle + .take() + .lines() + .map(|line| serde_json::from_str(line).unwrap()) + .collect() + }) + } } struct TestMakeWriter { - write: Arc>>>, + write: Arc>>>, } impl TestMakeWriter { fn new() -> (Self, Handle) { - let write = Arc::new(AxumMutex::new(Some(Vec::::new()))); + let write = Arc::new(Mutex::new(Some(Vec::::new()))); ( Self { @@ -94,7 +133,7 @@ impl<'a> io::Write for Writer<'a> { } struct Handle { - write: Arc>>>, + write: Arc>>>, } impl Handle { diff --git a/axum/src/util.rs b/axum/src/util.rs index bae803db88..7c9b7864e9 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -1,8 +1,6 @@ use pin_project_lite::pin_project; use std::{ops::Deref, sync::Arc}; -pub(crate) use self::mutex::*; - #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub(crate) struct PercentDecodedStr(Arc); @@ -57,64 +55,3 @@ fn test_try_downcast() { assert_eq!(try_downcast::(5_u32), Err(5_u32)); assert_eq!(try_downcast::(5_i32), Ok(5_i32)); } - -// `AxumMutex` is a wrapper around `std::sync::Mutex` which, in test mode, tracks the number of -// times it's been locked on the current task. That way we can write a test to ensure we don't -// accidentally introduce more locking. -// -// When not in test mode, it is just a type alias for `std::sync::Mutex`. -#[cfg(not(test))] -mod mutex { - #[allow(clippy::disallowed_types)] - pub(crate) type AxumMutex = std::sync::Mutex; -} - -#[cfg(test)] -#[allow(clippy::disallowed_types)] -mod mutex { - use std::sync::{ - atomic::{AtomicUsize, Ordering}, - LockResult, Mutex, MutexGuard, - }; - - tokio::task_local! { - pub(crate) static NUM_LOCKED: AtomicUsize; - } - - pub(crate) async fn mutex_num_locked(f: F) -> (usize, Fut::Output) - where - F: FnOnce() -> Fut, - Fut: std::future::IntoFuture, - { - NUM_LOCKED - .scope(AtomicUsize::new(0), async move { - let output = f().await; - let num = NUM_LOCKED.with(|num| num.load(Ordering::SeqCst)); - (num, output) - }) - .await - } - - pub(crate) struct AxumMutex(Mutex); - - impl AxumMutex { - pub(crate) fn new(value: T) -> Self { - Self(Mutex::new(value)) - } - - pub(crate) fn get_mut(&mut self) -> LockResult<&mut T> { - self.0.get_mut() - } - - pub(crate) fn into_inner(self) -> LockResult { - self.0.into_inner() - } - - pub(crate) fn lock(&self) -> LockResult> { - _ = NUM_LOCKED.try_with(|num| { - num.fetch_add(1, Ordering::SeqCst); - }); - self.0.lock() - } - } -} diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000..625309989d --- /dev/null +++ b/clippy.toml @@ -0,0 +1,3 @@ +disallowed-types = [ + { path = "tower::util::BoxCloneService", reason = "Use our internal BoxCloneService which is Sync" }, +] diff --git a/deny.toml b/deny.toml index d95bdf294b..c32c8715a8 100644 --- a/deny.toml +++ b/deny.toml @@ -33,6 +33,11 @@ skip-tree = [ { name = "regex-automata" }, # pulled in by hyper { name = "socket2" }, + # hyper-util hasn't upgraded to 0.5 yet, but it's the same service / layer + # crates beneath + { name = "tower" }, + # tower hasn't upgraded to 1.0 yet + { name = "sync_wrapper" }, ] [sources] diff --git a/examples/chat/Cargo.toml b/examples/chat/Cargo.toml index 90d88246fa..2beb99cf85 100644 --- a/examples/chat/Cargo.toml +++ b/examples/chat/Cargo.toml @@ -8,6 +8,5 @@ publish = false axum = { path = "../../axum", features = ["ws"] } futures = "0.3" tokio = { version = "1", features = ["full"] } -tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 02e3bdc060..77baada1b5 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -36,7 +36,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_chat=trace".into()), + .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -130,8 +130,8 @@ async fn websocket(stream: WebSocket, state: Arc) { // If any one of the tasks run to completion, we abort the other. tokio::select! { - _ = (&mut send_task) => recv_task.abort(), - _ = (&mut recv_task) => send_task.abort(), + _ = &mut send_task => recv_task.abort(), + _ = &mut recv_task => send_task.abort(), }; // Send "user left" message (similar to "joined" above). diff --git a/examples/compression/Cargo.toml b/examples/compression/Cargo.toml new file mode 100644 index 0000000000..a65d9d0d1b --- /dev/null +++ b/examples/compression/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "example-compression" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tower = "0.5.1" +tower-http = { version = "0.6.1", features = ["compression-full", "decompression-full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[dev-dependencies] +assert-json-diff = "2.0" +brotli = "6.0" +flate2 = "1" +http = "1" +zstd = "0.13" diff --git a/examples/compression/README.md b/examples/compression/README.md new file mode 100644 index 0000000000..3f0ed94da7 --- /dev/null +++ b/examples/compression/README.md @@ -0,0 +1,32 @@ +# compression + +This example shows how to: +- automatically decompress request bodies when necessary +- compress response bodies based on the `accept` header. + +## Running + +``` +cargo run -p example-compression +``` + +## Sending compressed requests + +``` +curl -v -g 'http://localhost:3000/' \ + -H "Content-Type: application/json" \ + -H "Content-Encoding: gzip" \ + --compressed \ + --data-binary @data/products.json.gz +``` + +(Notice the `Content-Encoding: gzip` in the request, and `content-encoding: gzip` in the response.) + +## Sending uncompressed requests + +``` +curl -v -g 'http://localhost:3000/' \ + -H "Content-Type: application/json" \ + --compressed \ + --data-binary @data/products.json +``` diff --git a/examples/compression/data/products.json b/examples/compression/data/products.json new file mode 100644 index 0000000000..a234fbdd2a --- /dev/null +++ b/examples/compression/data/products.json @@ -0,0 +1,12 @@ +{ + "products": [ + { + "id": 1, + "name": "Product 1" + }, + { + "id": 2, + "name": "Product 2" + } + ] +} diff --git a/examples/compression/data/products.json.gz b/examples/compression/data/products.json.gz new file mode 100644 index 0000000000..91d398955b Binary files /dev/null and b/examples/compression/data/products.json.gz differ diff --git a/examples/compression/src/main.rs b/examples/compression/src/main.rs new file mode 100644 index 0000000000..b487f34e4f --- /dev/null +++ b/examples/compression/src/main.rs @@ -0,0 +1,39 @@ +use axum::{routing::post, Json, Router}; +use serde_json::Value; +use tower::ServiceBuilder; +use tower_http::{compression::CompressionLayer, decompression::RequestDecompressionLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[cfg(test)] +mod tests; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let app: Router = app(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +fn app() -> Router { + Router::new().route("/", post(root)).layer( + ServiceBuilder::new() + .layer(RequestDecompressionLayer::new()) + .layer(CompressionLayer::new()), + ) +} + +async fn root(Json(value): Json) -> Json { + Json(value) +} diff --git a/examples/compression/src/tests.rs b/examples/compression/src/tests.rs new file mode 100644 index 0000000000..c91ccaa649 --- /dev/null +++ b/examples/compression/src/tests.rs @@ -0,0 +1,245 @@ +use assert_json_diff::assert_json_eq; +use axum::{ + body::{Body, Bytes}, + response::Response, +}; +use brotli::enc::BrotliEncoderParams; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use http::{header, StatusCode}; +use serde_json::{json, Value}; +use std::io::{Read, Write}; +use tower::ServiceExt; + +use super::*; + +#[tokio::test] +async fn handle_uncompressed_request_bodies() { + // Given + + let body = json(); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .body(json_body(&body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_gzip_request_bodies() { + // Given + + let body = compress_gzip(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "gzip") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_br_request_bodies() { + // Given + + let body = compress_br(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "br") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_zstd_request_bodies() { + // Given + + let body = compress_zstd(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "zstd") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn do_not_compress_response_bodies() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn compress_response_bodies_with_gzip() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "gzip") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let mut decoder = GzDecoder::new(response_body.as_ref()); + let mut decompress_body = String::new(); + decoder.read_to_string(&mut decompress_body).unwrap(); + assert_json_eq!( + serde_json::from_str::(&decompress_body).unwrap(), + json() + ); +} + +#[tokio::test] +async fn compress_response_bodies_with_br() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "br") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let mut decompress_body = Vec::new(); + brotli::BrotliDecompress(&mut response_body.as_ref(), &mut decompress_body).unwrap(); + assert_json_eq!( + serde_json::from_slice::(&decompress_body).unwrap(), + json() + ); +} + +#[tokio::test] +async fn compress_response_bodies_with_zstd() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "zstd") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let decompress_body = zstd::stream::decode_all(std::io::Cursor::new(response_body)).unwrap(); + assert_json_eq!( + serde_json::from_slice::(&decompress_body).unwrap(), + json() + ); +} + +fn json() -> Value { + json!({ + "name": "foo", + "mainProduct": { + "typeId": "product", + "id": "p1" + }, + }) +} + +fn json_body(input: &Value) -> Body { + Body::from(serde_json::to_vec(&input).unwrap()) +} + +async fn json_from_response(response: Response) -> Value { + let body = byte_from_response(response).await; + body_as_json(body) +} + +async fn byte_from_response(response: Response) -> Bytes { + axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() +} + +fn body_as_json(body: Bytes) -> Value { + serde_json::from_slice(body.as_ref()).unwrap() +} + +fn compress_gzip(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&request_body).unwrap(); + encoder.finish().unwrap() +} + +fn compress_br(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + let mut result = Vec::new(); + + let params = BrotliEncoderParams::default(); + let _ = brotli::enc::BrotliCompress(&mut &request_body[..], &mut result, ¶ms).unwrap(); + + result +} + +fn compress_zstd(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + zstd::stream::encode_all(std::io::Cursor::new(request_body), 4).unwrap() +} diff --git a/examples/consume-body-in-extractor-or-middleware/Cargo.toml b/examples/consume-body-in-extractor-or-middleware/Cargo.toml index 9aeb864d61..6688588582 100644 --- a/examples/consume-body-in-extractor-or-middleware/Cargo.toml +++ b/examples/consume-body-in-extractor-or-middleware/Cargo.toml @@ -7,9 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = "1.0.0" tokio = { version = "1.0", features = ["full"] } -tower = "0.4" -tower-http = { version = "0.5.0", features = ["map-request-body", "util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 107edb6f1b..3239d6ac6d 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, body::{Body, Bytes}, extract::{FromRequest, Request}, http::StatusCode, @@ -22,7 +21,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_consume_body_in_extractor_or_middleware=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -50,7 +49,7 @@ async fn print_request_body(request: Request, next: Next) -> Result Result { let (parts, body) = request.into_parts(); - // this wont work if the body is an long running stream + // this won't work if the body is an long running stream let bytes = body .collect() .await @@ -74,7 +73,6 @@ async fn handler(BufferRequestBody(body): BufferRequestBody) { struct BufferRequestBody(Bytes); // we must implement `FromRequest` (and not `FromRequestParts`) to consume the body -#[async_trait] impl FromRequest for BufferRequestBody where S: Send + Sync, diff --git a/examples/cors/Cargo.toml b/examples/cors/Cargo.toml index 5d5d2edae5..654538fa22 100644 --- a/examples/cors/Cargo.toml +++ b/examples/cors/Cargo.toml @@ -7,4 +7,4 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["cors"] } +tower-http = { version = "0.6.1", features = ["cors"] } diff --git a/examples/customize-extractor-error/src/custom_extractor.rs b/examples/customize-extractor-error/src/custom_extractor.rs index 3611fba796..4f75fb440d 100644 --- a/examples/customize-extractor-error/src/custom_extractor.rs +++ b/examples/customize-extractor-error/src/custom_extractor.rs @@ -5,7 +5,6 @@ //! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Complexity: Manually implementing `FromRequest` results on more complex code use axum::{ - async_trait, extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request}, http::StatusCode, response::IntoResponse, @@ -20,7 +19,6 @@ pub async fn handler(Json(value): Json) -> impl IntoResponse { // We define our own `Json` extractor that customizes the error from `axum::Json` pub struct Json(pub T); -#[async_trait] impl FromRequest for Json where axum::Json: FromRequest, diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index e8820326f9..48188352e5 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -16,7 +16,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_customize_extractor_error=trace".into()), + .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/customize-path-rejection/Cargo.toml b/examples/customize-path-rejection/Cargo.toml index 8f5b1e4487..c1c4884d43 100644 --- a/examples/customize-path-rejection/Cargo.toml +++ b/examples/customize-path-rejection/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index c5f0ef9eb6..e784a969b8 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::IntoResponse, @@ -20,7 +19,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_customize_path_rejection=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -49,7 +48,6 @@ struct Params { // We define our own `Path` extractor that customizes the error from `axum::extract::Path` struct Path(T); -#[async_trait] impl FromRequestParts for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` diff --git a/examples/dependency-injection/src/main.rs b/examples/dependency-injection/src/main.rs index dc4ce165e7..7a4719e768 100644 --- a/examples/dependency-injection/src/main.rs +++ b/examples/dependency-injection/src/main.rs @@ -25,7 +25,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_dependency_injection=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/diesel-async-postgres/Cargo.toml b/examples/diesel-async-postgres/Cargo.toml index d86db1516d..efec344044 100644 --- a/examples/diesel-async-postgres/Cargo.toml +++ b/examples/diesel-async-postgres/Cargo.toml @@ -8,9 +8,8 @@ publish = false axum = { path = "../../axum", features = ["macros"] } bb8 = "0.8" diesel = "2" -diesel-async = { version = "0.3", features = ["postgres", "bb8"] } +diesel-async = { version = "0.5", features = ["postgres", "bb8"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/diesel-async-postgres/src/main.rs b/examples/diesel-async-postgres/src/main.rs index ee42ac1002..44fbb54643 100644 --- a/examples/diesel-async-postgres/src/main.rs +++ b/examples/diesel-async-postgres/src/main.rs @@ -13,7 +13,6 @@ //! for a real world application using axum and diesel use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, response::Json, @@ -57,7 +56,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_diesel_async_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -102,7 +101,6 @@ struct DatabaseConnection( bb8::PooledConnection<'static, AsyncDieselConnectionManager>, ); -#[async_trait] impl FromRequestParts for DatabaseConnection where S: Send + Sync, diff --git a/examples/diesel-postgres/Cargo.toml b/examples/diesel-postgres/Cargo.toml index ff42a0db68..a68b9df89f 100644 --- a/examples/diesel-postgres/Cargo.toml +++ b/examples/diesel-postgres/Cargo.toml @@ -6,11 +6,10 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["macros"] } -deadpool-diesel = { version = "0.4.1", features = ["postgres"] } +deadpool-diesel = { version = "0.6.1", features = ["postgres"] } diesel = { version = "2", features = ["postgres"] } diesel_migrations = "2" serde = { version = "1.0", features = ["derive"] } -serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/diesel-postgres/src/main.rs b/examples/diesel-postgres/src/main.rs index 605660d073..0c5852d20b 100644 --- a/examples/diesel-postgres/src/main.rs +++ b/examples/diesel-postgres/src/main.rs @@ -54,7 +54,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/error-handling/Cargo.toml b/examples/error-handling/Cargo.toml index 26fc3b98ee..7aebc903b8 100644 --- a/examples/error-handling/Cargo.toml +++ b/examples/error-handling/Cargo.toml @@ -8,6 +8,6 @@ publish = false axum = { path = "../../axum", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs index 6981f59eee..0ad9f43cfa 100644 --- a/examples/error-handling/src/main.rs +++ b/examples/error-handling/src/main.rs @@ -45,8 +45,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_error_handling=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/form/src/main.rs b/examples/form/src/main.rs index 3f9ed09560..02ea23525b 100644 --- a/examples/form/src/main.rs +++ b/examples/form/src/main.rs @@ -13,7 +13,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_form=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/global-404-handler/Cargo.toml b/examples/global-404-handler/Cargo.toml index 9848d9e830..a453cab57b 100644 --- a/examples/global-404-handler/Cargo.toml +++ b/examples/global-404-handler/Cargo.toml @@ -7,6 +7,5 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs index 38b029439b..bf1d8a95ac 100644 --- a/examples/global-404-handler/src/main.rs +++ b/examples/global-404-handler/src/main.rs @@ -17,7 +17,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_global_404_handler=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/graceful-shutdown/Cargo.toml b/examples/graceful-shutdown/Cargo.toml index 86dfd52763..c7a5727423 100644 --- a/examples/graceful-shutdown/Cargo.toml +++ b/examples/graceful-shutdown/Cargo.toml @@ -6,10 +6,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["tracing"] } -hyper = { version = "1.0", features = [] } -hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5", features = ["timeout", "trace"] } -tracing = "0.1" +tower-http = { version = "0.6.1", features = ["timeout", "trace"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/graceful-shutdown/src/main.rs b/examples/graceful-shutdown/src/main.rs index d3388c8359..533cf8f145 100644 --- a/examples/graceful-shutdown/src/main.rs +++ b/examples/graceful-shutdown/src/main.rs @@ -21,7 +21,11 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { - "example_graceful_shutdown=debug,tower_http=debug,axum=trace".into() + format!( + "{}=debug,tower_http=debug,axum=trace", + env!("CARGO_CRATE_NAME") + ) + .into() }), ) .with(tracing_subscriber::fmt::layer().without_time()) diff --git a/examples/handle-head-request/Cargo.toml b/examples/handle-head-request/Cargo.toml index 83a8a66e25..8497b08957 100644 --- a/examples/handle-head-request/Cargo.toml +++ b/examples/handle-head-request/Cargo.toml @@ -11,4 +11,4 @@ tokio = { version = "1.0", features = ["full"] } [dev-dependencies] http-body-util = "0.1.0" hyper = { version = "1.0.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/http-proxy/Cargo.toml b/examples/http-proxy/Cargo.toml index aa6070020a..8dc2f19539 100644 --- a/examples/http-proxy/Cargo.toml +++ b/examples/http-proxy/Cargo.toml @@ -9,6 +9,6 @@ axum = { path = "../../axum" } hyper = { version = "1", features = ["full"] } hyper-util = "0.1.1" tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["make"] } +tower = { version = "0.5.1", features = ["make", "util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs index b60ed03daa..90aa5aa817 100644 --- a/examples/http-proxy/src/main.rs +++ b/examples/http-proxy/src/main.rs @@ -36,8 +36,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=trace,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/jwt/Cargo.toml b/examples/jwt/Cargo.toml index b0c76c25d1..54378fe3ee 100644 --- a/examples/jwt/Cargo.toml +++ b/examples/jwt/Cargo.toml @@ -7,7 +7,7 @@ publish = false [dependencies] axum = { path = "../../axum" } axum-extra = { path = "../../axum-extra", features = ["typed-header"] } -jsonwebtoken = "8.0" +jsonwebtoken = "9.3" once_cell = "1.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 85211851b2..8b7a7cbe6b 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -7,7 +7,6 @@ //! ``` use axum::{ - async_trait, extract::FromRequestParts, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, @@ -61,7 +60,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_jwt=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -122,7 +121,6 @@ impl AuthBody { } } -#[async_trait] impl FromRequestParts for Claims where S: Send + Sync, diff --git a/examples/key-value-store/Cargo.toml b/examples/key-value-store/Cargo.toml index c23b28d268..ccd28c2558 100644 --- a/examples/key-value-store/Cargo.toml +++ b/examples/key-value-store/Cargo.toml @@ -7,14 +7,13 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] } -tower-http = { version = "0.5.0", features = [ +tower = { version = "0.5.1", features = ["util", "timeout", "load-shed", "limit"] } +tower-http = { version = "0.6.1", features = [ "add-extension", "auth", "compression-full", "limit", "trace", ] } -tower-layer = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index d8713a2cd7..c2b3f51cda 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -33,8 +33,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_key_value_store=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/listen-multiple-addrs/Cargo.toml b/examples/listen-multiple-addrs/Cargo.toml index 8940b94332..ed146ca5e7 100644 --- a/examples/listen-multiple-addrs/Cargo.toml +++ b/examples/listen-multiple-addrs/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "listen-multiple-addrs" +name = "example-listen-multiple-addrs" version = "0.1.0" edition = "2021" publish = false @@ -9,4 +9,4 @@ axum = { path = "../../axum" } hyper = { version = "1.0.0", features = ["full"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/low-level-native-tls/Cargo.toml b/examples/low-level-native-tls/Cargo.toml new file mode 100644 index 0000000000..eee80081c9 --- /dev/null +++ b/examples/low-level-native-tls/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "example-low-level-native-tls" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +futures-util = { version = "0.3", default-features = false } +hyper = { version = "1.0.0", features = ["full"] } +hyper-util = { version = "0.1" } +tokio = { version = "1", features = ["full"] } +tokio-native-tls = "0.3.1" +tower = { version = "0.5.1", features = ["make"] } +tower-service = "0.3.2" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/low-level-native-tls/self_signed_certs/cert.pem b/examples/low-level-native-tls/self_signed_certs/cert.pem new file mode 100644 index 0000000000..656aa88055 --- /dev/null +++ b/examples/low-level-native-tls/self_signed_certs/cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 +daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 +kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq +dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT +bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 +J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK +NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 +yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W +ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU +XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg ++MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 +Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 +fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr +AopvZ09uEQ== +-----END CERTIFICATE----- diff --git a/examples/low-level-native-tls/self_signed_certs/key.pem b/examples/low-level-native-tls/self_signed_certs/key.pem new file mode 100644 index 0000000000..3de14eb32f --- /dev/null +++ b/examples/low-level-native-tls/self_signed_certs/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD +BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS +tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw +RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l +YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t +HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV +W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB +12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI ++2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw +zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt +fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty +RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT +ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 +myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ +XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY +5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD +ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD +ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 +k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 +wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV +5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg +3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa +Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C +rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m +y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW +w37pCS7ykL+7gp7V0WShYsw= +-----END PRIVATE KEY----- diff --git a/examples/low-level-native-tls/src/main.rs b/examples/low-level-native-tls/src/main.rs new file mode 100644 index 0000000000..d676238dfa --- /dev/null +++ b/examples/low-level-native-tls/src/main.rs @@ -0,0 +1,101 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-low-level-native-tls +//! ``` + +use axum::{extract::Request, routing::get, Router}; +use futures_util::pin_mut; +use hyper::body::Incoming; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use std::path::PathBuf; +use tokio::net::TcpListener; +use tokio_native_tls::{ + native_tls::{Identity, Protocol, TlsAcceptor as NativeTlsAcceptor}, + TlsAcceptor, +}; +use tower_service::Service; +use tracing::{error, info, warn}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_low_level_rustls=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let tls_acceptor = native_tls_acceptor( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("key.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("cert.pem"), + ); + + let tls_acceptor = TlsAcceptor::from(tls_acceptor); + let bind = "[::1]:3000"; + let tcp_listener = TcpListener::bind(bind).await.unwrap(); + info!("HTTPS server listening on {bind}. To contact curl -k https://localhost:3000"); + let app = Router::new().route("/", get(handler)); + + pin_mut!(tcp_listener); + loop { + let tower_service = app.clone(); + let tls_acceptor = tls_acceptor.clone(); + + // Wait for new tcp connection + let (cnx, addr) = tcp_listener.accept().await.unwrap(); + + tokio::spawn(async move { + // Wait for tls handshake to happen + let Ok(stream) = tls_acceptor.accept(cnx).await else { + error!("error during tls handshake connection from {}", addr); + return; + }; + + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. + // `TokioIo` converts between them. + let stream = TokioIo::new(stream); + + // Hyper also has its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = hyper::service::service_fn(move |request: Request| { + // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas + // tower's `Service` requires `&mut self`. + // + // We don't need to call `poll_ready` since `Router` is always ready. + tower_service.clone().call(request) + }); + + let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(stream, hyper_service) + .await; + + if let Err(err) = ret { + warn!("error serving connection from {addr}: {err}"); + } + }); + } +} + +async fn handler() -> &'static str { + "Hello, World!" +} + +fn native_tls_acceptor(key_file: PathBuf, cert_file: PathBuf) -> NativeTlsAcceptor { + let key_pem = std::fs::read_to_string(&key_file).unwrap(); + let cert_pem = std::fs::read_to_string(&cert_file).unwrap(); + + let id = Identity::from_pkcs8(cert_pem.as_bytes(), key_pem.as_bytes()).unwrap(); + NativeTlsAcceptor::builder(id) + // let's be modern + .min_protocol_version(Some(Protocol::Tlsv12)) + .build() + .unwrap() +} diff --git a/examples/low-level-openssl/Cargo.toml b/examples/low-level-openssl/Cargo.toml index c5247dec9c..a74a950e56 100644 --- a/examples/low-level-openssl/Cargo.toml +++ b/examples/low-level-openssl/Cargo.toml @@ -12,6 +12,6 @@ hyper-util = { version = "0.1" } openssl = "0.10" tokio = { version = "1", features = ["full"] } tokio-openssl = "0.6" -tower = { version = "0.4", features = ["make"] } +tower = { version = "0.5.1", features = ["make"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/low-level-openssl/src/main.rs b/examples/low-level-openssl/src/main.rs index f2839d61a9..7c483010c5 100644 --- a/examples/low-level-openssl/src/main.rs +++ b/examples/low-level-openssl/src/main.rs @@ -15,7 +15,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_low_level_openssl=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/low-level-rustls/Cargo.toml b/examples/low-level-rustls/Cargo.toml index 3975fcb917..1eaf04b3f7 100644 --- a/examples/low-level-rustls/Cargo.toml +++ b/examples/low-level-rustls/Cargo.toml @@ -12,7 +12,6 @@ hyper-util = { version = "0.1" } rustls-pemfile = "1.0.4" tokio = { version = "1", features = ["full"] } tokio-rustls = "0.24.1" -tower = { version = "0.4", features = ["make"] } tower-service = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/low-level-rustls/src/main.rs b/examples/low-level-rustls/src/main.rs index 660225d7cf..7f5c994e50 100644 --- a/examples/low-level-rustls/src/main.rs +++ b/examples/low-level-rustls/src/main.rs @@ -29,7 +29,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_low_level_rustls=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/mongodb/Cargo.toml b/examples/mongodb/Cargo.toml new file mode 100644 index 0000000000..c084a36f7d --- /dev/null +++ b/examples/mongodb/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-mongodb" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +mongodb = "3.1.0" +serde = { version = "1.0", features = ["derive"] } +tokio = { version = "1.0", features = ["full"] } +tower-http = { version = "0.6.1", features = ["add-extension", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/mongodb/src/main.rs b/examples/mongodb/src/main.rs new file mode 100644 index 0000000000..2cf25f4f31 --- /dev/null +++ b/examples/mongodb/src/main.rs @@ -0,0 +1,132 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-mongodb +//! ``` + +use axum::{ + extract::{Path, State}, + http::StatusCode, + routing::{delete, get, post, put}, + Json, Router, +}; +use mongodb::{ + bson::doc, + results::{DeleteResult, InsertOneResult, UpdateResult}, + Client, Collection, +}; +use serde::{Deserialize, Serialize}; +use tower_http::trace::TraceLayer; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + // connecting to mongodb + let db_connection_str = std::env::var("DATABASE_URL").unwrap_or_else(|_| { + "mongodb://admin:password@127.0.0.1:27017/?authSource=admin".to_string() + }); + let client = Client::with_uri_str(db_connection_str).await.unwrap(); + + // pinging the database + client + .database("axum-mongo") + .run_command(doc! { "ping": 1 }) + .await + .unwrap(); + println!("Pinged your database. Successfully connected to MongoDB!"); + + // logging middleware + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("Listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app(client)).await.unwrap(); +} + +// defining routes and state +fn app(client: Client) -> Router { + let collection: Collection = client.database("axum-mongo").collection("members"); + + Router::new() + .route("/create", post(create_member)) + .route("/read/:id", get(read_member)) + .route("/update", put(update_member)) + .route("/delete/:id", delete(delete_member)) + .layer(TraceLayer::new_for_http()) + .with_state(collection) +} + +// handler to create a new member +async fn create_member( + State(db): State>, + Json(input): Json, +) -> Result, (StatusCode, String)> { + let result = db.insert_one(input).await.map_err(internal_error)?; + + Ok(Json(result)) +} + +// handler to read an existing member +async fn read_member( + State(db): State>, + Path(id): Path, +) -> Result>, (StatusCode, String)> { + let result = db + .find_one(doc! { "_id": id }) + .await + .map_err(internal_error)?; + + Ok(Json(result)) +} + +// handler to update an existing member +async fn update_member( + State(db): State>, + Json(input): Json, +) -> Result, (StatusCode, String)> { + let result = db + .replace_one(doc! { "_id": input.id }, input) + .await + .map_err(internal_error)?; + + Ok(Json(result)) +} + +// handler to delete an existing member +async fn delete_member( + State(db): State>, + Path(id): Path, +) -> Result, (StatusCode, String)> { + let result = db + .delete_one(doc! { "_id": id }) + .await + .map_err(internal_error)?; + + Ok(Json(result)) +} + +fn internal_error(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) +} + +// defining Member type +#[derive(Debug, Deserialize, Serialize)] +struct Member { + #[serde(rename = "_id")] + id: u32, + name: String, + active: bool, +} diff --git a/examples/multipart-form/Cargo.toml b/examples/multipart-form/Cargo.toml index d93b9c08e8..143154e89d 100644 --- a/examples/multipart-form/Cargo.toml +++ b/examples/multipart-form/Cargo.toml @@ -7,6 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["multipart"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["limit", "trace"] } +tower-http = { version = "0.6.1", features = ["limit", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/multipart-form/src/main.rs b/examples/multipart-form/src/main.rs index ecf5191f2a..30fcfc70a9 100644 --- a/examples/multipart-form/src/main.rs +++ b/examples/multipart-form/src/main.rs @@ -17,8 +17,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_multipart_form=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 659ce2619c..30d7d41c88 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -11,7 +11,6 @@ use anyhow::{Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ - async_trait, extract::{FromRef, FromRequestParts, Query, State}, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, @@ -35,7 +34,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_oauth=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -254,7 +253,6 @@ impl IntoResponse for AuthRedirect { } } -#[async_trait] impl FromRequestParts for User where MemoryStore: FromRef, diff --git a/examples/parse-body-based-on-content-type/src/main.rs b/examples/parse-body-based-on-content-type/src/main.rs index bae4ec1d29..1e4fc1ac43 100644 --- a/examples/parse-body-based-on-content-type/src/main.rs +++ b/examples/parse-body-based-on-content-type/src/main.rs @@ -7,7 +7,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRequest, Request}, http::{header::CONTENT_TYPE, StatusCode}, response::{IntoResponse, Response}, @@ -22,7 +21,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { - "example_parse_body_based_on_content_type=debug,tower_http=debug".into() + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) @@ -48,7 +47,6 @@ async fn handler(JsonOrForm(payload): JsonOrForm) { struct JsonOrForm(T); -#[async_trait] impl FromRequest for JsonOrForm where S: Send + Sync, diff --git a/examples/print-request-response/Cargo.toml b/examples/print-request-response/Cargo.toml index a314b5b7fe..d6e064bb63 100644 --- a/examples/print-request-response/Cargo.toml +++ b/examples/print-request-response/Cargo.toml @@ -7,8 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = { version = "1.0.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util", "filter"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/print-request-response/src/main.rs b/examples/print-request-response/src/main.rs index 5e0d4d1d97..84f14f2d50 100644 --- a/examples/print-request-response/src/main.rs +++ b/examples/print-request-response/src/main.rs @@ -20,8 +20,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_print_request_response=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/prometheus-metrics/Cargo.toml b/examples/prometheus-metrics/Cargo.toml index a30e443659..56ccdd05b0 100644 --- a/examples/prometheus-metrics/Cargo.toml +++ b/examples/prometheus-metrics/Cargo.toml @@ -6,8 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -metrics = { version = "0.22", default-features = false } -metrics-exporter-prometheus = { version = "0.13", default-features = false } +metrics = { version = "0.23", default-features = false } +metrics-exporter-prometheus = { version = "0.15", default-features = false } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/prometheus-metrics/src/main.rs b/examples/prometheus-metrics/src/main.rs index bb0e9c9c37..fe76121ce9 100644 --- a/examples/prometheus-metrics/src/main.rs +++ b/examples/prometheus-metrics/src/main.rs @@ -63,8 +63,9 @@ async fn start_metrics_server() { async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_prometheus_metrics=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/query-params-with-empty-strings/Cargo.toml b/examples/query-params-with-empty-strings/Cargo.toml index 7a52e98d1e..6dde9a3ac6 100644 --- a/examples/query-params-with-empty-strings/Cargo.toml +++ b/examples/query-params-with-empty-strings/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = "1.0.0" serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/readme/Cargo.toml b/examples/readme/Cargo.toml index 4a79c9bb88..17669567da 100644 --- a/examples/readme/Cargo.toml +++ b/examples/readme/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0.68" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/request-id/Cargo.toml b/examples/request-id/Cargo.toml new file mode 100644 index 0000000000..22879e0824 --- /dev/null +++ b/examples/request-id/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-request-id" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +tokio = { version = "1.0", features = ["full"] } +tower = "0.5" +tower-http = { version = "0.5", features = ["request-id", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/request-id/src/main.rs b/examples/request-id/src/main.rs new file mode 100644 index 0000000000..552d8d4a81 --- /dev/null +++ b/examples/request-id/src/main.rs @@ -0,0 +1,81 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-request-id +//! ``` + +use axum::{ + http::{HeaderName, Request}, + response::Html, + routing::get, + Router, +}; +use tower::ServiceBuilder; +use tower_http::{ + request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, + trace::TraceLayer, +}; +use tracing::{error, info, info_span}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const REQUEST_ID_HEADER: &str = "x-request-id"; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + // axum logs rejections from built-in extractors with the `axum::rejection` + // target, at `TRACE` level. `axum::rejection=trace` enables showing those events + format!( + "{}=debug,tower_http=debug,axum::rejection=trace", + env!("CARGO_CRATE_NAME") + ) + .into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let x_request_id = HeaderName::from_static(REQUEST_ID_HEADER); + + let middleware = ServiceBuilder::new() + .layer(SetRequestIdLayer::new( + x_request_id.clone(), + MakeRequestUuid, + )) + .layer( + TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { + // Log the request id as generated. + let request_id = request.headers().get(REQUEST_ID_HEADER); + + match request_id { + Some(request_id) => info_span!( + "http_request", + request_id = ?request_id, + ), + None => { + error!("could not extract request_id"); + info_span!("http_request") + } + } + }), + ) + // send headers from request to response headers + .layer(PropagateRequestIdLayer::new(x_request_id)); + + // build our application with a route + let app = Router::new().route("/", get(handler)).layer(middleware); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn handler() -> Html<&'static str> { + info!("Hello world!"); + Html("

Hello, World!

") +} diff --git a/examples/reqwest-response/Cargo.toml b/examples/reqwest-response/Cargo.toml index cfb91bdcc4..3ea740e3cb 100644 --- a/examples/reqwest-response/Cargo.toml +++ b/examples/reqwest-response/Cargo.toml @@ -9,6 +9,6 @@ axum = { path = "../../axum" } reqwest = { version = "0.12", features = ["stream"] } tokio = { version = "1.0", features = ["full"] } tokio-stream = "0.1" -tower-http = { version = "0.5.0", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/reqwest-response/src/main.rs b/examples/reqwest-response/src/main.rs index b6dfeb70b5..a1ba33cf26 100644 --- a/examples/reqwest-response/src/main.rs +++ b/examples/reqwest-response/src/main.rs @@ -4,18 +4,16 @@ //! cargo run -p example-reqwest-response //! ``` -use std::{convert::Infallible, time::Duration}; - -use axum::http::{HeaderMap, StatusCode}; use axum::{ body::{Body, Bytes}, extract::State, - http::{HeaderName, HeaderValue}, + http::StatusCode, response::{IntoResponse, Response}, routing::get, Router, }; use reqwest::Client; +use std::{convert::Infallible, time::Duration}; use tokio_stream::StreamExt; use tower_http::trace::TraceLayer; use tracing::Span; @@ -25,8 +23,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_reqwest_response=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -34,7 +33,7 @@ async fn main() { let client = Client::new(); let app = Router::new() - .route("/", get(proxy_via_reqwest)) + .route("/", get(stream_reqwest_response)) .route("/stream", get(stream_some_data)) // Add some logging so we can see the streams going through .layer(TraceLayer::new_for_http().on_body_chunk( @@ -51,7 +50,7 @@ async fn main() { axum::serve(listener, app).await.unwrap(); } -async fn proxy_via_reqwest(State(client): State) -> Response { +async fn stream_reqwest_response(State(client): State) -> Response { let reqwest_response = match client.get("http://127.0.0.1:3000/stream").send().await { Ok(res) => res, Err(err) => { @@ -60,16 +59,8 @@ async fn proxy_via_reqwest(State(client): State) -> Response { } }; - let response_builder = Response::builder().status(reqwest_response.status().as_u16()); - - // Here the mapping of headers is required due to reqwest and axum differ on the http crate versions - let mut headers = HeaderMap::with_capacity(reqwest_response.headers().len()); - headers.extend(reqwest_response.headers().into_iter().map(|(name, value)| { - let name = HeaderName::from_bytes(name.as_ref()).unwrap(); - let value = HeaderValue::from_bytes(value.as_ref()).unwrap(); - (name, value) - })); - + let mut response_builder = Response::builder().status(reqwest_response.status()); + *response_builder.headers_mut().unwrap() = reqwest_response.headers().clone(); response_builder .body(Body::from_stream(reqwest_response.bytes_stream())) // This unwrap is fine because the body is empty here diff --git a/examples/rest-grpc-multiplex/Cargo.toml b/examples/rest-grpc-multiplex/Cargo.toml index 11a6a3a2b4..69ece3632e 100644 --- a/examples/rest-grpc-multiplex/Cargo.toml +++ b/examples/rest-grpc-multiplex/Cargo.toml @@ -8,13 +8,13 @@ publish = false axum = { path = "../../axum" } futures = "0.3" hyper = { version = "1.0.0", features = ["full"] } -prost = "0.11" -tokio = { version = "1", features = ["full"] } -tonic = { version = "0.9" } -tonic-reflection = "0.9" -tower = { version = "0.4", features = ["full"] } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } +#prost = "0.11" +#tokio = { version = "1", features = ["full"] } +#tonic = { version = "0.9" } +#tonic-reflection = "0.9" +tower = { version = "0.5.1", features = ["full"] } +#tracing = "0.1" +#tracing-subscriber = { version = "0.3", features = ["env-filter"] } [build-dependencies] tonic-build = { version = "0.9", features = ["prost"] } diff --git a/examples/rest-grpc-multiplex/src/multiplex_service.rs b/examples/rest-grpc-multiplex/src/multiplex_service.rs index 80b612e12e..51550ec5ba 100644 --- a/examples/rest-grpc-multiplex/src/multiplex_service.rs +++ b/examples/rest-grpc-multiplex/src/multiplex_service.rs @@ -38,7 +38,7 @@ where Self { rest: self.rest.clone(), grpc: self.grpc.clone(), - // the cloned services probably wont be ready + // the cloned services probably won't be ready rest_ready: false, grpc_ready: false, } diff --git a/examples/serve-with-hyper/Cargo.toml b/examples/serve-with-hyper/Cargo.toml index 06f4607053..81553eb08b 100644 --- a/examples/serve-with-hyper/Cargo.toml +++ b/examples/serve-with-hyper/Cargo.toml @@ -9,4 +9,4 @@ axum = { path = "../../axum" } hyper = { version = "1.0", features = [] } hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/serve-with-hyper/src/main.rs b/examples/serve-with-hyper/src/main.rs index 8aaab9b047..9da67fc1f8 100644 --- a/examples/serve-with-hyper/src/main.rs +++ b/examples/serve-with-hyper/src/main.rs @@ -11,6 +11,8 @@ //! //! [hyper-util]: https://crates.io/crates/hyper-util +#![allow(unreachable_patterns)] + use std::convert::Infallible; use std::net::SocketAddr; @@ -43,7 +45,7 @@ async fn serve_plain() { // We don't need to call `poll_ready` because `Router` is always ready. let tower_service = app.clone(); - // Spawn a task to handle the connection. That way we can multiple connections + // Spawn a task to handle the connection. That way we can handle multiple connections // concurrently. tokio::spawn(async move { // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. diff --git a/examples/simple-router-wasm/Cargo.toml b/examples/simple-router-wasm/Cargo.toml index c3041e4a58..d250cc2387 100644 --- a/examples/simple-router-wasm/Cargo.toml +++ b/examples/simple-router-wasm/Cargo.toml @@ -14,3 +14,6 @@ axum-extra = { path = "../../axum-extra", default-features = false } futures-executor = "0.3.21" http = "1.0.0" tower-service = "0.3.1" + +[package.metadata.cargo-machete] +ignored = ["axum-extra"] diff --git a/examples/sqlx-postgres/Cargo.toml b/examples/sqlx-postgres/Cargo.toml index 3bc40302ed..0a0c437630 100644 --- a/examples/sqlx-postgres/Cargo.toml +++ b/examples/sqlx-postgres/Cargo.toml @@ -10,4 +10,4 @@ tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "any", "postgres"] } +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any", "postgres"] } diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 465711157e..904a5a8aad 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -14,7 +14,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, @@ -31,7 +30,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -75,7 +74,6 @@ async fn using_connection_pool_extractor( // which setup is appropriate depends on your application struct DatabaseConnection(sqlx::pool::PoolConnection); -#[async_trait] impl FromRequestParts for DatabaseConnection where PgPool: FromRef, diff --git a/examples/sse/Cargo.toml b/examples/sse/Cargo.toml index b2b33159fe..138820db16 100644 --- a/examples/sse/Cargo.toml +++ b/examples/sse/Cargo.toml @@ -11,11 +11,11 @@ futures = "0.3" headers = "0.4" tokio = { version = "1.0", features = ["full"] } tokio-stream = "0.1" -tower-http = { version = "0.5.0", features = ["fs", "trace"] } +tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] eventsource-stream = "0.2" reqwest = { version = "0.12", features = ["stream"] } -reqwest-eventsource = "0.5" +reqwest-eventsource = "0.6" diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 53f7dc49b8..4f616f6b05 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -24,8 +24,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_sse=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -96,7 +97,7 @@ mod tests { let listening_url = spawn_app("127.0.0.1").await; let mut event_stream = reqwest::Client::new() - .get(&format!("{}/sse", listening_url)) + .get(format!("{}/sse", listening_url)) .header("User-Agent", "integration_test") .send() .await diff --git a/examples/static-file-server/Cargo.toml b/examples/static-file-server/Cargo.toml index 3f41d60816..ce1955432f 100644 --- a/examples/static-file-server/Cargo.toml +++ b/examples/static-file-server/Cargo.toml @@ -6,9 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-extra = { path = "../../axum-extra" } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5.0", features = ["fs", "trace"] } +tower = { version = "0.5.1", features = ["util"] } +tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index 707d2ee3f3..148af57c04 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -19,8 +19,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_static_file_server=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/stream-to-file/src/main.rs b/examples/stream-to-file/src/main.rs index f016993270..7c44286d87 100644 --- a/examples/stream-to-file/src/main.rs +++ b/examples/stream-to-file/src/main.rs @@ -25,7 +25,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_stream_to_file=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/templates-minijinja/Cargo.toml b/examples/templates-minijinja/Cargo.toml index 9b525a8c29..692ea0ca12 100644 --- a/examples/templates-minijinja/Cargo.toml +++ b/examples/templates-minijinja/Cargo.toml @@ -6,5 +6,5 @@ publish = false [dependencies] axum = { path = "../../axum" } -minijinja = "1.0.11" +minijinja = "2.3.1" tokio = { version = "1.0", features = ["full"] } diff --git a/examples/templates/Cargo.toml b/examples/templates/Cargo.toml index 2f5aba2791..6cba09469f 100644 --- a/examples/templates/Cargo.toml +++ b/examples/templates/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" publish = false [dependencies] -askama = "0.11" +askama = "0.12" axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" diff --git a/examples/templates/src/main.rs b/examples/templates/src/main.rs index 36d9e68e4c..872471c235 100644 --- a/examples/templates/src/main.rs +++ b/examples/templates/src/main.rs @@ -19,7 +19,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_templates=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/testing-websockets/Cargo.toml b/examples/testing-websockets/Cargo.toml index 842624c9e5..31ed2601f0 100644 --- a/examples/testing-websockets/Cargo.toml +++ b/examples/testing-websockets/Cargo.toml @@ -7,6 +7,5 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["ws"] } futures = "0.3" -hyper = { version = "1.0.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] } -tokio-tungstenite = "0.21" +tokio-tungstenite = "0.24" diff --git a/examples/testing/Cargo.toml b/examples/testing/Cargo.toml index 00e8132f73..811e4f6056 100644 --- a/examples/testing/Cargo.toml +++ b/examples/testing/Cargo.toml @@ -7,14 +7,13 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = { version = "1.0.0", features = ["full"] } hyper-util = { version = "0.1", features = ["client", "http1", "client-legacy"] } mime = "0.3" serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index 9e33027e22..c6dbcf5c16 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -18,8 +18,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_testing=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/tls-graceful-shutdown/Cargo.toml b/examples/tls-graceful-shutdown/Cargo.toml index 40e489030a..15b3b73b1f 100644 --- a/examples/tls-graceful-shutdown/Cargo.toml +++ b/examples/tls-graceful-shutdown/Cargo.toml @@ -6,8 +6,7 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-server = { version = "0.6", features = ["tls-rustls"] } -hyper = { version = "0.14", features = ["full"] } +axum-server = { version = "0.7", features = ["tls-rustls"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tls-graceful-shutdown/src/main.rs b/examples/tls-graceful-shutdown/src/main.rs index cc5b6ecb8d..f42ac435b6 100644 --- a/examples/tls-graceful-shutdown/src/main.rs +++ b/examples/tls-graceful-shutdown/src/main.rs @@ -28,7 +28,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/tls-rustls/Cargo.toml b/examples/tls-rustls/Cargo.toml index 4c255c2763..9b976f160e 100644 --- a/examples/tls-rustls/Cargo.toml +++ b/examples/tls-rustls/Cargo.toml @@ -6,7 +6,7 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-server = { version = "0.6", features = ["tls-rustls"] } +axum-server = { version = "0.7", features = ["tls-rustls"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs index 3649c75cb1..88da6e535b 100644 --- a/examples/tls-rustls/src/main.rs +++ b/examples/tls-rustls/src/main.rs @@ -30,7 +30,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tls_rustls=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/todos/Cargo.toml b/examples/todos/Cargo.toml index dbd8b7125a..127fbee47c 100644 --- a/examples/todos/Cargo.toml +++ b/examples/todos/Cargo.toml @@ -8,8 +8,8 @@ publish = false axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util", "timeout"] } -tower-http = { version = "0.5.0", features = ["add-extension", "trace"] } +tower = { version = "0.5.1", features = ["util", "timeout"] } +tower-http = { version = "0.6.1", features = ["add-extension", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.0", features = ["serde", "v4"] } diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index da3cb4a1c2..6f115daf3c 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -36,8 +36,9 @@ use uuid::Uuid; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_todos=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/tokio-postgres/Cargo.toml b/examples/tokio-postgres/Cargo.toml index 74806044ca..e14520d23e 100644 --- a/examples/tokio-postgres/Cargo.toml +++ b/examples/tokio-postgres/Cargo.toml @@ -6,8 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -bb8 = "0.7.1" -bb8-postgres = "0.7.0" +bb8 = "0.8.5" +bb8-postgres = "0.8.1" tokio = { version = "1.0", features = ["full"] } tokio-postgres = "0.7.2" tracing = "0.1" diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index effc032089..7df9917b92 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, @@ -21,7 +20,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -68,7 +67,6 @@ async fn using_connection_pool_extractor( // which setup is appropriate depends on your application struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); -#[async_trait] impl FromRequestParts for DatabaseConnection where ConnectionPool: FromRef, diff --git a/examples/tokio-redis/Cargo.toml b/examples/tokio-redis/Cargo.toml index fb276849e8..86d8513b90 100644 --- a/examples/tokio-redis/Cargo.toml +++ b/examples/tokio-redis/Cargo.toml @@ -6,9 +6,9 @@ publish = false [dependencies] axum = { path = "../../axum" } -bb8 = "0.7.1" -bb8-redis = "0.14.0" -redis = "0.24.0" +bb8 = "0.8.5" +bb8-redis = "0.17.0" +redis = "0.27.2" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tokio-redis/src/main.rs b/examples/tokio-redis/src/main.rs index f0109f2127..105b1de46c 100644 --- a/examples/tokio-redis/src/main.rs +++ b/examples/tokio-redis/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, @@ -23,7 +22,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_redis=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -71,7 +70,6 @@ async fn using_connection_pool_extractor( // which setup is appropriate depends on your application struct DatabaseConnection(PooledConnection<'static, RedisConnectionManager>); -#[async_trait] impl FromRequestParts for DatabaseConnection where ConnectionPool: FromRef, diff --git a/examples/tracing-aka-logging/Cargo.toml b/examples/tracing-aka-logging/Cargo.toml index 4004cd596b..3d1204723d 100644 --- a/examples/tracing-aka-logging/Cargo.toml +++ b/examples/tracing-aka-logging/Cargo.toml @@ -7,6 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["tracing"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tracing-aka-logging/src/main.rs b/examples/tracing-aka-logging/src/main.rs index 74a2055a07..30c16f1962 100644 --- a/examples/tracing-aka-logging/src/main.rs +++ b/examples/tracing-aka-logging/src/main.rs @@ -25,7 +25,11 @@ async fn main() { tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { // axum logs rejections from built-in extractors with the `axum::rejection` // target, at `TRACE` level. `axum::rejection=trace` enables showing those events - "example_tracing_aka_logging=debug,tower_http=debug,axum::rejection=trace".into() + format!( + "{}=debug,tower_http=debug,axum::rejection=trace", + env!("CARGO_CRATE_NAME") + ) + .into() }), ) .with(tracing_subscriber::fmt::layer()) diff --git a/examples/unix-domain-socket/Cargo.toml b/examples/unix-domain-socket/Cargo.toml index 7f157c7dcb..94ceb04080 100644 --- a/examples/unix-domain-socket/Cargo.toml +++ b/examples/unix-domain-socket/Cargo.toml @@ -10,6 +10,5 @@ http-body-util = "0.1" hyper = { version = "1.0.0", features = ["full"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tracing = "0.1" +tower = { version = "0.5.1", features = ["util"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index d11792dd70..fbb4c3b067 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -3,6 +3,7 @@ //! ```not_rust //! cargo run -p example-unix-domain-socket //! ``` +#![allow(unreachable_patterns)] #[cfg(unix)] #[tokio::main] diff --git a/examples/validator/Cargo.toml b/examples/validator/Cargo.toml index a1adc075a8..8a7e6928d8 100644 --- a/examples/validator/Cargo.toml +++ b/examples/validator/Cargo.toml @@ -5,12 +5,10 @@ publish = false version = "0.1.0" [dependencies] -async-trait = "0.1.67" axum = { path = "../../axum" } -http-body = "1.0.0" serde = { version = "1.0", features = ["derive"] } thiserror = "1.0.29" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -validator = { version = "0.14.0", features = ["derive"] } +validator = { version = "0.18.1", features = ["derive"] } diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index 85c4ac1843..00e46173c4 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -10,7 +10,6 @@ //! ->

Hello, LT!

//! ``` -use async_trait::async_trait; use axum::{ extract::{rejection::FormRejection, Form, FromRequest, Request}, http::StatusCode, @@ -29,7 +28,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_validator=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -56,7 +55,6 @@ async fn handler(ValidatedForm(input): ValidatedForm) -> Html #[derive(Debug, Clone, Copy, Default)] pub struct ValidatedForm(pub T); -#[async_trait] impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index ee353f06bc..7b3ca5a581 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRequestParts, Path}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, @@ -20,7 +19,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_versioning=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -47,7 +46,6 @@ enum Version { V3, } -#[async_trait] impl FromRequestParts for Version where S: Send + Sync, diff --git a/examples/websockets/Cargo.toml b/examples/websockets/Cargo.toml index f62a8b03ee..541d82805a 100644 --- a/examples/websockets/Cargo.toml +++ b/examples/websockets/Cargo.toml @@ -11,9 +11,8 @@ futures = "0.3" futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } headers = "0.4" tokio = { version = "1.0", features = ["full"] } -tokio-tungstenite = "0.21" -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5.0", features = ["fs", "trace"] } +tokio-tungstenite = "0.24.0" +tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 9c3b9dbbf5..7b964404ee 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -45,8 +45,9 @@ use futures::{sink::SinkExt, stream::StreamExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_websockets=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init();