cratestack_axum/idempotency/service.rs
1//! `IdempotencyService` — the tower `Service` that owns the per-request
2//! state machine (reserve → run → complete/release).
3
4use std::sync::Arc;
5use std::time::{Duration, SystemTime};
6
7use axum::body::Body;
8use axum::extract::Request;
9use axum::response::Response;
10use cratestack_core::CoolError;
11use http::{StatusCode, header};
12use tower::Service;
13
14use super::hash::{hash_request, is_idempotent_target_method};
15use super::headers::encode_headers;
16use super::parse::parse_idempotency_key;
17use super::record::ReservationOutcome;
18use super::responses::{error_response, in_flight_response, replay_response};
19use super::store::{IdempotencyStore, MAX_BODY_BYTES};
20
21#[derive(Clone)]
22pub struct IdempotencyService<S> {
23 pub(super) inner: S,
24 pub(super) store: Arc<dyn IdempotencyStore>,
25 pub(super) ttl: Duration,
26 pub(super) principal_fingerprint: Arc<dyn Fn(&Request) -> String + Send + Sync>,
27}
28
29impl<S> Service<Request> for IdempotencyService<S>
30where
31 S: Service<Request, Response = Response, Error = std::convert::Infallible>
32 + Clone
33 + Send
34 + 'static,
35 S::Future: Send + 'static,
36{
37 type Response = Response;
38 type Error = std::convert::Infallible;
39 type Future =
40 std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Self::Error>> + Send>>;
41
42 fn poll_ready(
43 &mut self,
44 cx: &mut std::task::Context<'_>,
45 ) -> std::task::Poll<Result<(), Self::Error>> {
46 self.inner.poll_ready(cx)
47 }
48
49 fn call(&mut self, req: Request) -> Self::Future {
50 let mut inner = self.inner.clone();
51 let store = self.store.clone();
52 let ttl = self.ttl;
53 let principal_fp = self.principal_fingerprint.clone();
54 Box::pin(async move {
55 let method = req.method().clone();
56 if !is_idempotent_target_method(&method) {
57 return inner.call(req).await;
58 }
59 let key = match parse_idempotency_key(req.headers()) {
60 Ok(Some(k)) => k,
61 Ok(None) => return inner.call(req).await,
62 Err(error) => return Ok(error_response(error)),
63 };
64 let principal = (principal_fp)(&req);
65 // Hash the full path + query string. Skipping the query
66 // makes `POST /transfer?dry_run=true` collide with
67 // `POST /transfer?dry_run=false` under the same key, so a
68 // dry-run preview would replay the live execution's
69 // response (or vice versa). Banks routinely encode
70 // operation modifiers like `?confirm=true` or
71 // `?settlement=instant` in the query string — those must
72 // produce distinct idempotency hashes.
73 let path = req
74 .uri()
75 .path_and_query()
76 .map(|pq| pq.as_str().to_owned())
77 .unwrap_or_else(|| req.uri().path().to_owned());
78 let content_type = req
79 .headers()
80 .get(header::CONTENT_TYPE)
81 .and_then(|v| v.to_str().ok())
82 .map(|s| s.to_owned());
83
84 // Buffer the request body so we can both hash it and replay
85 // it into the inner handler.
86 let (parts, body) = req.into_parts();
87 let bytes = match axum::body::to_bytes(body, MAX_BODY_BYTES).await {
88 Ok(b) => b,
89 Err(_) => {
90 return Ok(error_response(CoolError::BadRequest(
91 "request body exceeds idempotency buffer limit".to_owned(),
92 )));
93 }
94 };
95 let hash = hash_request(&method, &path, content_type.as_deref(), &bytes);
96
97 // Atomic reservation: exactly one caller gets `Reserved`,
98 // and only then do we let the handler run. Concurrent
99 // callers with the same key + same hash see `InFlight`;
100 // different-body conflicts see `Conflict`. This is the
101 // banking-grade duplicate-execution guarantee that the
102 // previous fetch-then-put pattern could not provide.
103 let expires_at = SystemTime::now() + ttl;
104 let outcome = match store
105 .reserve_or_fetch(&principal, &key, hash, expires_at)
106 .await
107 {
108 Ok(outcome) => outcome,
109 Err(error) => return Ok(error_response(error)),
110 };
111
112 let token = match outcome {
113 ReservationOutcome::Replay(record) => {
114 return Ok(replay_response(&record));
115 }
116 ReservationOutcome::Conflict => {
117 return Ok(error_response(CoolError::Validation(
118 "idempotency_key_conflict: key reused with a different request body"
119 .to_owned(),
120 )));
121 }
122 ReservationOutcome::InFlight => {
123 return Ok(in_flight_response());
124 }
125 ReservationOutcome::Reserved { token } => token,
126 };
127
128 // We hold the reservation. Run the handler.
129 let req2 = Request::from_parts(parts, Body::from(bytes));
130 let response_result = inner.call(req2).await;
131 let response = match response_result {
132 Ok(response) => response,
133 Err(_) => {
134 // `Service::Error = Infallible` so this branch is
135 // unreachable in practice. The release-on-error path
136 // is still here for if/when a fallible inner service
137 // is plugged in. Guarding on `token` ensures a
138 // handler whose reservation has already been
139 // reclaimed (TTL ran out) doesn't drop the new
140 // owner's row.
141 let _ = store.release(&principal, &key, token).await;
142 return Ok(error_response(CoolError::Internal(
143 "handler returned an unrecoverable error".to_owned(),
144 )));
145 }
146 };
147 let (rparts, rbody) = response.into_parts();
148 let rbytes = match axum::body::to_bytes(rbody, MAX_BODY_BYTES).await {
149 Ok(b) => b,
150 Err(_) => {
151 // Drop the reservation so retries can attempt
152 // again — but only if our token still holds.
153 let _ = store.release(&principal, &key, token).await;
154 let mut e = error_response(CoolError::Internal(
155 "response body exceeded idempotency buffer".to_owned(),
156 ));
157 *e.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
158 return Ok(e);
159 }
160 };
161 // Capture the full header set so the replay reproduces the
162 // original handler's `Location`, `ETag`, cache directives,
163 // `Content-Type`, etc. Hop-by-hop and framework-computed
164 // headers are filtered inside `encode_headers`. Pre-fix
165 // the middleware only persisted `Content-Type`, so a
166 // `201 Created` with a `Location` header replayed as
167 // `201 Created` with no `Location` — different observable
168 // behaviour from the original execution.
169 let headers_blob = encode_headers(&rparts.headers);
170
171 // Persist the completion. Best-effort: on store failure we
172 // still return the live response so the caller observes the
173 // mutation that DID happen; banks running strict mode can
174 // wrap the store in a fail-closed adapter. The `token`
175 // guard means a handler whose reservation got reclaimed
176 // (TTL expired, retry took over) silently fails this
177 // write rather than poisoning the newer reservation's row.
178 let _ = store
179 .complete(
180 &principal,
181 &key,
182 token,
183 rparts.status.as_u16(),
184 &headers_blob,
185 &rbytes,
186 )
187 .await;
188 Ok(Response::from_parts(rparts, Body::from(rbytes)))
189 })
190 }
191}