openzeppelin_relayer/models/
secret_string.rs

1//! SecretString - A container for sensitive string data
2//!
3//! This module provides a secure string implementation that protects sensitive
4//! data in memory and prevents it from being accidentally exposed through logs,
5//! serialization, or debug output.
6//!
7//! The `SecretString` type wraps a `SecretVec<u8>` and provides methods for
8//! securely handling string data, including zeroizing the memory when the
9//! string is dropped.
10use std::{fmt, sync::Mutex};
11
12use secrets::SecretVec;
13use serde::{Deserialize, Serialize};
14use zeroize::Zeroizing;
15
16pub struct SecretString(Mutex<SecretVec<u8>>);
17
18impl Clone for SecretString {
19    fn clone(&self) -> Self {
20        let secret_vec = self.with_secret_vec(|secret_vec| secret_vec.clone());
21        Self(Mutex::new(secret_vec))
22    }
23}
24
25impl SecretString {
26    /// Creates a new SecretString from a regular string
27    ///
28    /// The input string's content is copied into secure memory and protected.
29    pub fn new(s: &str) -> Self {
30        let bytes = Zeroizing::new(s.as_bytes().to_vec());
31        let secret_vec = SecretVec::new(bytes.len(), |buffer| {
32            buffer.copy_from_slice(&bytes);
33        });
34        Self(Mutex::new(secret_vec))
35    }
36
37    /// Access the SecretVec with a provided function
38    ///
39    /// This is a private helper method to safely access the locked SecretVec
40    fn with_secret_vec<F, R>(&self, f: F) -> R
41    where
42        F: FnOnce(&SecretVec<u8>) -> R,
43    {
44        let guard = match self.0.lock() {
45            Ok(guard) => guard,
46            Err(poisoned) => poisoned.into_inner(),
47        };
48
49        f(&guard)
50    }
51
52    /// Access the secret string content with a provided function
53    ///
54    /// This method allows temporary access to the string content
55    /// without creating a copy of the string.
56    pub fn as_str<F, R>(&self, f: F) -> R
57    where
58        F: FnOnce(&str) -> R,
59    {
60        self.with_secret_vec(|secret_vec| {
61            let bytes = secret_vec.borrow();
62            let s = unsafe { std::str::from_utf8_unchecked(&bytes) };
63            f(s)
64        })
65    }
66
67    /// Create a temporary copy of the string content
68    ///
69    /// Returns a zeroizing string that will be securely erased when dropped.
70    /// Only use this when absolutely necessary as it creates a copy of the secret.
71    pub fn to_str(&self) -> Zeroizing<String> {
72        self.with_secret_vec(|secret_vec| {
73            let bytes = secret_vec.borrow();
74            let s = unsafe { std::str::from_utf8_unchecked(&bytes) };
75            Zeroizing::new(s.to_string())
76        })
77    }
78
79    /// Check if the secret string is empty
80    ///
81    /// Returns true if the string contains no bytes.
82    pub fn is_empty(&self) -> bool {
83        self.with_secret_vec(|secret_vec| secret_vec.is_empty())
84    }
85
86    /// Check if the secret string meets a minimum length requirement
87    ///
88    /// Returns true if the string has at least the specified length.
89    pub fn has_minimum_length(&self, min_length: usize) -> bool {
90        self.with_secret_vec(|secret_vec| {
91            let bytes = secret_vec.borrow();
92            bytes.len() >= min_length
93        })
94    }
95}
96
97impl Serialize for SecretString {
98    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
99    where
100        S: serde::Serializer,
101    {
102        serializer.serialize_str("REDACTED")
103    }
104}
105
106impl<'de> Deserialize<'de> for SecretString {
107    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
108    where
109        D: serde::Deserializer<'de>,
110    {
111        let s = Zeroizing::new(String::deserialize(deserializer)?);
112
113        Ok(SecretString::new(&s))
114    }
115}
116
117impl PartialEq for SecretString {
118    fn eq(&self, other: &Self) -> bool {
119        self.with_secret_vec(|self_vec| {
120            other.with_secret_vec(|other_vec| {
121                let self_bytes = self_vec.borrow();
122                let other_bytes = other_vec.borrow();
123
124                self_bytes.len() == other_bytes.len()
125                    && subtle::ConstantTimeEq::ct_eq(&*self_bytes, &*other_bytes).into()
126            })
127        })
128    }
129}
130
131impl fmt::Debug for SecretString {
132    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
133        write!(f, "SecretString(REDACTED)")
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use std::sync::{Arc, Barrier};
141    use std::thread;
142
143    #[test]
144    fn test_new_creates_valid_secret_string() {
145        let secret = SecretString::new("test_secret_value");
146
147        secret.as_str(|s| {
148            assert_eq!(s, "test_secret_value");
149        });
150    }
151
152    #[test]
153    fn test_empty_string_is_handled_correctly() {
154        let empty = SecretString::new("");
155
156        assert!(empty.is_empty());
157
158        empty.as_str(|s| {
159            assert_eq!(s, "");
160        });
161    }
162
163    #[test]
164    fn test_to_str_creates_correct_zeroizing_copy() {
165        let secret = SecretString::new("temporary_copy");
166
167        let copy = secret.to_str();
168
169        assert_eq!(&*copy, "temporary_copy");
170    }
171
172    #[test]
173    fn test_is_empty_returns_correct_value() {
174        let empty = SecretString::new("");
175        let non_empty = SecretString::new("not empty");
176
177        assert!(empty.is_empty());
178        assert!(!non_empty.is_empty());
179    }
180
181    #[test]
182    fn test_serialization_redacts_content() {
183        let secret = SecretString::new("should_not_appear_in_serialized_form");
184
185        let serialized = serde_json::to_string(&secret).unwrap();
186
187        assert_eq!(serialized, "\"REDACTED\"");
188        assert!(!serialized.contains("should_not_appear_in_serialized_form"));
189    }
190
191    #[test]
192    fn test_deserialization_creates_valid_secret_string() {
193        let json_str = "\"deserialized_secret\"";
194
195        let deserialized: SecretString = serde_json::from_str(json_str).unwrap();
196
197        deserialized.as_str(|s| {
198            assert_eq!(s, "deserialized_secret");
199        });
200    }
201
202    #[test]
203    fn test_equality_comparison_works_correctly() {
204        let secret1 = SecretString::new("same_value");
205        let secret2 = SecretString::new("same_value");
206        let secret3 = SecretString::new("different_value");
207
208        assert_eq!(secret1, secret2);
209        assert_ne!(secret1, secret3);
210    }
211
212    #[test]
213    fn test_debug_output_redacts_content() {
214        let secret = SecretString::new("should_not_appear_in_debug");
215
216        let debug_str = format!("{:?}", secret);
217
218        assert_eq!(debug_str, "SecretString(REDACTED)");
219        assert!(!debug_str.contains("should_not_appear_in_debug"));
220    }
221
222    #[test]
223    fn test_thread_safety() {
224        let secret = SecretString::new("shared_across_threads");
225        let num_threads = 10;
226        let barrier = Arc::new(Barrier::new(num_threads));
227        let mut handles = vec![];
228
229        for i in 0..num_threads {
230            let thread_secret = secret.clone();
231            let thread_barrier = barrier.clone();
232
233            let handle = thread::spawn(move || {
234                // Wait for all threads to be ready
235                thread_barrier.wait();
236
237                // Verify the secret content
238                thread_secret.as_str(|s| {
239                    assert_eq!(s, "shared_across_threads");
240                });
241
242                // Test other methods
243                assert!(!thread_secret.is_empty());
244                let copy = thread_secret.to_str();
245                assert_eq!(&*copy, "shared_across_threads");
246
247                // Return thread ID to verify all threads ran
248                i
249            });
250
251            handles.push(handle);
252        }
253
254        // Verify all threads completed successfully
255        let mut completed_threads = vec![];
256        for handle in handles {
257            completed_threads.push(handle.join().unwrap());
258        }
259
260        // Sort results to make comparison easier
261        completed_threads.sort();
262        assert_eq!(completed_threads, (0..num_threads).collect::<Vec<_>>());
263    }
264
265    #[test]
266    fn test_unicode_handling() {
267        let unicode_string = "こんにちは世界!";
268        let secret = SecretString::new(unicode_string);
269
270        secret.as_str(|s| {
271            assert_eq!(s, unicode_string);
272            assert_eq!(s.chars().count(), 8); // 7 Unicode characters + 1 ASCII
273        });
274    }
275
276    #[test]
277    fn test_special_characters_handling() {
278        let special_chars = "!@#$%^&*()_+{}|:<>?~`-=[]\\;',./";
279        let secret = SecretString::new(special_chars);
280
281        secret.as_str(|s| {
282            assert_eq!(s, special_chars);
283        });
284    }
285
286    #[test]
287    fn test_very_long_string() {
288        // Create a long string (100,000 characters)
289        let long_string = "a".repeat(100_000);
290        let secret = SecretString::new(&long_string);
291
292        secret.as_str(|s| {
293            assert_eq!(s.len(), 100_000);
294            assert_eq!(s, long_string);
295        });
296
297        assert_eq!(secret.0.lock().unwrap().len(), 100_000);
298    }
299
300    #[test]
301    fn test_has_minimum_length() {
302        // Create test strings of various lengths
303        let empty = SecretString::new("");
304        let short = SecretString::new("abc");
305        let medium = SecretString::new("abcdefghij"); // 10 characters
306        let long = SecretString::new("abcdefghijklmnopqrst"); // 20 characters
307
308        // Test with minimum length 0
309        assert!(empty.has_minimum_length(0));
310        assert!(short.has_minimum_length(0));
311        assert!(medium.has_minimum_length(0));
312        assert!(long.has_minimum_length(0));
313
314        // Test with minimum length 1
315        assert!(!empty.has_minimum_length(1));
316        assert!(short.has_minimum_length(1));
317        assert!(medium.has_minimum_length(1));
318        assert!(long.has_minimum_length(1));
319
320        // Test with exact length matches
321        assert!(empty.has_minimum_length(0));
322        assert!(short.has_minimum_length(3));
323        assert!(medium.has_minimum_length(10));
324        assert!(long.has_minimum_length(20));
325
326        // Test with length exceeding the string
327        assert!(!empty.has_minimum_length(1));
328        assert!(!short.has_minimum_length(4));
329        assert!(!medium.has_minimum_length(11));
330        assert!(!long.has_minimum_length(21));
331
332        // Test with significantly larger minimum length
333        assert!(!short.has_minimum_length(100));
334        assert!(!medium.has_minimum_length(100));
335        assert!(!long.has_minimum_length(100));
336    }
337}