Skip to main content

cratestack_axum/idempotency/
service.rs

1//! `IdempotencyService` — the tower `Service` that owns the per-request
2//! state machine (reserve → run → complete/release).
3
4use std::sync::Arc;
5use std::time::{Duration, SystemTime};
6
7use axum::body::Body;
8use axum::extract::Request;
9use axum::response::Response;
10use cratestack_core::CoolError;
11use http::{StatusCode, header};
12use tower::Service;
13
14use super::hash::{hash_request, is_idempotent_target_method};
15use super::headers::encode_headers;
16use super::parse::parse_idempotency_key;
17use super::record::ReservationOutcome;
18use super::responses::{error_response, in_flight_response, replay_response};
19use super::store::{IdempotencyStore, MAX_BODY_BYTES};
20
21#[derive(Clone)]
22pub struct IdempotencyService<S> {
23    pub(super) inner: S,
24    pub(super) store: Arc<dyn IdempotencyStore>,
25    pub(super) ttl: Duration,
26    pub(super) principal_fingerprint: Arc<dyn Fn(&Request) -> String + Send + Sync>,
27}
28
29impl<S> Service<Request> for IdempotencyService<S>
30where
31    S: Service<Request, Response = Response, Error = std::convert::Infallible>
32        + Clone
33        + Send
34        + 'static,
35    S::Future: Send + 'static,
36{
37    type Response = Response;
38    type Error = std::convert::Infallible;
39    type Future =
40        std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Self::Error>> + Send>>;
41
42    fn poll_ready(
43        &mut self,
44        cx: &mut std::task::Context<'_>,
45    ) -> std::task::Poll<Result<(), Self::Error>> {
46        self.inner.poll_ready(cx)
47    }
48
49    fn call(&mut self, req: Request) -> Self::Future {
50        let mut inner = self.inner.clone();
51        let store = self.store.clone();
52        let ttl = self.ttl;
53        let principal_fp = self.principal_fingerprint.clone();
54        Box::pin(async move {
55            let method = req.method().clone();
56            if !is_idempotent_target_method(&method) {
57                return inner.call(req).await;
58            }
59            let key = match parse_idempotency_key(req.headers()) {
60                Ok(Some(k)) => k,
61                Ok(None) => return inner.call(req).await,
62                Err(error) => return Ok(error_response(error)),
63            };
64            let principal = (principal_fp)(&req);
65            // Hash the full path + query string. Skipping the query
66            // makes `POST /transfer?dry_run=true` collide with
67            // `POST /transfer?dry_run=false` under the same key, so a
68            // dry-run preview would replay the live execution's
69            // response (or vice versa). Banks routinely encode
70            // operation modifiers like `?confirm=true` or
71            // `?settlement=instant` in the query string — those must
72            // produce distinct idempotency hashes.
73            let path = req
74                .uri()
75                .path_and_query()
76                .map(|pq| pq.as_str().to_owned())
77                .unwrap_or_else(|| req.uri().path().to_owned());
78            let content_type = req
79                .headers()
80                .get(header::CONTENT_TYPE)
81                .and_then(|v| v.to_str().ok())
82                .map(|s| s.to_owned());
83
84            // Buffer the request body so we can both hash it and replay
85            // it into the inner handler.
86            let (parts, body) = req.into_parts();
87            let bytes = match axum::body::to_bytes(body, MAX_BODY_BYTES).await {
88                Ok(b) => b,
89                Err(_) => {
90                    return Ok(error_response(CoolError::BadRequest(
91                        "request body exceeds idempotency buffer limit".to_owned(),
92                    )));
93                }
94            };
95            let hash = hash_request(&method, &path, content_type.as_deref(), &bytes);
96
97            // Atomic reservation: exactly one caller gets `Reserved`,
98            // and only then do we let the handler run. Concurrent
99            // callers with the same key + same hash see `InFlight`;
100            // different-body conflicts see `Conflict`. This is the
101            // banking-grade duplicate-execution guarantee that the
102            // previous fetch-then-put pattern could not provide.
103            let expires_at = SystemTime::now() + ttl;
104            let outcome = match store
105                .reserve_or_fetch(&principal, &key, hash, expires_at)
106                .await
107            {
108                Ok(outcome) => outcome,
109                Err(error) => return Ok(error_response(error)),
110            };
111
112            let token = match outcome {
113                ReservationOutcome::Replay(record) => {
114                    return Ok(replay_response(&record));
115                }
116                ReservationOutcome::Conflict => {
117                    return Ok(error_response(CoolError::Validation(
118                        "idempotency_key_conflict: key reused with a different request body"
119                            .to_owned(),
120                    )));
121                }
122                ReservationOutcome::InFlight => {
123                    return Ok(in_flight_response());
124                }
125                ReservationOutcome::Reserved { token } => token,
126            };
127
128            // We hold the reservation. Run the handler.
129            let req2 = Request::from_parts(parts, Body::from(bytes));
130            let response_result = inner.call(req2).await;
131            let response = match response_result {
132                Ok(response) => response,
133                Err(_) => {
134                    // `Service::Error = Infallible` so this branch is
135                    // unreachable in practice. The release-on-error path
136                    // is still here for if/when a fallible inner service
137                    // is plugged in. Guarding on `token` ensures a
138                    // handler whose reservation has already been
139                    // reclaimed (TTL ran out) doesn't drop the new
140                    // owner's row.
141                    let _ = store.release(&principal, &key, token).await;
142                    return Ok(error_response(CoolError::Internal(
143                        "handler returned an unrecoverable error".to_owned(),
144                    )));
145                }
146            };
147            let (rparts, rbody) = response.into_parts();
148            let rbytes = match axum::body::to_bytes(rbody, MAX_BODY_BYTES).await {
149                Ok(b) => b,
150                Err(_) => {
151                    // Drop the reservation so retries can attempt
152                    // again — but only if our token still holds.
153                    let _ = store.release(&principal, &key, token).await;
154                    let mut e = error_response(CoolError::Internal(
155                        "response body exceeded idempotency buffer".to_owned(),
156                    ));
157                    *e.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
158                    return Ok(e);
159                }
160            };
161            // Capture the full header set so the replay reproduces the
162            // original handler's `Location`, `ETag`, cache directives,
163            // `Content-Type`, etc. Hop-by-hop and framework-computed
164            // headers are filtered inside `encode_headers`. Pre-fix
165            // the middleware only persisted `Content-Type`, so a
166            // `201 Created` with a `Location` header replayed as
167            // `201 Created` with no `Location` — different observable
168            // behaviour from the original execution.
169            let headers_blob = encode_headers(&rparts.headers);
170
171            // Persist the completion. Best-effort: on store failure we
172            // still return the live response so the caller observes the
173            // mutation that DID happen; banks running strict mode can
174            // wrap the store in a fail-closed adapter. The `token`
175            // guard means a handler whose reservation got reclaimed
176            // (TTL expired, retry took over) silently fails this
177            // write rather than poisoning the newer reservation's row.
178            let _ = store
179                .complete(
180                    &principal,
181                    &key,
182                    token,
183                    rparts.status.as_u16(),
184                    &headers_blob,
185                    &rbytes,
186                )
187                .await;
188            Ok(Response::from_parts(rparts, Body::from(rbytes)))
189        })
190    }
191}