cratestack_axum/ratelimit/
store.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::time::Instant;
5
6use async_trait::async_trait;
7use cratestack_core::CoolError;
8
9use super::config::{RateLimitConfig, RateLimitDecision};
10
11#[async_trait]
15pub trait RateLimitStore: Send + Sync + 'static {
16 async fn consume(
19 &self,
20 key: &str,
21 config: RateLimitConfig,
22 ) -> Result<RateLimitDecision, CoolError>;
23}
24
25#[derive(Debug)]
26struct Bucket {
27 tokens: f64,
28 last_refill: Instant,
29}
30
31#[derive(Debug, Clone, Default)]
35pub struct InMemoryRateLimitStore {
36 buckets: Arc<Mutex<HashMap<String, Bucket>>>,
37}
38
39impl InMemoryRateLimitStore {
40 pub fn new() -> Self {
41 Self::default()
42 }
43}
44
45#[async_trait]
46impl RateLimitStore for InMemoryRateLimitStore {
47 async fn consume(
48 &self,
49 key: &str,
50 config: RateLimitConfig,
51 ) -> Result<RateLimitDecision, CoolError> {
52 let mut buckets = self
53 .buckets
54 .lock()
55 .map_err(|_| CoolError::Internal("rate limit store poisoned".to_owned()))?;
56 let now = Instant::now();
57 let bucket = buckets.entry(key.to_owned()).or_insert(Bucket {
58 tokens: config.burst as f64,
59 last_refill: now,
60 });
61 let elapsed = now
62 .saturating_duration_since(bucket.last_refill)
63 .as_secs_f64();
64 bucket.tokens =
65 (bucket.tokens + elapsed * config.refill_per_second).min(config.burst as f64);
66 bucket.last_refill = now;
67 if bucket.tokens >= 1.0 {
68 bucket.tokens -= 1.0;
69 Ok(RateLimitDecision::Allowed {
70 remaining: bucket.tokens.floor() as u32,
71 })
72 } else {
73 let need = 1.0 - bucket.tokens;
74 let secs = (need / config.refill_per_second).ceil() as u32;
75 Ok(RateLimitDecision::Throttled {
76 retry_after_secs: secs.max(1),
77 })
78 }
79 }
80}