cratestack_axum/ratelimit/
layer.rs1use std::sync::Arc;
2
3use axum::body::Body;
4use axum::extract::Request;
5use axum::response::Response;
6use http::{HeaderValue, StatusCode, header};
7use sha2::{Digest, Sha256};
8use tower::{Layer, Service};
9
10use super::config::{RateLimitConfig, RateLimitDecision};
11use super::store::RateLimitStore;
12
13#[derive(Clone)]
14pub struct RateLimitLayer {
15 store: Arc<dyn RateLimitStore>,
16 config: RateLimitConfig,
17 key_fn: Arc<dyn Fn(&Request) -> String + Send + Sync>,
18}
19
20impl RateLimitLayer {
21 pub fn new(store: Arc<dyn RateLimitStore>, config: RateLimitConfig) -> Self {
22 Self {
23 store,
24 config,
25 key_fn: Arc::new(default_key_fn),
26 }
27 }
28
29 pub fn with_key_fn(mut self, f: impl Fn(&Request) -> String + Send + Sync + 'static) -> Self {
30 self.key_fn = Arc::new(f);
31 self
32 }
33}
34
35fn default_key_fn(req: &Request) -> String {
36 req.headers()
37 .get(header::AUTHORIZATION)
38 .and_then(|v| v.to_str().ok())
39 .map(|s| {
40 let mut h = Sha256::new();
41 h.update(s.as_bytes());
42 format!("auth:{:x}", h.finalize())
43 })
44 .unwrap_or_else(|| "anonymous".to_owned())
45}
46
47impl<S> Layer<S> for RateLimitLayer {
48 type Service = RateLimitService<S>;
49
50 fn layer(&self, inner: S) -> Self::Service {
51 RateLimitService {
52 inner,
53 store: self.store.clone(),
54 config: self.config,
55 key_fn: self.key_fn.clone(),
56 }
57 }
58}
59
60#[derive(Clone)]
61pub struct RateLimitService<S> {
62 inner: S,
63 store: Arc<dyn RateLimitStore>,
64 config: RateLimitConfig,
65 key_fn: Arc<dyn Fn(&Request) -> String + Send + Sync>,
66}
67
68impl<S> Service<Request> for RateLimitService<S>
69where
70 S: Service<Request, Response = Response, Error = std::convert::Infallible>
71 + Clone
72 + Send
73 + 'static,
74 S::Future: Send + 'static,
75{
76 type Response = Response;
77 type Error = std::convert::Infallible;
78 type Future =
79 std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response, Self::Error>> + Send>>;
80
81 fn poll_ready(
82 &mut self,
83 cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<Result<(), Self::Error>> {
85 self.inner.poll_ready(cx)
86 }
87
88 fn call(&mut self, req: Request) -> Self::Future {
89 let mut inner = self.inner.clone();
90 let store = self.store.clone();
91 let config = self.config;
92 let key = (self.key_fn)(&req);
93 Box::pin(async move {
94 match store.consume(&key, config).await {
95 Ok(RateLimitDecision::Allowed { remaining }) => {
96 let mut response = inner.call(req).await?;
97 if let Ok(value) = HeaderValue::from_str(&config.burst.to_string()) {
98 response.headers_mut().insert("X-RateLimit-Limit", value);
99 }
100 if let Ok(value) = HeaderValue::from_str(&remaining.to_string()) {
101 response
102 .headers_mut()
103 .insert("X-RateLimit-Remaining", value);
104 }
105 Ok(response)
106 }
107 Ok(RateLimitDecision::Throttled { retry_after_secs }) => {
108 let mut response = Response::new(Body::from("rate limit exceeded"));
109 *response.status_mut() = StatusCode::TOO_MANY_REQUESTS;
110 if let Ok(value) = HeaderValue::from_str(&retry_after_secs.to_string()) {
111 response.headers_mut().insert(header::RETRY_AFTER, value);
112 }
113 response.headers_mut().insert(
114 header::CONTENT_TYPE,
115 HeaderValue::from_static("text/plain; charset=utf-8"),
116 );
117 Ok(response)
118 }
119 Err(error) => {
120 let mut response =
121 Response::new(Body::from(error.public_message().into_owned()));
122 *response.status_mut() = error.status_code();
123 Ok(response)
124 }
125 }
126 })
127 }
128}