openzeppelin_relayer/config/
rate_limit.rs1use crate::constants::{AUTHORIZATION_HEADER_NAME, PUBLIC_ENDPOINTS};
4use actix_governor::governor::clock::{Clock, DefaultClock, QuantaInstant};
5use actix_governor::governor::NotUntil;
6use actix_governor::{KeyExtractor, SimpleKeyExtractionError};
7use actix_web::{
8 dev::ServiceRequest,
9 http::{header::ContentType, StatusCode},
10 HttpResponse, HttpResponseBuilder,
11};
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Serialize, Deserialize, Clone, Eq, PartialEq)]
15pub struct ApiKeyRateLimit;
16
17impl KeyExtractor for ApiKeyRateLimit {
18 type Key = String;
19 type KeyExtractionError = SimpleKeyExtractionError<&'static str>;
20
21 fn extract(&self, req: &ServiceRequest) -> Result<Self::Key, Self::KeyExtractionError> {
22 let path = req.path();
23 let is_public_endpoint = PUBLIC_ENDPOINTS
24 .iter()
25 .any(|prefix| path.starts_with(prefix));
26
27 if is_public_endpoint {
28 return Ok("swagger-ui-exempt".to_string());
29 }
30 req.headers()
31 .get(AUTHORIZATION_HEADER_NAME)
32 .and_then(|token| token.to_str().ok())
33 .map(|token| token.trim().to_owned())
34 .ok_or_else(|| {
35 Self::KeyExtractionError::new(
36 r#"{"success": false, "code": 401, "error": "Unauthorized", "message": "Unauthorized"}"#,
37 )
38 .set_content_type(ContentType::json())
39 .set_status_code(StatusCode::UNAUTHORIZED)
40 })
41 }
42
43 fn exceed_rate_limit_response(
44 &self,
45 negative: &NotUntil<QuantaInstant>,
46 mut response: HttpResponseBuilder,
47 ) -> HttpResponse {
48 let wait_time = negative
49 .wait_time_from(DefaultClock::default().now())
50 .as_secs();
51 response.content_type(ContentType::json())
52 .body(
53 format!(
54 r#"{{ "success": false, "code":429, "error": "TooManyRequests", "message": "Too Many Requests", "after": {wait_time}}}"#
55 )
56 )
57 }
58}
59
60#[cfg(test)]
61mod tests {
62
63 use super::*;
64 use actix_governor::governor::{Quota, RateLimiter};
65 use actix_web::test::TestRequest;
66 use actix_web::{body::MessageBody, http::header::HeaderValue};
67 use std::num::NonZeroU32;
68
69 #[tokio::test]
70 async fn test_extract_with_valid_api_key() {
71 let api_key = "test-api-key";
72 let req = TestRequest::default()
73 .insert_header((AUTHORIZATION_HEADER_NAME, api_key))
74 .to_srv_request();
75
76 let extractor = ApiKeyRateLimit;
77 let result = extractor.extract(&req);
78
79 assert!(result.is_ok());
80 assert_eq!(result.unwrap(), api_key);
81 }
82
83 #[tokio::test]
84 async fn test_extract_with_whitespace_in_api_key() {
85 let api_key = " test-api-key-with-spaces ";
86 let expected_key = "test-api-key-with-spaces";
87 let req = TestRequest::default()
88 .insert_header((AUTHORIZATION_HEADER_NAME, api_key))
89 .to_srv_request();
90
91 let extractor = ApiKeyRateLimit;
92 let result = extractor.extract(&req);
93
94 assert!(result.is_ok());
95 assert_eq!(result.unwrap(), expected_key);
96 }
97
98 #[tokio::test]
99 async fn test_exceed_rate_limit_response() {
100 let extractor = ApiKeyRateLimit;
101
102 let quota = Quota::per_second(NonZeroU32::new(1).unwrap());
104 let limiter = RateLimiter::keyed(quota);
105
106 let _ = limiter.check_key(&"test_key");
108 let negative = limiter.check_key(&"test_key").unwrap_err();
109
110 let response_builder = HttpResponse::TooManyRequests();
111 let response = extractor.exceed_rate_limit_response(&negative, response_builder);
112
113 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
115 assert_eq!(
116 response
117 .headers()
118 .get(actix_web::http::header::CONTENT_TYPE),
119 Some(&HeaderValue::from_static("application/json"))
120 );
121
122 let body = response.into_body();
124 let bytes = body.try_into_bytes().unwrap();
125 let body_str = std::str::from_utf8(&bytes).unwrap();
126
127 assert!(body_str.contains(r#""success": false"#));
129 assert!(body_str.contains(r#""code":429"#));
130 assert!(body_str.contains(r#""error": "TooManyRequests""#));
131 assert!(body_str.contains(r#""message": "Too Many Requests""#));
132 assert!(body_str.contains(r#""after":"#));
133 }
134}