Skip to main content

cratestack_sqlx/
isolation.rs

1//! Helpers for running banking-grade multi-row mutations under explicit
2//! transaction isolation, with retry on serialization failure.
3//!
4//! Procedures opt in via `@isolation("serializable")` in the schema; the
5//! macro records the requested level on a `ProcedureMetadata` const and
6//! handler code can wrap its body in [`run_in_isolated_tx`] to actually
7//! enforce it. A follow-up will auto-wrap procedure dispatch so opting in
8//! requires only the attribute.
9use crate::sqlx;
10
11use std::future::Future;
12
13use cratestack_core::{CoolError, TransactionIsolation};
14
15use crate::error::cool_error_from_sqlx;
16
17const MAX_RETRIES_DEFAULT: u32 = 3;
18const PG_SERIALIZATION_FAILURE_SQLSTATE: &str = "40001";
19const PG_DEADLOCK_DETECTED_SQLSTATE: &str = "40P01";
20
21/// Begin a transaction at the requested isolation level, run `body` against
22/// the live transaction, and commit. On `40001` (serialization_failure) or
23/// `40P01` (deadlock_detected) the transaction is rolled back and the body
24/// runs again, up to `MAX_RETRIES_DEFAULT` times. Other errors propagate
25/// immediately.
26///
27/// `body` receives a mutable transaction reference; it should run all of
28/// its SQL through that reference so the writes participate in the same
29/// transaction.
30pub async fn run_in_isolated_tx<F, Fut, T>(
31    pool: &sqlx::PgPool,
32    isolation: TransactionIsolation,
33    body: F,
34) -> Result<T, CoolError>
35where
36    F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
37    Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
38{
39    run_in_isolated_tx_with_retries(pool, isolation, MAX_RETRIES_DEFAULT, body).await
40}
41
42/// Same as [`run_in_isolated_tx`] but with a caller-chosen retry budget.
43/// Banks running long-tail contended writes sometimes want a higher cap
44/// (5–10); single-row CAS workflows can drop to 1 to fail fast.
45pub async fn run_in_isolated_tx_with_retries<F, Fut, T>(
46    pool: &sqlx::PgPool,
47    isolation: TransactionIsolation,
48    max_retries: u32,
49    mut body: F,
50) -> Result<T, CoolError>
51where
52    F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
53    Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
54{
55    let mut attempts = 0u32;
56    loop {
57        attempts += 1;
58        let mut tx = pool.begin().await.map_err(cool_error_from_sqlx)?;
59        let set_stmt = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation.as_sql());
60        sqlx::query(&set_stmt)
61            .execute(&mut *tx)
62            .await
63            .map_err(cool_error_from_sqlx)?;
64
65        match body(tx).await {
66            Ok((value, tx)) => match tx.commit().await {
67                Ok(()) => return Ok(value),
68                Err(commit_error) => {
69                    // PG can defer a serialization anomaly all the way to
70                    // COMMIT: the body's SQL runs cleanly, then the engine
71                    // detects the conflict during the predicate-lock check
72                    // at commit and rolls the transaction back with
73                    // SQLSTATE 40001 (the docs are explicit that the
74                    // *entire* transaction must be retried). Without this
75                    // branch we'd advertise automatic retries but still
76                    // leak a transient 40001 to callers when the conflict
77                    // is detected at the commit boundary.
78                    let promoted = cool_error_from_sqlx(commit_error);
79                    if attempts <= max_retries && is_retriable(&promoted) {
80                        tokio::task::yield_now().await;
81                        continue;
82                    }
83                    return Err(promoted);
84                }
85            },
86            Err(error) => {
87                if attempts <= max_retries && is_retriable(&error) {
88                    // Backoff is intentionally trivial — banks running this
89                    // under heavy contention should swap to a more thoughtful
90                    // jittered backoff. Sub-millisecond pause yields the
91                    // current task without keeping a tx open.
92                    tokio::task::yield_now().await;
93                    continue;
94                }
95                return Err(error);
96            }
97        }
98    }
99}
100
101fn is_retriable(error: &CoolError) -> bool {
102    // Fast path: typed variant surfaces the SQLSTATE directly. Only treat
103    // it as authoritative when the code matches a known retriable state;
104    // an unrecognized SQLSTATE falls through so the substring fallback
105    // still has a chance (drivers may surface retriable conditions in
106    // ways the typed path doesn't capture).
107    if let Some(code) = error.db_sqlstate()
108        && (code == PG_SERIALIZATION_FAILURE_SQLSTATE || code == PG_DEADLOCK_DETECTED_SQLSTATE)
109    {
110        return true;
111    }
112    // Fallback: legacy `Database(String)` variant — substring-match the detail
113    // string the way the original code did, so existing behaviour is preserved.
114    let detail = error.detail().unwrap_or_default();
115    detail.contains(PG_SERIALIZATION_FAILURE_SQLSTATE)
116        || detail.contains(PG_DEADLOCK_DETECTED_SQLSTATE)
117        || detail.contains("could not serialize access")
118        || detail.contains("deadlock detected")
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn parses_all_isolation_levels() {
127        assert_eq!(
128            TransactionIsolation::parse("serializable").unwrap(),
129            TransactionIsolation::Serializable,
130        );
131        assert_eq!(
132            TransactionIsolation::parse("Repeatable_Read").unwrap(),
133            TransactionIsolation::RepeatableRead,
134        );
135        assert_eq!(
136            TransactionIsolation::parse("read committed").unwrap(),
137            TransactionIsolation::ReadCommitted,
138        );
139        assert!(TransactionIsolation::parse("snapshot").is_err());
140    }
141
142    #[test]
143    fn sql_strings_match_pg_grammar() {
144        assert_eq!(TransactionIsolation::Serializable.as_sql(), "SERIALIZABLE");
145        assert_eq!(
146            TransactionIsolation::RepeatableRead.as_sql(),
147            "REPEATABLE READ",
148        );
149        assert_eq!(
150            TransactionIsolation::ReadCommitted.as_sql(),
151            "READ COMMITTED",
152        );
153    }
154
155    #[test]
156    fn retriable_on_serialization_failure_sqlstate() {
157        let err = CoolError::Database(
158            "Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
159             message: \"could not serialize access due to concurrent update\" })"
160                .to_owned(),
161        );
162        assert!(is_retriable(&err));
163    }
164
165    #[test]
166    fn retriable_on_deadlock_sqlstate() {
167        let err = CoolError::Database(
168            "Database(PgDatabaseError { code: \"40P01\", \
169             message: \"deadlock detected\" })"
170                .to_owned(),
171        );
172        assert!(is_retriable(&err));
173    }
174
175    #[test]
176    fn not_retriable_on_unique_violation() {
177        let err = CoolError::Database(
178            "duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
179        );
180        assert!(!is_retriable(&err));
181    }
182
183    #[test]
184    fn retriable_when_serialization_failure_is_raised_at_commit_time() {
185        // PG SSI can defer the 40001 to COMMIT. The sqlx error surfaced
186        // by `tx.commit()` carries the same SQLSTATE; the loop now
187        // promotes that into `CoolError::Database` and feeds it through
188        // `is_retriable` so the commit-time path is no longer leaked to
189        // callers despite the API advertising automatic retries.
190        let err = CoolError::Database(
191            "Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
192             message: \"could not serialize access due to read/write dependencies among transactions\" })"
193                .to_owned(),
194        );
195        assert!(is_retriable(&err));
196    }
197
198    // --- typed-variant paths ---
199
200    #[test]
201    fn retriable_typed_serialization_failure() {
202        use cratestack_core::DbErrorInfo;
203        let err = CoolError::DatabaseTyped(DbErrorInfo {
204            detail: "could not serialize access due to concurrent update".to_owned(),
205            sqlstate: Some("40001".to_owned()),
206            constraint: None,
207        });
208        assert!(
209            is_retriable(&err),
210            "DatabaseTyped with 40001 sqlstate must be retriable via the fast path",
211        );
212    }
213
214    #[test]
215    fn retriable_typed_deadlock() {
216        use cratestack_core::DbErrorInfo;
217        let err = CoolError::DatabaseTyped(DbErrorInfo {
218            detail: "deadlock detected".to_owned(),
219            sqlstate: Some("40P01".to_owned()),
220            constraint: None,
221        });
222        assert!(
223            is_retriable(&err),
224            "DatabaseTyped with 40P01 sqlstate must be retriable via the fast path",
225        );
226    }
227
228    #[test]
229    fn not_retriable_typed_unique_violation() {
230        use cratestack_core::DbErrorInfo;
231        let err = CoolError::DatabaseTyped(DbErrorInfo {
232            detail: "duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
233            sqlstate: Some("23505".to_owned()),
234            constraint: Some("accounts_pkey".to_owned()),
235        });
236        assert!(
237            !is_retriable(&err),
238            "unique_violation (23505) must not be retried",
239        );
240    }
241
242    #[test]
243    fn typed_variant_with_unknown_sqlstate_falls_through_to_detail_match() {
244        // A driver reports an unfamiliar SQLSTATE but the detail still
245        // contains a known retriable substring. The typed fast path must
246        // not short-circuit — the substring fallback must run.
247        use cratestack_core::DbErrorInfo;
248        let err = CoolError::DatabaseTyped(DbErrorInfo {
249            detail: "could not serialize access due to read/write dependencies".to_owned(),
250            sqlstate: Some("XX999".to_owned()),
251            constraint: None,
252        });
253        assert!(
254            is_retriable(&err),
255            "unknown sqlstate must fall through to detail-substring fallback",
256        );
257    }
258
259    #[test]
260    fn typed_variant_exposes_constraint_for_unique_violation() {
261        use cratestack_core::DbErrorInfo;
262        let err = CoolError::DatabaseTyped(DbErrorInfo {
263            detail: "duplicate key value violates unique constraint \"wallets_owner_key\""
264                .to_owned(),
265            sqlstate: Some("23505".to_owned()),
266            constraint: Some("wallets_owner_key".to_owned()),
267        });
268        assert_eq!(err.db_sqlstate(), Some("23505"));
269        assert_eq!(err.db_constraint(), Some("wallets_owner_key"));
270        // Public message must remain canned — no detail leak.
271        assert_eq!(err.public_message(), "internal error");
272    }
273}