openzeppelin_relayer/repositories/
plugin.rs

1//! This module provides an in-memory implementation of plugins.
2//!
3//! The `InMemoryPluginRepository` struct is used to store and retrieve plugins
4//! script paths for further execution.
5use crate::{
6    config::PluginFileConfig,
7    models::PluginModel,
8    repositories::{ConversionError, RepositoryError},
9};
10use async_trait::async_trait;
11
12#[cfg(test)]
13use mockall::automock;
14
15use std::collections::HashMap;
16use tokio::sync::{Mutex, MutexGuard};
17
18#[derive(Debug)]
19pub struct InMemoryPluginRepository {
20    store: Mutex<HashMap<String, PluginModel>>,
21}
22
23impl InMemoryPluginRepository {
24    pub fn new() -> Self {
25        Self {
26            store: Mutex::new(HashMap::new()),
27        }
28    }
29
30    pub async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
31        let store = Self::acquire_lock(&self.store).await?;
32        Ok(store.get(id).cloned())
33    }
34
35    async fn acquire_lock<T>(lock: &Mutex<T>) -> Result<MutexGuard<T>, RepositoryError> {
36        Ok(lock.lock().await)
37    }
38}
39
40impl Default for InMemoryPluginRepository {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46#[async_trait]
47#[allow(dead_code)]
48#[cfg_attr(test, automock)]
49pub trait PluginRepositoryTrait {
50    async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError>;
51    async fn add(&self, plugin: PluginModel) -> Result<(), RepositoryError>;
52}
53
54#[async_trait]
55impl PluginRepositoryTrait for InMemoryPluginRepository {
56    async fn get_by_id(&self, id: &str) -> Result<Option<PluginModel>, RepositoryError> {
57        let store = Self::acquire_lock(&self.store).await?;
58        Ok(store.get(id).cloned())
59    }
60
61    async fn add(&self, plugin: PluginModel) -> Result<(), RepositoryError> {
62        let mut store = Self::acquire_lock(&self.store).await?;
63        store.insert(plugin.id.clone(), plugin);
64        Ok(())
65    }
66}
67
68impl TryFrom<PluginFileConfig> for PluginModel {
69    type Error = ConversionError;
70
71    fn try_from(config: PluginFileConfig) -> Result<Self, Self::Error> {
72        Ok(PluginModel {
73            id: config.id.clone(),
74            path: config.path.clone(),
75        })
76    }
77}
78
79impl PartialEq for PluginModel {
80    fn eq(&self, other: &Self) -> bool {
81        self.id == other.id && self.path == other.path
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use std::sync::Arc;
89
90    #[tokio::test]
91    async fn test_in_memory_plugin_repository() {
92        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
93
94        // Test add and get_by_id
95        let plugin = PluginModel {
96            id: "test-plugin".to_string(),
97            path: "test-path".to_string(),
98        };
99        plugin_repository.add(plugin.clone()).await.unwrap();
100        assert_eq!(
101            plugin_repository.get_by_id("test-plugin").await.unwrap(),
102            Some(plugin)
103        );
104    }
105
106    #[tokio::test]
107    async fn test_get_nonexistent_plugin() {
108        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
109
110        let result = plugin_repository.get_by_id("test-plugin").await;
111        assert!(matches!(result, Ok(None)));
112    }
113
114    #[tokio::test]
115    async fn test_try_from() {
116        let plugin = PluginFileConfig {
117            id: "test-plugin".to_string(),
118            path: "test-path".to_string(),
119        };
120        let result = PluginModel::try_from(plugin);
121        assert!(result.is_ok());
122        assert_eq!(
123            result.unwrap(),
124            PluginModel {
125                id: "test-plugin".to_string(),
126                path: "test-path".to_string(),
127            }
128        );
129    }
130
131    #[tokio::test]
132    async fn test_get_by_id() {
133        let plugin_repository = Arc::new(InMemoryPluginRepository::new());
134
135        let plugin = PluginModel {
136            id: "test-plugin".to_string(),
137            path: "test-path".to_string(),
138        };
139        plugin_repository.add(plugin.clone()).await.unwrap();
140        assert_eq!(
141            plugin_repository.get_by_id("test-plugin").await.unwrap(),
142            Some(plugin)
143        );
144    }
145}