Skip to main content

cratestack_axum/ratelimit/
layer.rs

1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::Request;
5use axum::response::Response;
6use http::{HeaderValue, StatusCode, header};
7use sha2::{Digest, Sha256};
8use tower::{Layer, Service};
9
10use super::config::{RateLimitConfig, RateLimitDecision};
11use super::store::RateLimitStore;
12
13#[derive(Clone)]
14pub struct RateLimitLayer {
15    store: Arc<dyn RateLimitStore>,
16    config: RateLimitConfig,
17    key_fn: Arc<dyn Fn(&Request) -> String + Send + Sync>,
18}
19
20impl RateLimitLayer {
21    pub fn new(store: Arc<dyn RateLimitStore>, config: RateLimitConfig) -> Self {
22        Self {
23            store,
24            config,
25            key_fn: Arc::new(default_key_fn),
26        }
27    }
28
29    pub fn with_key_fn(mut self, f: impl Fn(&Request) -> String + Send + Sync + 'static) -> Self {
30        self.key_fn = Arc::new(f);
31        self
32    }
33}
34
35fn default_key_fn(req: &Request) -> String {
36    req.headers()
37        .get(header::AUTHORIZATION)
38        .and_then(|v| v.to_str().ok())
39        .map(|s| {
40            let mut h = Sha256::new();
41            h.update(s.as_bytes());
42            format!("auth:{:x}", h.finalize())
43        })
44        .unwrap_or_else(|| "anonymous".to_owned())
45}
46
47impl<S> Layer<S> for RateLimitLayer {
48    type Service = RateLimitService<S>;
49
50    fn layer(&self, inner: S) -> Self::Service {
51        RateLimitService {
52            inner,
53            store: self.store.clone(),
54            config: self.config,
55            key_fn: self.key_fn.clone(),
56        }
57    }
58}
59
60#[derive(Clone)]
61pub struct RateLimitService<S> {
62    inner: S,
63    store: Arc<dyn RateLimitStore>,
64    config: RateLimitConfig,
65    key_fn: Arc<dyn Fn(&Request) -> String + Send + Sync>,
66}
67
68impl<S> Service<Request> for RateLimitService<S>
69where
70    S: Service<Request, Response = Response, Error = std::convert::Infallible>
71        + Clone
72        + Send
73        + 'static,
74    S::Future: Send + 'static,
75{
76    type Response = Response;
77    type Error = std::convert::Infallible;
78    type Future =
79        std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Self::Error>> + Send>>;
80
81    fn poll_ready(
82        &mut self,
83        cx: &mut std::task::Context<'_>,
84    ) -> std::task::Poll<Result<(), Self::Error>> {
85        self.inner.poll_ready(cx)
86    }
87
88    fn call(&mut self, req: Request) -> Self::Future {
89        let mut inner = self.inner.clone();
90        let store = self.store.clone();
91        let config = self.config;
92        let key = (self.key_fn)(&req);
93        Box::pin(async move {
94            match store.consume(&key, config).await {
95                Ok(RateLimitDecision::Allowed { remaining }) => {
96                    let mut response = inner.call(req).await?;
97                    if let Ok(value) = HeaderValue::from_str(&config.burst.to_string()) {
98                        response.headers_mut().insert("X-RateLimit-Limit", value);
99                    }
100                    if let Ok(value) = HeaderValue::from_str(&remaining.to_string()) {
101                        response
102                            .headers_mut()
103                            .insert("X-RateLimit-Remaining", value);
104                    }
105                    Ok(response)
106                }
107                Ok(RateLimitDecision::Throttled { retry_after_secs }) => {
108                    let mut response = Response::new(Body::from("rate limit exceeded"));
109                    *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
110                    if let Ok(value) = HeaderValue::from_str(&retry_after_secs.to_string()) {
111                        response.headers_mut().insert(header::RETRY_AFTER, value);
112                    }
113                    response.headers_mut().insert(
114                        header::CONTENT_TYPE,
115                        HeaderValue::from_static("text/plain; charset=utf-8"),
116                    );
117                    Ok(response)
118                }
119                Err(error) => {
120                    let mut response =
121                        Response::new(Body::from(error.public_message().into_owned()));
122                    *response.status_mut() = error.status_code();
123                    Ok(response)
124                }
125            }
126        })
127    }
128}