openzeppelin_relayer/config/
rate_limit.rs

1//! This module provides rate limiting functionality using API keys.
2
3use crate::constants::{AUTHORIZATION_HEADER_NAME, PUBLIC_ENDPOINTS};
4use actix_governor::governor::clock::{Clock, DefaultClock, QuantaInstant};
5use actix_governor::governor::NotUntil;
6use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
7use actix_web::{
8    dev::ServiceRequest,
9    http::{header::ContentType, StatusCode},
10    HttpResponse, HttpResponseBuilder,
11};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
15pub struct ApiKeyRateLimit;
16
17impl KeyExtractor for ApiKeyRateLimit {
18    type Key = String;
19    type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
20
21    fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
22        let path = req.path();
23        let is_public_endpoint = PUBLIC_ENDPOINTS
24            .iter()
25            .any(|prefix| path.starts_with(prefix));
26
27        if is_public_endpoint {
28            return Ok("swagger-ui-exempt".to_string());
29        }
30        req.headers()
31            .get(AUTHORIZATION_HEADER_NAME)
32            .and_then(|token| token.to_str().ok())
33            .map(|token| token.trim().to_owned())
34            .ok_or_else(|| {
35                Self::KeyExtractionError::new(
36					r#"{"success": false, "code": 401, "error": "Unauthorized", "message": "Unauthorized"}"#,
37				)
38				.set_content_type(ContentType::json())
39				.set_status_code(StatusCode::UNAUTHORIZED)
40            })
41    }
42
43    fn exceed_rate_limit_response(
44        &self,
45        negative: &NotUntil<QuantaInstant>,
46        mut response: HttpResponseBuilder,
47    ) -> HttpResponse {
48        let wait_time = negative
49            .wait_time_from(DefaultClock::default().now())
50            .as_secs();
51        response.content_type(ContentType::json())
52            .body(
53                format!(
54                    r#"{{ "success": false, "code":429, "error": "TooManyRequests", "message": "Too Many Requests", "after": {wait_time}}}"#
55                )
56            )
57    }
58}
59
60#[cfg(test)]
61mod tests {
62
63    use super::*;
64    use actix_governor::governor::{Quota, RateLimiter};
65    use actix_web::test::TestRequest;
66    use actix_web::{body::MessageBody, http::header::HeaderValue};
67    use std::num::NonZeroU32;
68
69    #[tokio::test]
70    async fn test_extract_with_valid_api_key() {
71        let api_key = "test-api-key";
72        let req = TestRequest::default()
73            .insert_header((AUTHORIZATION_HEADER_NAME, api_key))
74            .to_srv_request();
75
76        let extractor = ApiKeyRateLimit;
77        let result = extractor.extract(&req);
78
79        assert!(result.is_ok());
80        assert_eq!(result.unwrap(), api_key);
81    }
82
83    #[tokio::test]
84    async fn test_extract_with_whitespace_in_api_key() {
85        let api_key = "  test-api-key-with-spaces  ";
86        let expected_key = "test-api-key-with-spaces";
87        let req = TestRequest::default()
88            .insert_header((AUTHORIZATION_HEADER_NAME, api_key))
89            .to_srv_request();
90
91        let extractor = ApiKeyRateLimit;
92        let result = extractor.extract(&req);
93
94        assert!(result.is_ok());
95        assert_eq!(result.unwrap(), expected_key);
96    }
97
98    #[tokio::test]
99    async fn test_exceed_rate_limit_response() {
100        let extractor = ApiKeyRateLimit;
101
102        // Create a keyed rate limiter
103        let quota = Quota::per_second(NonZeroU32::new(1).unwrap());
104        let limiter = RateLimiter::keyed(quota);
105
106        // Make two requests to trigger rate limiting
107        let _ = limiter.check_key(&"test_key");
108        let negative = limiter.check_key(&"test_key").unwrap_err();
109
110        let response_builder = HttpResponse::TooManyRequests();
111        let response = extractor.exceed_rate_limit_response(&negative, response_builder);
112
113        // Check status code and content type
114        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
115        assert_eq!(
116            response
117                .headers()
118                .get(actix_web::http::header::CONTENT_TYPE),
119            Some(&HeaderValue::from_static("application/json"))
120        );
121
122        // Check response body
123        let body = response.into_body();
124        let bytes = body.try_into_bytes().unwrap();
125        let body_str = std::str::from_utf8(&bytes).unwrap();
126
127        // Verify JSON structure contains expected fields
128        assert!(body_str.contains(r#""success": false"#));
129        assert!(body_str.contains(r#""code":429"#));
130        assert!(body_str.contains(r#""error": "TooManyRequests""#));
131        assert!(body_str.contains(r#""message": "Too Many Requests""#));
132        assert!(body_str.contains(r#""after":"#));
133    }
134}