openzeppelin_relayer/repositories/
transaction_counter.rs1use dashmap::DashMap;
8use serde::Serialize;
9use thiserror::Error;
10
11#[cfg(test)]
12use mockall::automock;
13
14#[derive(Debug, Default, Clone)]
15pub struct InMemoryTransactionCounter {
16 store: DashMap<(String, String), u64>, }
18
19impl InMemoryTransactionCounter {
20 pub fn new() -> Self {
21 Self {
22 store: DashMap::new(),
23 }
24 }
25}
26
27#[derive(Error, Debug, Serialize)]
28pub enum TransactionCounterError {
29 #[error("No sequence found for relayer {relayer_id} and address {address}")]
30 SequenceNotFound { relayer_id: String, address: String },
31 #[error("Counter not found for {0}")]
32 NotFound(String),
33}
34
35#[allow(dead_code)]
36#[cfg_attr(test, automock)]
37pub trait TransactionCounterTrait {
38 fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, TransactionCounterError>;
39
40 fn get_and_increment(
41 &self,
42 relayer_id: &str,
43 address: &str,
44 ) -> Result<u64, TransactionCounterError>;
45
46 fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, TransactionCounterError>;
47
48 fn set(
49 &self,
50 relayer_id: &str,
51 address: &str,
52 value: u64,
53 ) -> Result<(), TransactionCounterError>;
54}
55
56impl TransactionCounterTrait for InMemoryTransactionCounter {
57 fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, TransactionCounterError> {
58 Ok(self
59 .store
60 .get(&(relayer_id.to_string(), address.to_string()))
61 .map(|n| *n))
62 }
63
64 fn get_and_increment(
65 &self,
66 relayer_id: &str,
67 address: &str,
68 ) -> Result<u64, TransactionCounterError> {
69 let mut entry = self
70 .store
71 .entry((relayer_id.to_string(), address.to_string()))
72 .or_insert(0);
73 let current = *entry;
74 *entry += 1;
75 Ok(current)
76 }
77
78 fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, TransactionCounterError> {
79 let mut entry = self
80 .store
81 .get_mut(&(relayer_id.to_string(), address.to_string()))
82 .ok_or_else(|| {
83 TransactionCounterError::NotFound(format!("Counter not found for {}", address))
84 })?;
85 if *entry > 0 {
86 *entry -= 1;
87 }
88 Ok(*entry)
89 }
90
91 fn set(
92 &self,
93 relayer_id: &str,
94 address: &str,
95 value: u64,
96 ) -> Result<(), TransactionCounterError> {
97 self.store
98 .insert((relayer_id.to_string(), address.to_string()), value);
99 Ok(())
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_decrement_not_found() {
109 let store = InMemoryTransactionCounter::new();
110 let result = store.decrement("nonexistent", "0x1234");
111 assert!(matches!(result, Err(TransactionCounterError::NotFound(_))));
112 }
113
114 #[test]
115 fn test_nonce_store() {
116 let store = InMemoryTransactionCounter::new();
117 let relayer_id = "relayer_1";
118 let address = "0x1234";
119
120 assert_eq!(store.get(relayer_id, address).unwrap(), None);
122
123 store.set(relayer_id, address, 100).unwrap();
125 assert_eq!(store.get(relayer_id, address).unwrap(), Some(100));
126
127 assert_eq!(store.get_and_increment(relayer_id, address).unwrap(), 100);
129 assert_eq!(store.get(relayer_id, address).unwrap(), Some(101));
130
131 assert_eq!(store.decrement(relayer_id, address).unwrap(), 100);
133 assert_eq!(store.get(relayer_id, address).unwrap(), Some(100));
134 }
135
136 #[test]
137 fn test_multiple_relayers() {
138 let store = InMemoryTransactionCounter::new();
139
140 store.set("relayer_1", "0x1234", 100).unwrap();
142 store.set("relayer_1", "0x5678", 200).unwrap();
143 store.set("relayer_2", "0x1234", 300).unwrap();
144
145 assert_eq!(store.get("relayer_1", "0x1234").unwrap(), Some(100));
147 assert_eq!(store.get("relayer_1", "0x5678").unwrap(), Some(200));
148 assert_eq!(store.get("relayer_2", "0x1234").unwrap(), Some(300));
149
150 assert_eq!(store.get_and_increment("relayer_1", "0x1234").unwrap(), 100);
152 assert_eq!(store.get_and_increment("relayer_1", "0x1234").unwrap(), 101);
153 assert_eq!(store.get_and_increment("relayer_1", "0x5678").unwrap(), 200);
154 assert_eq!(store.get_and_increment("relayer_1", "0x5678").unwrap(), 201);
155 assert_eq!(store.get("relayer_2", "0x1234").unwrap(), Some(300));
156 }
157}