openzeppelin_relayer/services/plugins/
socket.rs1use crate::{jobs::JobProducerTrait, models::AppState};
52use actix_web::web;
53use std::sync::Arc;
54use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
55use tokio::net::{UnixListener, UnixStream};
56use tokio::sync::oneshot;
57
58use super::{
59 relayer_api::{RelayerApiTrait, Request},
60 PluginError,
61};
62
63pub struct SocketService {
64 socket_path: String,
65 listener: UnixListener,
66}
67
68impl SocketService {
69 pub fn new(socket_path: &str) -> Result<Self, PluginError> {
75 let _ = std::fs::remove_file(socket_path);
77
78 let listener =
79 UnixListener::bind(socket_path).map_err(|e| PluginError::SocketError(e.to_string()))?;
80
81 Ok(Self {
82 socket_path: socket_path.to_string(),
83 listener,
84 })
85 }
86
87 pub fn socket_path(&self) -> &str {
88 &self.socket_path
89 }
90
91 pub async fn listen<
103 J: JobProducerTrait + 'static,
104 R: RelayerApiTrait + 'static + Send + Sync,
105 >(
106 self,
107 shutdown_rx: oneshot::Receiver<()>,
108 state: Arc<web::ThinData<AppState<J>>>,
109 relayer_api: Arc<R>,
110 ) -> Result<Vec<serde_json::Value>, PluginError> {
111 let mut shutdown = shutdown_rx;
112
113 let mut traces = Vec::new();
114
115 loop {
116 let state = Arc::clone(&state);
117 let relayer_api = Arc::clone(&relayer_api);
118 tokio::select! {
119 Ok((stream, _)) = self.listener.accept() => {
120 let result = tokio::spawn(Self::handle_connection::<J, R>(stream, state, relayer_api))
121 .await
122 .map_err(|e| PluginError::SocketError(e.to_string()))?;
123
124 match result {
125 Ok(trace) => traces.extend(trace),
126 Err(e) => return Err(e),
127 }
128 }
129 _ = &mut shutdown => {
130 println!("Shutdown signal received. Closing listener.");
131 break;
132 }
133 }
134 }
135
136 Ok(traces)
137 }
138
139 async fn handle_connection<
151 J: JobProducerTrait + 'static,
152 R: RelayerApiTrait + 'static + Send + Sync,
153 >(
154 stream: UnixStream,
155 state: Arc<web::ThinData<AppState<J>>>,
156 relayer_api: Arc<R>,
157 ) -> Result<Vec<serde_json::Value>, PluginError> {
158 let (r, mut w) = stream.into_split();
159 let mut reader = BufReader::new(r).lines();
160 let mut traces = Vec::new();
161
162 while let Ok(Some(line)) = reader.next_line().await {
163 let trace: serde_json::Value = serde_json::from_str(&line)
164 .map_err(|e| PluginError::PluginError(format!("Failed to parse trace: {}", e)))?;
165 traces.push(trace);
166
167 let request: Request =
168 serde_json::from_str(&line).map_err(|e| PluginError::PluginError(e.to_string()))?;
169
170 let response = relayer_api.handle_request(request, &state).await;
171
172 let response_str = serde_json::to_string(&response)
173 .map_err(|e| PluginError::PluginError(e.to_string()))?
174 + "\n";
175
176 let _ = w.write_all(response_str.as_bytes()).await;
177 }
178
179 Ok(traces)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use std::time::Duration;
186
187 use crate::{
188 jobs::MockJobProducerTrait,
189 services::plugins::{MockRelayerApiTrait, PluginMethod, Response},
190 utils::mocks::mockutils::{create_mock_app_state, create_mock_evm_transaction_request},
191 };
192
193 use super::*;
194
195 use tempfile::tempdir;
196 use tokio::{
197 io::{AsyncBufReadExt, BufReader},
198 time::timeout,
199 };
200
201 #[tokio::test]
202 async fn test_socket_service_listen_and_shutdown() {
203 let temp_dir = tempdir().unwrap();
204 let socket_path = temp_dir.path().join("test.sock");
205
206 let mock_relayer = MockRelayerApiTrait::default();
207
208 let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
209
210 let state = create_mock_app_state(None, None, None, None).await;
211 let (shutdown_tx, shutdown_rx) = oneshot::channel();
212
213 let listen_handle = tokio::spawn(async move {
214 service
215 .listen(
216 shutdown_rx,
217 Arc::new(web::ThinData(state)),
218 Arc::new(mock_relayer),
219 )
220 .await
221 });
222
223 shutdown_tx.send(()).unwrap();
224
225 let result = timeout(Duration::from_millis(100), listen_handle).await;
226 assert!(result.is_ok(), "Listen handle timed out");
227 assert!(result.unwrap().is_ok(), "Listen handle returned error");
228 }
229
230 #[tokio::test]
231 async fn test_socket_service_handle_connection() {
232 let temp_dir = tempdir().unwrap();
233 let socket_path = temp_dir.path().join("test.sock");
234
235 let mut mock_relayer = MockRelayerApiTrait::default();
236
237 mock_relayer
238 .expect_handle_request::<MockJobProducerTrait>()
239 .returning(|_, _| Response {
240 request_id: "test".to_string(),
241 result: Some(serde_json::json!("test")),
242 error: None,
243 });
244
245 let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
246
247 let state = create_mock_app_state(None, None, None, None).await;
248 let (shutdown_tx, shutdown_rx) = oneshot::channel();
249
250 let listen_handle = tokio::spawn(async move {
251 service
252 .listen(
253 shutdown_rx,
254 Arc::new(web::ThinData(state)),
255 Arc::new(mock_relayer),
256 )
257 .await
258 });
259
260 tokio::time::sleep(Duration::from_millis(50)).await;
261
262 let mut client = UnixStream::connect(socket_path.to_str().unwrap())
263 .await
264 .unwrap();
265
266 let request = Request {
267 request_id: "test".to_string(),
268 relayer_id: "test".to_string(),
269 method: PluginMethod::SendTransaction,
270 payload: serde_json::json!(create_mock_evm_transaction_request()),
271 };
272
273 let request_json = serde_json::to_string(&request).unwrap() + "\n";
274
275 client.write_all(request_json.as_bytes()).await.unwrap();
276
277 let mut reader = BufReader::new(&mut client);
278 let mut response_str = String::new();
279 let read_result = timeout(
280 Duration::from_millis(1000),
281 reader.read_line(&mut response_str),
282 )
283 .await;
284
285 assert!(
286 read_result.is_ok(),
287 "Reading response timed out: {:?}",
288 read_result
289 );
290 let bytes_read = read_result.unwrap().unwrap();
291 assert!(bytes_read > 0, "No data received");
292 shutdown_tx.send(()).unwrap();
293
294 let response: Response = serde_json::from_str(&response_str).unwrap();
295
296 assert!(response.error.is_none(), "Error should be none");
297 assert!(response.result.is_some(), "Result should be some");
298 assert_eq!(
299 response.request_id, request.request_id,
300 "Request id mismatch"
301 );
302
303 client.shutdown().await.unwrap();
304
305 let traces = listen_handle.await.unwrap().unwrap();
306
307 assert_eq!(traces.len(), 1);
308 let expected: serde_json::Value = serde_json::from_str(&request_json).unwrap();
309 let actual: serde_json::Value =
310 serde_json::from_str(&serde_json::to_string(&traces[0]).unwrap()).unwrap();
311 assert_eq!(expected, actual, "Request json mismatch with trace");
312 }
313}