cratestack_sqlx/
isolation.rs1use crate::sqlx;
10
11use std::future::Future;
12
13use cratestack_core::{CoolError, TransactionIsolation};
14
15use crate::error::cool_error_from_sqlx;
16
17const MAX_RETRIES_DEFAULT: u32 = 3;
18const PG_SERIALIZATION_FAILURE_SQLSTATE: &str = "40001";
19const PG_DEADLOCK_DETECTED_SQLSTATE: &str = "40P01";
20
21pub async fn run_in_isolated_tx<F, Fut, T>(
31 pool: &sqlx::PgPool,
32 isolation: TransactionIsolation,
33 body: F,
34) -> Result<T, CoolError>
35where
36 F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
37 Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
38{
39 run_in_isolated_tx_with_retries(pool, isolation, MAX_RETRIES_DEFAULT, body).await
40}
41
42pub async fn run_in_isolated_tx_with_retries<F, Fut, T>(
46 pool: &sqlx::PgPool,
47 isolation: TransactionIsolation,
48 max_retries: u32,
49 mut body: F,
50) -> Result<T, CoolError>
51where
52 F: FnMut(sqlx::Transaction<'static, sqlx::Postgres>) -> Fut,
53 Fut: Future<Output = Result<(T, sqlx::Transaction<'static, sqlx::Postgres>), CoolError>>,
54{
55 let mut attempts = 0u32;
56 loop {
57 attempts += 1;
58 let mut tx = pool.begin().await.map_err(cool_error_from_sqlx)?;
59 let set_stmt = format!("SET TRANSACTION ISOLATION LEVEL {}", isolation.as_sql());
60 sqlx::query(&set_stmt)
61 .execute(&mut *tx)
62 .await
63 .map_err(cool_error_from_sqlx)?;
64
65 match body(tx).await {
66 Ok((value, tx)) => match tx.commit().await {
67 Ok(()) => return Ok(value),
68 Err(commit_error) => {
69 let promoted = cool_error_from_sqlx(commit_error);
79 if attempts <= max_retries && is_retriable(&promoted) {
80 tokio::task::yield_now().await;
81 continue;
82 }
83 return Err(promoted);
84 }
85 },
86 Err(error) => {
87 if attempts <= max_retries && is_retriable(&error) {
88 tokio::task::yield_now().await;
93 continue;
94 }
95 return Err(error);
96 }
97 }
98 }
99}
100
101fn is_retriable(error: &CoolError) -> bool {
102 if let Some(code) = error.db_sqlstate()
108 && (code == PG_SERIALIZATION_FAILURE_SQLSTATE || code == PG_DEADLOCK_DETECTED_SQLSTATE)
109 {
110 return true;
111 }
112 let detail = error.detail().unwrap_or_default();
115 detail.contains(PG_SERIALIZATION_FAILURE_SQLSTATE)
116 || detail.contains(PG_DEADLOCK_DETECTED_SQLSTATE)
117 || detail.contains("could not serialize access")
118 || detail.contains("deadlock detected")
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn parses_all_isolation_levels() {
127 assert_eq!(
128 TransactionIsolation::parse("serializable").unwrap(),
129 TransactionIsolation::Serializable,
130 );
131 assert_eq!(
132 TransactionIsolation::parse("Repeatable_Read").unwrap(),
133 TransactionIsolation::RepeatableRead,
134 );
135 assert_eq!(
136 TransactionIsolation::parse("read committed").unwrap(),
137 TransactionIsolation::ReadCommitted,
138 );
139 assert!(TransactionIsolation::parse("snapshot").is_err());
140 }
141
142 #[test]
143 fn sql_strings_match_pg_grammar() {
144 assert_eq!(TransactionIsolation::Serializable.as_sql(), "SERIALIZABLE");
145 assert_eq!(
146 TransactionIsolation::RepeatableRead.as_sql(),
147 "REPEATABLE READ",
148 );
149 assert_eq!(
150 TransactionIsolation::ReadCommitted.as_sql(),
151 "READ COMMITTED",
152 );
153 }
154
155 #[test]
156 fn retriable_on_serialization_failure_sqlstate() {
157 let err = CoolError::Database(
158 "Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
159 message: \"could not serialize access due to concurrent update\" })"
160 .to_owned(),
161 );
162 assert!(is_retriable(&err));
163 }
164
165 #[test]
166 fn retriable_on_deadlock_sqlstate() {
167 let err = CoolError::Database(
168 "Database(PgDatabaseError { code: \"40P01\", \
169 message: \"deadlock detected\" })"
170 .to_owned(),
171 );
172 assert!(is_retriable(&err));
173 }
174
175 #[test]
176 fn not_retriable_on_unique_violation() {
177 let err = CoolError::Database(
178 "duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
179 );
180 assert!(!is_retriable(&err));
181 }
182
183 #[test]
184 fn retriable_when_serialization_failure_is_raised_at_commit_time() {
185 let err = CoolError::Database(
191 "Database(PgDatabaseError { severity: ERROR, code: \"40001\", \
192 message: \"could not serialize access due to read/write dependencies among transactions\" })"
193 .to_owned(),
194 );
195 assert!(is_retriable(&err));
196 }
197
198 #[test]
201 fn retriable_typed_serialization_failure() {
202 use cratestack_core::DbErrorInfo;
203 let err = CoolError::DatabaseTyped(DbErrorInfo {
204 detail: "could not serialize access due to concurrent update".to_owned(),
205 sqlstate: Some("40001".to_owned()),
206 constraint: None,
207 });
208 assert!(
209 is_retriable(&err),
210 "DatabaseTyped with 40001 sqlstate must be retriable via the fast path",
211 );
212 }
213
214 #[test]
215 fn retriable_typed_deadlock() {
216 use cratestack_core::DbErrorInfo;
217 let err = CoolError::DatabaseTyped(DbErrorInfo {
218 detail: "deadlock detected".to_owned(),
219 sqlstate: Some("40P01".to_owned()),
220 constraint: None,
221 });
222 assert!(
223 is_retriable(&err),
224 "DatabaseTyped with 40P01 sqlstate must be retriable via the fast path",
225 );
226 }
227
228 #[test]
229 fn not_retriable_typed_unique_violation() {
230 use cratestack_core::DbErrorInfo;
231 let err = CoolError::DatabaseTyped(DbErrorInfo {
232 detail: "duplicate key value violates unique constraint \"accounts_pkey\"".to_owned(),
233 sqlstate: Some("23505".to_owned()),
234 constraint: Some("accounts_pkey".to_owned()),
235 });
236 assert!(
237 !is_retriable(&err),
238 "unique_violation (23505) must not be retried",
239 );
240 }
241
242 #[test]
243 fn typed_variant_with_unknown_sqlstate_falls_through_to_detail_match() {
244 use cratestack_core::DbErrorInfo;
248 let err = CoolError::DatabaseTyped(DbErrorInfo {
249 detail: "could not serialize access due to read/write dependencies".to_owned(),
250 sqlstate: Some("XX999".to_owned()),
251 constraint: None,
252 });
253 assert!(
254 is_retriable(&err),
255 "unknown sqlstate must fall through to detail-substring fallback",
256 );
257 }
258
259 #[test]
260 fn typed_variant_exposes_constraint_for_unique_violation() {
261 use cratestack_core::DbErrorInfo;
262 let err = CoolError::DatabaseTyped(DbErrorInfo {
263 detail: "duplicate key value violates unique constraint \"wallets_owner_key\""
264 .to_owned(),
265 sqlstate: Some("23505".to_owned()),
266 constraint: Some("wallets_owner_key".to_owned()),
267 });
268 assert_eq!(err.db_sqlstate(), Some("23505"));
269 assert_eq!(err.db_constraint(), Some("wallets_owner_key"));
270 assert_eq!(err.public_message(), "internal error");
272 }
273}