Skip to main content

cratestack_axum/ratelimit/
store.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::time::Instant;
5
6use async_trait::async_trait;
7use cratestack_core::CoolError;
8
9use super::config::{RateLimitConfig, RateLimitDecision};
10
11/// Pluggable storage for token-bucket state. Implementations must be safe
12/// to share across tasks (use a Mutex internally, or rely on the backing
13/// store's atomicity).
14#[async_trait]
15pub trait RateLimitStore: Send + Sync + 'static {
16    /// Atomically consume one token for `key`. Returns the decision based
17    /// on the bucket state after the consumption attempt.
18    async fn consume(
19        &self,
20        key: &str,
21        config: RateLimitConfig,
22    ) -> Result<RateLimitDecision, CoolError>;
23}
24
25#[derive(Debug)]
26struct Bucket {
27    tokens: f64,
28    last_refill: Instant,
29}
30
31/// In-memory `RateLimitStore`. Suitable for single-replica deployments and
32/// development; banks running multi-replica clusters need a Redis-backed
33/// implementation so the limit is enforced cluster-wide.
34#[derive(Debug, Clone, Default)]
35pub struct InMemoryRateLimitStore {
36    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
37}
38
39impl InMemoryRateLimitStore {
40    pub fn new() -> Self {
41        Self::default()
42    }
43}
44
45#[async_trait]
46impl RateLimitStore for InMemoryRateLimitStore {
47    async fn consume(
48        &self,
49        key: &str,
50        config: RateLimitConfig,
51    ) -> Result<RateLimitDecision, CoolError> {
52        let mut buckets = self
53            .buckets
54            .lock()
55            .map_err(|_| CoolError::Internal("rate limit store poisoned".to_owned()))?;
56        let now = Instant::now();
57        let bucket = buckets.entry(key.to_owned()).or_insert(Bucket {
58            tokens: config.burst as f64,
59            last_refill: now,
60        });
61        let elapsed = now
62            .saturating_duration_since(bucket.last_refill)
63            .as_secs_f64();
64        bucket.tokens =
65            (bucket.tokens + elapsed * config.refill_per_second).min(config.burst as f64);
66        bucket.last_refill = now;
67        if bucket.tokens >= 1.0 {
68            bucket.tokens -= 1.0;
69            Ok(RateLimitDecision::Allowed {
70                remaining: bucket.tokens.floor() as u32,
71            })
72        } else {
73            let need = 1.0 - bucket.tokens;
74            let secs = (need / config.refill_per_second).ceil() as u32;
75            Ok(RateLimitDecision::Throttled {
76                retry_after_secs: secs.max(1),
77            })
78        }
79    }
80}