Skip to main content

cratestack_sqlx/query/
write.rs

1use cratestack_core::{CoolContext, CoolError, ModelEventKind};
2
3use crate::{
4    CreateModelInput, ModelDescriptor, SqlxRuntime, UpdateModelInput,
5    descriptor::{enqueue_event_outbox, ensure_event_outbox_table},
6};
7
8use super::support::{
9    apply_create_defaults, evaluate_create_policies, push_action_policy_query, push_bind_value,
10};
11
12#[derive(Debug, Clone)]
13pub struct CreateRecord<'a, M: 'static, PK: 'static, I> {
14    pub(crate) runtime: &'a SqlxRuntime,
15    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
16    pub(crate) input: I,
17}
18
19impl<'a, M: 'static, PK: 'static, I> CreateRecord<'a, M, PK, I>
20where
21    I: CreateModelInput<M>,
22{
23    pub fn preview_sql(&self) -> String {
24        let values = self.input.sql_values();
25        let placeholders = (1..=values.len())
26            .map(|index| format!("${index}"))
27            .collect::<Vec<_>>()
28            .join(", ");
29        let columns = values
30            .iter()
31            .map(|value| value.column)
32            .collect::<Vec<_>>()
33            .join(", ");
34
35        format!(
36            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
37            self.descriptor.table_name,
38            columns,
39            placeholders,
40            self.descriptor.select_projection(),
41        )
42    }
43
44    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
45    where
46        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
47    {
48        let emits_event = self.descriptor.emits(ModelEventKind::Created);
49        let record = if emits_event {
50            let mut tx = self
51                .runtime
52                .pool()
53                .begin()
54                .await
55                .map_err(|error| CoolError::Database(error.to_string()))?;
56            ensure_event_outbox_table(&mut *tx).await?;
57            let record = create_record_with_executor(
58                &mut *tx,
59                self.runtime.pool(),
60                self.descriptor,
61                self.input,
62                ctx,
63            )
64            .await?;
65            enqueue_event_outbox(
66                &mut *tx,
67                self.descriptor.schema_name,
68                ModelEventKind::Created,
69                &record,
70            )
71            .await?;
72            tx.commit()
73                .await
74                .map_err(|error| CoolError::Database(error.to_string()))?;
75            record
76        } else {
77            create_record_with_executor(
78                self.runtime.pool(),
79                self.runtime.pool(),
80                self.descriptor,
81                self.input,
82                ctx,
83            )
84            .await?
85        };
86
87        if emits_event {
88            let _ = self.runtime.drain_event_outbox().await;
89        }
90
91        Ok(record)
92    }
93}
94
95#[derive(Debug, Clone)]
96pub struct UpdateRecord<'a, M: 'static, PK: 'static> {
97    pub(crate) runtime: &'a SqlxRuntime,
98    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
99    pub(crate) id: PK,
100}
101
102impl<'a, M: 'static, PK: 'static> UpdateRecord<'a, M, PK> {
103    pub fn set<I>(self, input: I) -> UpdateRecordSet<'a, M, PK, I> {
104        UpdateRecordSet {
105            runtime: self.runtime,
106            descriptor: self.descriptor,
107            id: self.id,
108            input,
109        }
110    }
111}
112
113#[derive(Debug, Clone)]
114pub struct UpdateRecordSet<'a, M: 'static, PK: 'static, I> {
115    pub(crate) runtime: &'a SqlxRuntime,
116    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
117    pub(crate) id: PK,
118    pub(crate) input: I,
119}
120
121impl<'a, M: 'static, PK: 'static, I> UpdateRecordSet<'a, M, PK, I>
122where
123    I: UpdateModelInput<M>,
124{
125    pub fn preview_sql(&self) -> String {
126        let values = self.input.sql_values();
127        let assignments = values
128            .iter()
129            .enumerate()
130            .map(|(index, value)| format!("{} = ${}", value.column, index + 1))
131            .collect::<Vec<_>>()
132            .join(", ");
133
134        format!(
135            "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
136            self.descriptor.table_name,
137            assignments,
138            self.descriptor.primary_key,
139            values.len() + 1,
140            self.descriptor.select_projection(),
141        )
142    }
143
144    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
145    where
146        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
147        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
148    {
149        let emits_event = self.descriptor.emits(ModelEventKind::Updated);
150        let record = if emits_event {
151            let mut tx = self
152                .runtime
153                .pool()
154                .begin()
155                .await
156                .map_err(|error| CoolError::Database(error.to_string()))?;
157            ensure_event_outbox_table(&mut *tx).await?;
158            let record =
159                update_record_with_executor(&mut *tx, self.descriptor, self.id, self.input, ctx)
160                    .await?;
161            enqueue_event_outbox(
162                &mut *tx,
163                self.descriptor.schema_name,
164                ModelEventKind::Updated,
165                &record,
166            )
167            .await?;
168            tx.commit()
169                .await
170                .map_err(|error| CoolError::Database(error.to_string()))?;
171            record
172        } else {
173            update_record_with_executor(
174                self.runtime.pool(),
175                self.descriptor,
176                self.id,
177                self.input,
178                ctx,
179            )
180            .await?
181        };
182
183        if emits_event {
184            let _ = self.runtime.drain_event_outbox().await;
185        }
186
187        Ok(record)
188    }
189}
190
191#[derive(Debug, Clone)]
192pub struct DeleteRecord<'a, M: 'static, PK: 'static> {
193    pub(crate) runtime: &'a SqlxRuntime,
194    pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
195    pub(crate) id: PK,
196}
197
198impl<'a, M: 'static, PK: 'static> DeleteRecord<'a, M, PK> {
199    pub fn preview_sql(&self) -> String {
200        format!(
201            "DELETE FROM {} WHERE {} = $1 RETURNING {}",
202            self.descriptor.table_name,
203            self.descriptor.primary_key,
204            self.descriptor.select_projection(),
205        )
206    }
207
208    pub async fn run(self, ctx: &CoolContext) -> Result<M, CoolError>
209    where
210        for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
211        PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
212    {
213        let emits_event = self.descriptor.emits(ModelEventKind::Deleted);
214        let record = if emits_event {
215            let mut tx = self
216                .runtime
217                .pool()
218                .begin()
219                .await
220                .map_err(|error| CoolError::Database(error.to_string()))?;
221            ensure_event_outbox_table(&mut *tx).await?;
222
223            let record = delete_returning_record(&mut *tx, self.descriptor, self.id, ctx).await?;
224            enqueue_event_outbox(
225                &mut *tx,
226                self.descriptor.schema_name,
227                ModelEventKind::Deleted,
228                &record,
229            )
230            .await?;
231            tx.commit()
232                .await
233                .map_err(|error| CoolError::Database(error.to_string()))?;
234            record
235        } else {
236            delete_returning_record(self.runtime.pool(), self.descriptor, self.id, ctx).await?
237        };
238
239        if emits_event {
240            let _ = self.runtime.drain_event_outbox().await;
241        }
242
243        Ok(record)
244    }
245}
246
247pub async fn create_record_with_executor<'e, E, M, PK, I>(
248    executor: E,
249    policy_pool: &sqlx::PgPool,
250    descriptor: &'static ModelDescriptor<M, PK>,
251    input: I,
252    ctx: &CoolContext,
253) -> Result<M, CoolError>
254where
255    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
256    I: CreateModelInput<M>,
257    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
258{
259    let values = apply_create_defaults(input.sql_values(), descriptor.create_defaults, ctx)?;
260    if values.is_empty() {
261        return Err(CoolError::Validation(
262            "create input must contain at least one column".to_owned(),
263        ));
264    }
265    if !evaluate_create_policies(
266        policy_pool,
267        descriptor.create_allow_policies,
268        descriptor.create_deny_policies,
269        &values,
270        ctx,
271    )
272    .await?
273    {
274        return Err(CoolError::Forbidden(
275            "create policy denied this operation".to_owned(),
276        ));
277    }
278
279    insert_returning_record(executor, descriptor, &values).await
280}
281
282pub async fn update_record_with_executor<'e, E, M, PK, I>(
283    executor: E,
284    descriptor: &'static ModelDescriptor<M, PK>,
285    id: PK,
286    input: I,
287    ctx: &CoolContext,
288) -> Result<M, CoolError>
289where
290    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
291    I: UpdateModelInput<M>,
292    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + serde::Serialize,
293    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
294{
295    let values = input.sql_values();
296    if values.is_empty() {
297        return Err(CoolError::Validation(
298            "update input must contain at least one changed column".to_owned(),
299        ));
300    }
301
302    update_returning_record(executor, descriptor, id, &values, ctx).await
303}
304
305async fn insert_returning_record<'e, E, M, PK>(
306    executor: E,
307    descriptor: &'static ModelDescriptor<M, PK>,
308    values: &[crate::SqlColumnValue],
309) -> Result<M, CoolError>
310where
311    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
312    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
313{
314    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("INSERT INTO ");
315    query.push(descriptor.table_name).push(" (");
316    for (index, value) in values.iter().enumerate() {
317        if index > 0 {
318            query.push(", ");
319        }
320        query.push(value.column);
321    }
322    query.push(") VALUES (");
323    for (index, value) in values.iter().enumerate() {
324        if index > 0 {
325            query.push(", ");
326        }
327        push_bind_value(&mut query, &value.value);
328    }
329    query
330        .push(") RETURNING ")
331        .push(descriptor.select_projection());
332
333    query
334        .build_query_as::<M>()
335        .fetch_one(executor)
336        .await
337        .map_err(|error| CoolError::Database(error.to_string()))
338}
339
340async fn update_returning_record<'e, E, M, PK>(
341    executor: E,
342    descriptor: &'static ModelDescriptor<M, PK>,
343    id: PK,
344    values: &[crate::SqlColumnValue],
345    ctx: &CoolContext,
346) -> Result<M, CoolError>
347where
348    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
349    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
350    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
351{
352    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("UPDATE ");
353    query.push(descriptor.table_name).push(" SET ");
354    for (index, value) in values.iter().enumerate() {
355        if index > 0 {
356            query.push(", ");
357        }
358        query.push(value.column).push(" = ");
359        push_bind_value(&mut query, &value.value);
360    }
361    query
362        .push(" WHERE ")
363        .push(descriptor.primary_key)
364        .push(" = ");
365    query.push_bind(id);
366    query.push(" AND ");
367    push_action_policy_query(
368        &mut query,
369        descriptor.update_allow_policies,
370        descriptor.update_deny_policies,
371        ctx,
372    );
373    query
374        .push(" RETURNING ")
375        .push(descriptor.select_projection());
376
377    query
378        .build_query_as::<M>()
379        .fetch_optional(executor)
380        .await
381        .map_err(|error| CoolError::Database(error.to_string()))?
382        .ok_or_else(|| CoolError::Forbidden("update policy denied this operation".to_owned()))
383}
384
385async fn delete_returning_record<'e, E, M, PK>(
386    executor: E,
387    descriptor: &'static ModelDescriptor<M, PK>,
388    id: PK,
389    ctx: &CoolContext,
390) -> Result<M, CoolError>
391where
392    E: sqlx::Executor<'e, Database = sqlx::Postgres>,
393    for<'r> M: Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow>,
394    PK: Send + sqlx::Type<sqlx::Postgres> + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
395{
396    let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("DELETE FROM ");
397    query.push(descriptor.table_name).push(" WHERE ");
398    query.push(descriptor.primary_key).push(" = ");
399    query.push_bind(id);
400    query.push(" AND ");
401    push_action_policy_query(
402        &mut query,
403        descriptor.delete_allow_policies,
404        descriptor.delete_deny_policies,
405        ctx,
406    );
407    query
408        .push(" RETURNING ")
409        .push(descriptor.select_projection());
410
411    query
412        .build_query_as::<M>()
413        .fetch_optional(executor)
414        .await
415        .map_err(|error| CoolError::Database(error.to_string()))?
416        .ok_or_else(|| CoolError::Forbidden("delete policy denied this operation".to_owned()))
417}