openzeppelin_relayer/services/plugins/
socket.rs

1//! This module is responsible for creating a socket connection to the relayer server.
2//! It is used to send requests to the relayer server and processing the responses.
3//! It also intercepts the logs, errors and return values.
4//!
5//! The socket connection is created using the `UnixListener`.
6//!
7//! 1. Creates a socket connection using the `UnixListener`.
8//! 2. Each request payload is stringified by the client and added as a new line to the socket.
9//! 3. The server reads the requests from the socket and processes them.
10//! 4. The server sends the responses back to the client in the same format. By writing a new line in the socket
11//! 5. When the client sends the socket shutdown signal, the server closes the socket connection.
12//!
13//! Example:
14//! 1. Create a new socket connection using `/tmp/socket.sock`
15//! 2. Client sends request (writes in `/tmp/socket.sock`):
16//! ```json
17//! {
18//!   "request_id": "123",
19//!   "relayer_id": "relayer1",
20//!   "method": "sendTransaction",
21//!   "payload": {
22//!     "to": "0x1234567890123456789012345678901234567890",
23//!     "value": "1000000000000000000"
24//!   }
25//! }
26//! ```
27//! 3. Server process the requests, calls the relayer API and sends back the response (writes in `/tmp/socket.sock`):
28//! ```json
29//! {
30//!   "request_id": "123",
31//!   "result": {
32//!     "id": "123",
33//!     "status": "success"
34//!   }
35//! }
36//! ```
37//! 4. Client reads the response (reads from `/tmp/socket.sock`):
38//! ```json
39//! {
40//!   "request_id": "123",
41//!   "result": {
42//!     "id": "123",
43//!     "status": "success"
44//!   }
45//! }
46//! ```
47//! 5. Once the client finishes the execution, it sends a shutdown signal to the server.
48//! 6. The server closes the socket connection.
49//!
50
51use crate::{jobs::JobProducerTrait, models::AppState};
52use actix_web::web;
53use std::sync::Arc;
54use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
55use tokio::net::{UnixListener, UnixStream};
56use tokio::sync::oneshot;
57
58use super::{
59    relayer_api::{RelayerApiTrait, Request},
60    PluginError,
61};
62
63pub struct SocketService {
64    socket_path: String,
65    listener: UnixListener,
66}
67
68impl SocketService {
69    /// Creates a new socket service.
70    ///
71    /// # Arguments
72    ///
73    /// * `socket_path` - The path to the socket file.
74    pub fn new(socket_path: &str) -> Result<Self, PluginError> {
75        // Remove existing socket file if it exists
76        let _ = std::fs::remove_file(socket_path);
77
78        let listener =
79            UnixListener::bind(socket_path).map_err(|e| PluginError::SocketError(e.to_string()))?;
80
81        Ok(Self {
82            socket_path: socket_path.to_string(),
83            listener,
84        })
85    }
86
87    pub fn socket_path(&self) -> &str {
88        &self.socket_path
89    }
90
91    /// Listens for incoming connections and processes the requests.
92    ///
93    /// # Arguments
94    ///
95    /// * `shutdown_rx` - A receiver for the shutdown signal.
96    /// * `state` - The application state.
97    /// * `relayer_api` - The relayer API.
98    ///
99    /// # Returns
100    ///
101    /// A vector of traces.
102    pub async fn listen<
103        J: JobProducerTrait + 'static,
104        R: RelayerApiTrait + 'static + Send + Sync,
105    >(
106        self,
107        shutdown_rx: oneshot::Receiver<()>,
108        state: Arc<web::ThinData<AppState<J>>>,
109        relayer_api: Arc<R>,
110    ) -> Result<Vec<serde_json::Value>, PluginError> {
111        let mut shutdown = shutdown_rx;
112
113        let mut traces = Vec::new();
114
115        loop {
116            let state = Arc::clone(&state);
117            let relayer_api = Arc::clone(&relayer_api);
118            tokio::select! {
119                Ok((stream, _)) = self.listener.accept() => {
120                    let result = tokio::spawn(Self::handle_connection::<J, R>(stream, state, relayer_api))
121                        .await
122                        .map_err(|e| PluginError::SocketError(e.to_string()))?;
123
124                    match result {
125                        Ok(trace) => traces.extend(trace),
126                        Err(e) => return Err(e),
127                    }
128                }
129                _ = &mut shutdown => {
130                    println!("Shutdown signal received. Closing listener.");
131                    break;
132                }
133            }
134        }
135
136        Ok(traces)
137    }
138
139    /// Handles a new connection.
140    ///
141    /// # Arguments
142    ///
143    /// * `stream` - The stream to the client.
144    /// * `state` - The application state.
145    /// * `relayer_api` - The relayer API.
146    ///
147    /// # Returns
148    ///
149    /// A vector of traces.
150    async fn handle_connection<
151        J: JobProducerTrait + 'static,
152        R: RelayerApiTrait + 'static + Send + Sync,
153    >(
154        stream: UnixStream,
155        state: Arc<web::ThinData<AppState<J>>>,
156        relayer_api: Arc<R>,
157    ) -> Result<Vec<serde_json::Value>, PluginError> {
158        let (r, mut w) = stream.into_split();
159        let mut reader = BufReader::new(r).lines();
160        let mut traces = Vec::new();
161
162        while let Ok(Some(line)) = reader.next_line().await {
163            let trace: serde_json::Value = serde_json::from_str(&line)
164                .map_err(|e| PluginError::PluginError(format!("Failed to parse trace: {}", e)))?;
165            traces.push(trace);
166
167            let request: Request =
168                serde_json::from_str(&line).map_err(|e| PluginError::PluginError(e.to_string()))?;
169
170            let response = relayer_api.handle_request(request, &state).await;
171
172            let response_str = serde_json::to_string(&response)
173                .map_err(|e| PluginError::PluginError(e.to_string()))?
174                + "\n";
175
176            let _ = w.write_all(response_str.as_bytes()).await;
177        }
178
179        Ok(traces)
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use std::time::Duration;
186
187    use crate::{
188        jobs::MockJobProducerTrait,
189        services::plugins::{MockRelayerApiTrait, PluginMethod, Response},
190        utils::mocks::mockutils::{create_mock_app_state, create_mock_evm_transaction_request},
191    };
192
193    use super::*;
194
195    use tempfile::tempdir;
196    use tokio::{
197        io::{AsyncBufReadExt, BufReader},
198        time::timeout,
199    };
200
201    #[tokio::test]
202    async fn test_socket_service_listen_and_shutdown() {
203        let temp_dir = tempdir().unwrap();
204        let socket_path = temp_dir.path().join("test.sock");
205
206        let mock_relayer = MockRelayerApiTrait::default();
207
208        let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
209
210        let state = create_mock_app_state(None, None, None, None).await;
211        let (shutdown_tx, shutdown_rx) = oneshot::channel();
212
213        let listen_handle = tokio::spawn(async move {
214            service
215                .listen(
216                    shutdown_rx,
217                    Arc::new(web::ThinData(state)),
218                    Arc::new(mock_relayer),
219                )
220                .await
221        });
222
223        shutdown_tx.send(()).unwrap();
224
225        let result = timeout(Duration::from_millis(100), listen_handle).await;
226        assert!(result.is_ok(), "Listen handle timed out");
227        assert!(result.unwrap().is_ok(), "Listen handle returned error");
228    }
229
230    #[tokio::test]
231    async fn test_socket_service_handle_connection() {
232        let temp_dir = tempdir().unwrap();
233        let socket_path = temp_dir.path().join("test.sock");
234
235        let mut mock_relayer = MockRelayerApiTrait::default();
236
237        mock_relayer
238            .expect_handle_request::<MockJobProducerTrait>()
239            .returning(|_, _| Response {
240                request_id: "test".to_string(),
241                result: Some(serde_json::json!("test")),
242                error: None,
243            });
244
245        let service = SocketService::new(socket_path.to_str().unwrap()).unwrap();
246
247        let state = create_mock_app_state(None, None, None, None).await;
248        let (shutdown_tx, shutdown_rx) = oneshot::channel();
249
250        let listen_handle = tokio::spawn(async move {
251            service
252                .listen(
253                    shutdown_rx,
254                    Arc::new(web::ThinData(state)),
255                    Arc::new(mock_relayer),
256                )
257                .await
258        });
259
260        tokio::time::sleep(Duration::from_millis(50)).await;
261
262        let mut client = UnixStream::connect(socket_path.to_str().unwrap())
263            .await
264            .unwrap();
265
266        let request = Request {
267            request_id: "test".to_string(),
268            relayer_id: "test".to_string(),
269            method: PluginMethod::SendTransaction,
270            payload: serde_json::json!(create_mock_evm_transaction_request()),
271        };
272
273        let request_json = serde_json::to_string(&request).unwrap() + "\n";
274
275        client.write_all(request_json.as_bytes()).await.unwrap();
276
277        let mut reader = BufReader::new(&mut client);
278        let mut response_str = String::new();
279        let read_result = timeout(
280            Duration::from_millis(1000),
281            reader.read_line(&mut response_str),
282        )
283        .await;
284
285        assert!(
286            read_result.is_ok(),
287            "Reading response timed out: {:?}",
288            read_result
289        );
290        let bytes_read = read_result.unwrap().unwrap();
291        assert!(bytes_read > 0, "No data received");
292        shutdown_tx.send(()).unwrap();
293
294        let response: Response = serde_json::from_str(&response_str).unwrap();
295
296        assert!(response.error.is_none(), "Error should be none");
297        assert!(response.result.is_some(), "Result should be some");
298        assert_eq!(
299            response.request_id, request.request_id,
300            "Request id mismatch"
301        );
302
303        client.shutdown().await.unwrap();
304
305        let traces = listen_handle.await.unwrap().unwrap();
306
307        assert_eq!(traces.len(), 1);
308        let expected: serde_json::Value = serde_json::from_str(&request_json).unwrap();
309        let actual: serde_json::Value =
310            serde_json::from_str(&serde_json::to_string(&traces[0]).unwrap()).unwrap();
311        assert_eq!(expected, actual, "Request json mismatch with trace");
312    }
313}