openzeppelin_relayer/repositories/
transaction_counter.rs

1//! This module provides an in-memory implementation of a transaction counter.
2//!
3//! The `InMemoryTransactionCounter` struct is used to track and manage transaction nonces
4//! for different relayers and addresses. It supports operations to get, increment, decrement,
5//! and set nonce values. This implementation uses a `DashMap` for concurrent access and
6//! modification of the nonce values.
7use dashmap::DashMap;
8use serde::Serialize;
9use thiserror::Error;
10
11#[cfg(test)]
12use mockall::automock;
13
14#[derive(Debug, Default, Clone)]
15pub struct InMemoryTransactionCounter {
16    store: DashMap<(String, String), u64>, // (relayer_id, address) -> nonce/sequence
17}
18
19impl InMemoryTransactionCounter {
20    pub fn new() -> Self {
21        Self {
22            store: DashMap::new(),
23        }
24    }
25}
26
27#[derive(Error, Debug, Serialize)]
28pub enum TransactionCounterError {
29    #[error("No sequence found for relayer {relayer_id} and address {address}")]
30    SequenceNotFound { relayer_id: String, address: String },
31    #[error("Counter not found for {0}")]
32    NotFound(String),
33}
34
35#[allow(dead_code)]
36#[cfg_attr(test, automock)]
37pub trait TransactionCounterTrait {
38    fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, TransactionCounterError>;
39
40    fn get_and_increment(
41        &self,
42        relayer_id: &str,
43        address: &str,
44    ) -> Result<u64, TransactionCounterError>;
45
46    fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, TransactionCounterError>;
47
48    fn set(
49        &self,
50        relayer_id: &str,
51        address: &str,
52        value: u64,
53    ) -> Result<(), TransactionCounterError>;
54}
55
56impl TransactionCounterTrait for InMemoryTransactionCounter {
57    fn get(&self, relayer_id: &str, address: &str) -> Result<Option<u64>, TransactionCounterError> {
58        Ok(self
59            .store
60            .get(&(relayer_id.to_string(), address.to_string()))
61            .map(|n| *n))
62    }
63
64    fn get_and_increment(
65        &self,
66        relayer_id: &str,
67        address: &str,
68    ) -> Result<u64, TransactionCounterError> {
69        let mut entry = self
70            .store
71            .entry((relayer_id.to_string(), address.to_string()))
72            .or_insert(0);
73        let current = *entry;
74        *entry += 1;
75        Ok(current)
76    }
77
78    fn decrement(&self, relayer_id: &str, address: &str) -> Result<u64, TransactionCounterError> {
79        let mut entry = self
80            .store
81            .get_mut(&(relayer_id.to_string(), address.to_string()))
82            .ok_or_else(|| {
83                TransactionCounterError::NotFound(format!("Counter not found for {}", address))
84            })?;
85        if *entry > 0 {
86            *entry -= 1;
87        }
88        Ok(*entry)
89    }
90
91    fn set(
92        &self,
93        relayer_id: &str,
94        address: &str,
95        value: u64,
96    ) -> Result<(), TransactionCounterError> {
97        self.store
98            .insert((relayer_id.to_string(), address.to_string()), value);
99        Ok(())
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_decrement_not_found() {
109        let store = InMemoryTransactionCounter::new();
110        let result = store.decrement("nonexistent", "0x1234");
111        assert!(matches!(result, Err(TransactionCounterError::NotFound(_))));
112    }
113
114    #[test]
115    fn test_nonce_store() {
116        let store = InMemoryTransactionCounter::new();
117        let relayer_id = "relayer_1";
118        let address = "0x1234";
119
120        // Initially should be None
121        assert_eq!(store.get(relayer_id, address).unwrap(), None);
122
123        // Set a value explicitly
124        store.set(relayer_id, address, 100).unwrap();
125        assert_eq!(store.get(relayer_id, address).unwrap(), Some(100));
126
127        // Increment
128        assert_eq!(store.get_and_increment(relayer_id, address).unwrap(), 100);
129        assert_eq!(store.get(relayer_id, address).unwrap(), Some(101));
130
131        // Decrement
132        assert_eq!(store.decrement(relayer_id, address).unwrap(), 100);
133        assert_eq!(store.get(relayer_id, address).unwrap(), Some(100));
134    }
135
136    #[test]
137    fn test_multiple_relayers() {
138        let store = InMemoryTransactionCounter::new();
139
140        // Setup different relayer/address combinations
141        store.set("relayer_1", "0x1234", 100).unwrap();
142        store.set("relayer_1", "0x5678", 200).unwrap();
143        store.set("relayer_2", "0x1234", 300).unwrap();
144
145        // Verify independent counters
146        assert_eq!(store.get("relayer_1", "0x1234").unwrap(), Some(100));
147        assert_eq!(store.get("relayer_1", "0x5678").unwrap(), Some(200));
148        assert_eq!(store.get("relayer_2", "0x1234").unwrap(), Some(300));
149
150        // Verify independent increments
151        assert_eq!(store.get_and_increment("relayer_1", "0x1234").unwrap(), 100);
152        assert_eq!(store.get_and_increment("relayer_1", "0x1234").unwrap(), 101);
153        assert_eq!(store.get_and_increment("relayer_1", "0x5678").unwrap(), 200);
154        assert_eq!(store.get_and_increment("relayer_1", "0x5678").unwrap(), 201);
155        assert_eq!(store.get("relayer_2", "0x1234").unwrap(), Some(300));
156    }
157}