1use cratestack_core::{CoolContext, CoolError, Value};
2#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3pub enum RelationQuantifier {
4 ToOne,
5 Some,
6 Every,
7 None,
8}
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum PolicyLiteral {
12 Bool(bool),
13 Int(i64),
14 String(&'static str),
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ReadPredicate {
19 AuthNotNull,
20 AuthIsNull,
21 HasRole {
22 role: &'static str,
23 },
24 InTenant {
25 tenant_id: &'static str,
26 },
27 AuthFieldEqLiteral {
28 auth_field: &'static str,
29 value: PolicyLiteral,
30 },
31 AuthFieldNeLiteral {
32 auth_field: &'static str,
33 value: PolicyLiteral,
34 },
35 FieldIsTrue {
36 column: &'static str,
37 },
38 FieldEqLiteral {
39 column: &'static str,
40 value: PolicyLiteral,
41 },
42 FieldNeLiteral {
43 column: &'static str,
44 value: PolicyLiteral,
45 },
46 FieldEqAuth {
47 column: &'static str,
48 auth_field: &'static str,
49 },
50 FieldNeAuth {
51 column: &'static str,
52 auth_field: &'static str,
53 },
54 Relation {
55 quantifier: RelationQuantifier,
56 parent_table: &'static str,
57 parent_column: &'static str,
58 related_table: &'static str,
59 related_column: &'static str,
60 expr: &'static PolicyExpr,
61 },
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct ReadPolicy {
66 pub expr: PolicyExpr,
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum PolicyExpr {
71 Predicate(ReadPredicate),
72 And(&'static [PolicyExpr]),
73 Or(&'static [PolicyExpr]),
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum ProcedurePolicyLiteral {
78 Bool(bool),
79 Int(i64),
80 String(&'static str),
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84pub enum ProcedurePredicate {
85 AuthNotNull,
86 AuthIsNull,
87 HasRole {
88 role: &'static str,
89 },
90 InTenant {
91 tenant_id: &'static str,
92 },
93 AuthFieldEqLiteral {
94 auth_field: &'static str,
95 value: ProcedurePolicyLiteral,
96 },
97 AuthFieldNeLiteral {
98 auth_field: &'static str,
99 value: ProcedurePolicyLiteral,
100 },
101 InputFieldIsTrue {
102 field: &'static str,
103 },
104 InputFieldEqLiteral {
105 field: &'static str,
106 value: ProcedurePolicyLiteral,
107 },
108 InputFieldNeLiteral {
109 field: &'static str,
110 value: ProcedurePolicyLiteral,
111 },
112 InputFieldEqAuth {
113 field: &'static str,
114 auth_field: &'static str,
115 },
116 InputFieldNeAuth {
117 field: &'static str,
118 auth_field: &'static str,
119 },
120 InputFieldEqInput {
121 field: &'static str,
122 other_field: &'static str,
123 },
124 InputFieldNeInput {
125 field: &'static str,
126 other_field: &'static str,
127 },
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub struct ProcedurePolicy {
132 pub expr: ProcedurePolicyExpr,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub enum ProcedurePolicyExpr {
137 Predicate(ProcedurePredicate),
138 And(&'static [ProcedurePolicyExpr]),
139 Or(&'static [ProcedurePolicyExpr]),
140}
141
142pub trait ProcedureArgs {
143 fn procedure_arg_value(&self, field: &str) -> Option<Value>;
144}
145
146impl ProcedureArgs for () {
147 fn procedure_arg_value(&self, _field: &str) -> Option<Value> {
148 None
149 }
150}
151
152pub fn authorize_procedure<A: ProcedureArgs + ?Sized>(
153 allow_policies: &[ProcedurePolicy],
154 deny_policies: &[ProcedurePolicy],
155 args: &A,
156 ctx: &CoolContext,
157) -> Result<(), CoolError> {
158 if allow_policies.is_empty() {
159 return Err(CoolError::Forbidden(
160 "procedure policy denied this operation".to_owned(),
161 ));
162 }
163
164 if deny_policies
165 .iter()
166 .any(|policy| procedure_policy_expr_matches(policy.expr, args, ctx))
167 {
168 return Err(CoolError::Forbidden(
169 "procedure policy denied this operation".to_owned(),
170 ));
171 }
172
173 if allow_policies
174 .iter()
175 .any(|policy| procedure_policy_expr_matches(policy.expr, args, ctx))
176 {
177 Ok(())
178 } else {
179 Err(CoolError::Forbidden(
180 "procedure policy denied this operation".to_owned(),
181 ))
182 }
183}
184
185pub fn context_has_role(ctx: &CoolContext, role: &str) -> bool {
186 ctx.auth_field("role")
187 .or_else(|| ctx.auth_field("actor.role"))
188 .is_some_and(|value| matches!(value, Value::String(candidate) if candidate == role))
189}
190
191pub fn context_in_tenant(ctx: &CoolContext, tenant_id: &str) -> bool {
192 ctx.auth_field("tenant.id")
193 .is_some_and(|value| matches!(value, Value::String(candidate) if candidate == tenant_id))
194}
195
196fn procedure_policy_expr_matches<A: ProcedureArgs + ?Sized>(
197 expr: ProcedurePolicyExpr,
198 args: &A,
199 ctx: &CoolContext,
200) -> bool {
201 match expr {
202 ProcedurePolicyExpr::Predicate(predicate) => {
203 procedure_predicate_matches(predicate, args, ctx)
204 }
205 ProcedurePolicyExpr::And(exprs) => exprs
206 .iter()
207 .copied()
208 .all(|expr| procedure_policy_expr_matches(expr, args, ctx)),
209 ProcedurePolicyExpr::Or(exprs) => exprs
210 .iter()
211 .copied()
212 .any(|expr| procedure_policy_expr_matches(expr, args, ctx)),
213 }
214}
215
216fn procedure_predicate_matches<A: ProcedureArgs + ?Sized>(
217 predicate: ProcedurePredicate,
218 args: &A,
219 ctx: &CoolContext,
220) -> bool {
221 match predicate {
222 ProcedurePredicate::AuthNotNull => ctx.is_authenticated(),
223 ProcedurePredicate::AuthIsNull => !ctx.is_authenticated(),
224 ProcedurePredicate::HasRole { role } => context_has_role(ctx, role),
225 ProcedurePredicate::InTenant { tenant_id } => context_in_tenant(ctx, tenant_id),
226 ProcedurePredicate::AuthFieldEqLiteral { auth_field, value } => ctx
227 .auth_field(auth_field)
228 .is_some_and(|candidate| value_matches_literal(candidate, value)),
229 ProcedurePredicate::AuthFieldNeLiteral { auth_field, value } => ctx
230 .auth_field(auth_field)
231 .is_some_and(|candidate| !value_matches_literal(candidate, value)),
232 ProcedurePredicate::InputFieldIsTrue { field } => args
233 .procedure_arg_value(field)
234 .is_some_and(|value| value == Value::Bool(true)),
235 ProcedurePredicate::InputFieldEqLiteral { field, value } => args
236 .procedure_arg_value(field)
237 .is_some_and(|candidate| value_matches_literal(&candidate, value)),
238 ProcedurePredicate::InputFieldNeLiteral { field, value } => args
239 .procedure_arg_value(field)
240 .is_some_and(|candidate| !value_matches_literal(&candidate, value)),
241 ProcedurePredicate::InputFieldEqAuth { field, auth_field } => {
242 match (args.procedure_arg_value(field), ctx.auth_field(auth_field)) {
243 (Some(left), Some(right)) => &left == right,
244 _ => false,
245 }
246 }
247 ProcedurePredicate::InputFieldNeAuth { field, auth_field } => {
248 match (args.procedure_arg_value(field), ctx.auth_field(auth_field)) {
249 (Some(left), Some(right)) => &left != right,
250 _ => false,
251 }
252 }
253 ProcedurePredicate::InputFieldEqInput { field, other_field } => {
254 match (
255 args.procedure_arg_value(field),
256 args.procedure_arg_value(other_field),
257 ) {
258 (Some(left), Some(right)) => left == right,
259 _ => false,
260 }
261 }
262 ProcedurePredicate::InputFieldNeInput { field, other_field } => {
263 match (
264 args.procedure_arg_value(field),
265 args.procedure_arg_value(other_field),
266 ) {
267 (Some(left), Some(right)) => left != right,
268 _ => false,
269 }
270 }
271 }
272}
273
274fn value_matches_literal(value: &Value, literal: ProcedurePolicyLiteral) -> bool {
275 match (value, literal) {
276 (Value::Bool(left), ProcedurePolicyLiteral::Bool(right)) => *left == right,
277 (Value::Int(left), ProcedurePolicyLiteral::Int(right)) => *left == right,
278 (Value::String(left), ProcedurePolicyLiteral::String(right)) => left == right,
279 _ => false,
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use std::collections::BTreeMap;
287
288 #[test]
289 fn has_role_checks_top_level_and_actor_role() {
290 let top_level =
291 CoolContext::authenticated([("role".to_owned(), Value::String("admin".to_owned()))]);
292 assert!(context_has_role(&top_level, "admin"));
293 assert!(!context_has_role(&top_level, "member"));
294
295 let actor_role = CoolContext::authenticated([(
296 "actor".to_owned(),
297 Value::Map(BTreeMap::from([(
298 "role".to_owned(),
299 Value::String("merchant".to_owned()),
300 )])),
301 )]);
302 assert!(context_has_role(&actor_role, "merchant"));
303 }
304
305 #[test]
306 fn in_tenant_checks_structured_tenant_id() {
307 let ctx = CoolContext::authenticated([(
308 "tenant".to_owned(),
309 Value::Map(BTreeMap::from([(
310 "id".to_owned(),
311 Value::String("tenant_1".to_owned()),
312 )])),
313 )]);
314 assert!(context_in_tenant(&ctx, "tenant_1"));
315 assert!(!context_in_tenant(&ctx, "tenant_2"));
316 }
317}