Skip to main content

cratestack_policy/
lib.rs

1use cratestack_core::{CoolContext, CoolError, Value};
2#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3pub enum RelationQuantifier {
4    ToOne,
5    Some,
6    Every,
7    None,
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum PolicyLiteral {
12    Bool(bool),
13    Int(i64),
14    String(&'static str),
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ReadPredicate {
19    AuthNotNull,
20    AuthIsNull,
21    HasRole {
22        role: &'static str,
23    },
24    InTenant {
25        tenant_id: &'static str,
26    },
27    AuthFieldEqLiteral {
28        auth_field: &'static str,
29        value: PolicyLiteral,
30    },
31    AuthFieldNeLiteral {
32        auth_field: &'static str,
33        value: PolicyLiteral,
34    },
35    FieldIsTrue {
36        column: &'static str,
37    },
38    FieldEqLiteral {
39        column: &'static str,
40        value: PolicyLiteral,
41    },
42    FieldNeLiteral {
43        column: &'static str,
44        value: PolicyLiteral,
45    },
46    FieldEqAuth {
47        column: &'static str,
48        auth_field: &'static str,
49    },
50    FieldNeAuth {
51        column: &'static str,
52        auth_field: &'static str,
53    },
54    Relation {
55        quantifier: RelationQuantifier,
56        parent_table: &'static str,
57        parent_column: &'static str,
58        related_table: &'static str,
59        related_column: &'static str,
60        expr: &'static PolicyExpr,
61    },
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct ReadPolicy {
66    pub expr: PolicyExpr,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum PolicyExpr {
71    Predicate(ReadPredicate),
72    And(&'static [PolicyExpr]),
73    Or(&'static [PolicyExpr]),
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum ProcedurePolicyLiteral {
78    Bool(bool),
79    Int(i64),
80    String(&'static str),
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum ProcedurePredicate {
85    AuthNotNull,
86    AuthIsNull,
87    HasRole {
88        role: &'static str,
89    },
90    InTenant {
91        tenant_id: &'static str,
92    },
93    AuthFieldEqLiteral {
94        auth_field: &'static str,
95        value: ProcedurePolicyLiteral,
96    },
97    AuthFieldNeLiteral {
98        auth_field: &'static str,
99        value: ProcedurePolicyLiteral,
100    },
101    InputFieldIsTrue {
102        field: &'static str,
103    },
104    InputFieldEqLiteral {
105        field: &'static str,
106        value: ProcedurePolicyLiteral,
107    },
108    InputFieldNeLiteral {
109        field: &'static str,
110        value: ProcedurePolicyLiteral,
111    },
112    InputFieldEqAuth {
113        field: &'static str,
114        auth_field: &'static str,
115    },
116    InputFieldNeAuth {
117        field: &'static str,
118        auth_field: &'static str,
119    },
120    InputFieldEqInput {
121        field: &'static str,
122        other_field: &'static str,
123    },
124    InputFieldNeInput {
125        field: &'static str,
126        other_field: &'static str,
127    },
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub struct ProcedurePolicy {
132    pub expr: ProcedurePolicyExpr,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum ProcedurePolicyExpr {
137    Predicate(ProcedurePredicate),
138    And(&'static [ProcedurePolicyExpr]),
139    Or(&'static [ProcedurePolicyExpr]),
140}
141
142pub trait ProcedureArgs {
143    fn procedure_arg_value(&self, field: &str) -> Option<Value>;
144}
145
146impl ProcedureArgs for () {
147    fn procedure_arg_value(&self, _field: &str) -> Option<Value> {
148        None
149    }
150}
151
152pub fn authorize_procedure<A: ProcedureArgs + ?Sized>(
153    allow_policies: &[ProcedurePolicy],
154    deny_policies: &[ProcedurePolicy],
155    args: &A,
156    ctx: &CoolContext,
157) -> Result<(), CoolError> {
158    if allow_policies.is_empty() {
159        return Err(CoolError::Forbidden(
160            "procedure policy denied this operation".to_owned(),
161        ));
162    }
163
164    if deny_policies
165        .iter()
166        .any(|policy| procedure_policy_expr_matches(policy.expr, args, ctx))
167    {
168        return Err(CoolError::Forbidden(
169            "procedure policy denied this operation".to_owned(),
170        ));
171    }
172
173    if allow_policies
174        .iter()
175        .any(|policy| procedure_policy_expr_matches(policy.expr, args, ctx))
176    {
177        Ok(())
178    } else {
179        Err(CoolError::Forbidden(
180            "procedure policy denied this operation".to_owned(),
181        ))
182    }
183}
184
185pub fn context_has_role(ctx: &CoolContext, role: &str) -> bool {
186    ctx.auth_field("role")
187        .or_else(|| ctx.auth_field("actor.role"))
188        .is_some_and(|value| matches!(value, Value::String(candidate) if candidate == role))
189}
190
191pub fn context_in_tenant(ctx: &CoolContext, tenant_id: &str) -> bool {
192    ctx.auth_field("tenant.id")
193        .is_some_and(|value| matches!(value, Value::String(candidate) if candidate == tenant_id))
194}
195
196fn procedure_policy_expr_matches<A: ProcedureArgs + ?Sized>(
197    expr: ProcedurePolicyExpr,
198    args: &A,
199    ctx: &CoolContext,
200) -> bool {
201    match expr {
202        ProcedurePolicyExpr::Predicate(predicate) => {
203            procedure_predicate_matches(predicate, args, ctx)
204        }
205        ProcedurePolicyExpr::And(exprs) => exprs
206            .iter()
207            .copied()
208            .all(|expr| procedure_policy_expr_matches(expr, args, ctx)),
209        ProcedurePolicyExpr::Or(exprs) => exprs
210            .iter()
211            .copied()
212            .any(|expr| procedure_policy_expr_matches(expr, args, ctx)),
213    }
214}
215
216fn procedure_predicate_matches<A: ProcedureArgs + ?Sized>(
217    predicate: ProcedurePredicate,
218    args: &A,
219    ctx: &CoolContext,
220) -> bool {
221    match predicate {
222        ProcedurePredicate::AuthNotNull => ctx.is_authenticated(),
223        ProcedurePredicate::AuthIsNull => !ctx.is_authenticated(),
224        ProcedurePredicate::HasRole { role } => context_has_role(ctx, role),
225        ProcedurePredicate::InTenant { tenant_id } => context_in_tenant(ctx, tenant_id),
226        ProcedurePredicate::AuthFieldEqLiteral { auth_field, value } => ctx
227            .auth_field(auth_field)
228            .is_some_and(|candidate| value_matches_literal(candidate, value)),
229        ProcedurePredicate::AuthFieldNeLiteral { auth_field, value } => ctx
230            .auth_field(auth_field)
231            .is_some_and(|candidate| !value_matches_literal(candidate, value)),
232        ProcedurePredicate::InputFieldIsTrue { field } => args
233            .procedure_arg_value(field)
234            .is_some_and(|value| value == Value::Bool(true)),
235        ProcedurePredicate::InputFieldEqLiteral { field, value } => args
236            .procedure_arg_value(field)
237            .is_some_and(|candidate| value_matches_literal(&candidate, value)),
238        ProcedurePredicate::InputFieldNeLiteral { field, value } => args
239            .procedure_arg_value(field)
240            .is_some_and(|candidate| !value_matches_literal(&candidate, value)),
241        ProcedurePredicate::InputFieldEqAuth { field, auth_field } => {
242            match (args.procedure_arg_value(field), ctx.auth_field(auth_field)) {
243                (Some(left), Some(right)) => &left == right,
244                _ => false,
245            }
246        }
247        ProcedurePredicate::InputFieldNeAuth { field, auth_field } => {
248            match (args.procedure_arg_value(field), ctx.auth_field(auth_field)) {
249                (Some(left), Some(right)) => &left != right,
250                _ => false,
251            }
252        }
253        ProcedurePredicate::InputFieldEqInput { field, other_field } => {
254            match (
255                args.procedure_arg_value(field),
256                args.procedure_arg_value(other_field),
257            ) {
258                (Some(left), Some(right)) => left == right,
259                _ => false,
260            }
261        }
262        ProcedurePredicate::InputFieldNeInput { field, other_field } => {
263            match (
264                args.procedure_arg_value(field),
265                args.procedure_arg_value(other_field),
266            ) {
267                (Some(left), Some(right)) => left != right,
268                _ => false,
269            }
270        }
271    }
272}
273
274fn value_matches_literal(value: &Value, literal: ProcedurePolicyLiteral) -> bool {
275    match (value, literal) {
276        (Value::Bool(left), ProcedurePolicyLiteral::Bool(right)) => *left == right,
277        (Value::Int(left), ProcedurePolicyLiteral::Int(right)) => *left == right,
278        (Value::String(left), ProcedurePolicyLiteral::String(right)) => left == right,
279        _ => false,
280    }
281}
282
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use std::collections::BTreeMap;
287
288    #[test]
289    fn has_role_checks_top_level_and_actor_role() {
290        let top_level =
291            CoolContext::authenticated([("role".to_owned(), Value::String("admin".to_owned()))]);
292        assert!(context_has_role(&top_level, "admin"));
293        assert!(!context_has_role(&top_level, "member"));
294
295        let actor_role = CoolContext::authenticated([(
296            "actor".to_owned(),
297            Value::Map(BTreeMap::from([(
298                "role".to_owned(),
299                Value::String("merchant".to_owned()),
300            )])),
301        )]);
302        assert!(context_has_role(&actor_role, "merchant"));
303    }
304
305    #[test]
306    fn in_tenant_checks_structured_tenant_id() {
307        let ctx = CoolContext::authenticated([(
308            "tenant".to_owned(),
309            Value::Map(BTreeMap::from([(
310                "id".to_owned(),
311                Value::String("tenant_1".to_owned()),
312            )])),
313        )]);
314        assert!(context_in_tenant(&ctx, "tenant_1"));
315        assert!(!context_in_tenant(&ctx, "tenant_2"));
316    }
317}