1use async_trait::async_trait;
24use core::fmt;
25use log::debug;
26use once_cell::sync::Lazy;
27use serde::Serialize;
28use std::collections::HashMap;
29use std::hash::Hash;
30use std::sync::Arc;
31use std::time::{Duration, Instant};
32use thiserror::Error;
33use tokio::sync::RwLock;
34use vaultrs::{
35 auth::approle::login,
36 client::{VaultClient, VaultClientSettingsBuilder},
37 kv2, transit,
38};
39use zeroize::{Zeroize, ZeroizeOnDrop};
40
41#[derive(Error, Debug, Serialize)]
42pub enum VaultError {
43 #[error("Vault client error: {0}")]
44 ClientError(String),
45
46 #[error("Secret not found: {0}")]
47 SecretNotFound(String),
48
49 #[error("Authentication failed: {0}")]
50 AuthenticationFailed(String),
51
52 #[error("Configuration error: {0}")]
53 ConfigError(String),
54
55 #[error("Signing error: {0}")]
56 SigningError(String),
57}
58
59#[derive(Clone, Debug, PartialEq, Eq, Hash, Zeroize, ZeroizeOnDrop)]
61struct VaultCacheKey {
62 address: String,
63 role_id: String,
64 namespace: Option<String>,
65}
66
67impl fmt::Display for VaultCacheKey {
68 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69 write!(
70 f,
71 "{}|{}|{}",
72 self.address,
73 self.role_id,
74 self.namespace.as_deref().unwrap_or("")
75 )
76 }
77}
78
79struct TokenCache {
80 client: Arc<VaultClient>,
81 expiry: Instant,
82}
83
84static TOKEN_CACHE: Lazy<RwLock<HashMap<VaultCacheKey, TokenCache>>> =
86 Lazy::new(|| RwLock::new(HashMap::new()));
87
88#[cfg(test)]
89use mockall::automock;
90
91use crate::models::SecretString;
92use crate::utils::base64_encode;
93
94#[derive(Clone)]
95pub struct VaultConfig {
96 pub address: String,
97 pub namespace: Option<String>,
98 pub role_id: SecretString,
99 pub secret_id: SecretString,
100 pub mount_path: String,
101 pub token_ttl: Option<u64>,
103}
104
105impl VaultConfig {
106 pub fn new(
107 address: String,
108 role_id: SecretString,
109 secret_id: SecretString,
110 namespace: Option<String>,
111 mount_path: String,
112 token_ttl: Option<u64>,
113 ) -> Self {
114 Self {
115 address,
116 role_id,
117 secret_id,
118 namespace,
119 mount_path,
120 token_ttl,
121 }
122 }
123
124 fn cache_key(&self) -> VaultCacheKey {
125 VaultCacheKey {
126 address: self.address.clone(),
127 role_id: self.role_id.to_str().to_string(),
128 namespace: self.namespace.clone(),
129 }
130 }
131}
132
133#[async_trait]
134#[cfg_attr(test, automock)]
135pub trait VaultServiceTrait: Send + Sync {
136 async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError>;
137 async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError>;
138}
139
140#[derive(Clone)]
141pub struct VaultService {
142 pub config: VaultConfig,
143}
144
145impl VaultService {
146 pub fn new(config: VaultConfig) -> Self {
147 Self { config }
148 }
149
150 async fn get_client(&self) -> Result<Arc<VaultClient>, VaultError> {
152 let cache_key = self.config.cache_key();
153
154 {
156 let cache = TOKEN_CACHE.read().await;
157 if let Some(cached) = cache.get(&cache_key) {
158 if Instant::now() < cached.expiry {
159 return Ok(Arc::clone(&cached.client));
160 }
161 }
162 }
163
164 let mut cache = TOKEN_CACHE.write().await;
166 if let Some(cached) = cache.get(&cache_key) {
168 if Instant::now() < cached.expiry {
169 return Ok(Arc::clone(&cached.client));
170 }
171 }
172
173 let client = self.create_authenticated_client().await?;
175
176 let ttl = Duration::from_secs(self.config.token_ttl.unwrap_or(45 * 60));
178
179 cache.insert(
181 cache_key,
182 TokenCache {
183 client: client.clone(),
184 expiry: Instant::now() + ttl,
185 },
186 );
187
188 Ok(client)
189 }
190
191 async fn create_authenticated_client(&self) -> Result<Arc<VaultClient>, VaultError> {
193 let mut auth_settings_builder = VaultClientSettingsBuilder::default();
194 let address = &self.config.address;
195 auth_settings_builder.address(address).verify(true);
196
197 if let Some(namespace) = &self.config.namespace {
198 auth_settings_builder.namespace(Some(namespace.clone()));
199 }
200
201 let auth_settings = auth_settings_builder.build().map_err(|e| {
202 VaultError::ConfigError(format!("Failed to build Vault client settings: {}", e))
203 })?;
204
205 let client = VaultClient::new(auth_settings).map_err(|e| {
206 VaultError::ConfigError(format!("Failed to create Vault client: {}", e))
207 })?;
208
209 let token = login(
210 &client,
211 "approle",
212 &self.config.role_id.to_str(),
213 &self.config.secret_id.to_str(),
214 )
215 .await
216 .map_err(|e| VaultError::AuthenticationFailed(e.to_string()))?;
217
218 let mut transit_settings_builder = VaultClientSettingsBuilder::default();
219
220 transit_settings_builder
221 .address(self.config.address.clone())
222 .token(token.client_token.clone())
223 .verify(true);
224
225 if let Some(namespace) = &self.config.namespace {
226 transit_settings_builder.namespace(Some(namespace.clone()));
227 }
228
229 let transit_settings = transit_settings_builder.build().map_err(|e| {
230 VaultError::ConfigError(format!("Failed to build Vault client settings: {}", e))
231 })?;
232
233 let client = Arc::new(VaultClient::new(transit_settings).map_err(|e| {
234 VaultError::ConfigError(format!(
235 "Failed to create authenticated Vault client: {}",
236 e
237 ))
238 })?);
239
240 Ok(client)
241 }
242}
243
244#[async_trait]
245impl VaultServiceTrait for VaultService {
246 async fn retrieve_secret(&self, key_name: &str) -> Result<String, VaultError> {
247 let client = self.get_client().await?;
248
249 let secret: serde_json::Value = kv2::read(&*client, &self.config.mount_path, key_name)
250 .await
251 .map_err(|e| VaultError::ClientError(e.to_string()))?;
252
253 let value = secret["value"]
254 .as_str()
255 .ok_or_else(|| {
256 VaultError::SecretNotFound(format!("Secret value invalid for key: {}", key_name))
257 })?
258 .to_string();
259
260 Ok(value)
261 }
262
263 async fn sign(&self, key_name: &str, message: &[u8]) -> Result<String, VaultError> {
264 let client = self.get_client().await?;
265
266 let vault_signature = transit::data::sign(
267 &*client,
268 &self.config.mount_path,
269 key_name,
270 &base64_encode(message),
271 None,
272 )
273 .await
274 .map_err(|e| VaultError::SigningError(format!("Failed to sign with Vault: {}", e)))?;
275
276 let vault_signature_str = &vault_signature.signature;
277
278 debug!("vault_signature_str: {}", vault_signature_str);
279
280 Ok(vault_signature_str.clone())
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use serde_json::json;
288 use wiremock::matchers::{body_json, header, method, path};
289 use wiremock::{Mock, MockServer, ResponseTemplate};
290
291 #[test]
292 fn test_vault_config_new() {
293 let config = VaultConfig::new(
294 "https://vault.example.com".to_string(),
295 SecretString::new("test-role-id"),
296 SecretString::new("test-secret-id"),
297 Some("test-namespace".to_string()),
298 "test-mount-path".to_string(),
299 Some(60),
300 );
301
302 assert_eq!(config.address, "https://vault.example.com");
303 assert_eq!(config.role_id.to_str().as_str(), "test-role-id");
304 assert_eq!(config.secret_id.to_str().as_str(), "test-secret-id");
305 assert_eq!(config.namespace, Some("test-namespace".to_string()));
306 assert_eq!(config.mount_path, "test-mount-path");
307 assert_eq!(config.token_ttl, Some(60));
308 }
309
310 #[test]
311 fn test_vault_cache_key() {
312 let config1 = VaultConfig {
313 address: "https://vault1.example.com".to_string(),
314 namespace: Some("namespace1".to_string()),
315 role_id: SecretString::new("role1"),
316 secret_id: SecretString::new("secret1"),
317 mount_path: "transit".to_string(),
318 token_ttl: None,
319 };
320
321 let config2 = VaultConfig {
322 address: "https://vault1.example.com".to_string(),
323 namespace: Some("namespace1".to_string()),
324 role_id: SecretString::new("role1"),
325 secret_id: SecretString::new("secret1"),
326 mount_path: "different-mount".to_string(),
327 token_ttl: None,
328 };
329
330 let config3 = VaultConfig {
331 address: "https://vault2.example.com".to_string(),
332 namespace: Some("namespace1".to_string()),
333 role_id: SecretString::new("role1"),
334 secret_id: SecretString::new("secret1"),
335 mount_path: "transit".to_string(),
336 token_ttl: None,
337 };
338
339 assert_eq!(config1.cache_key(), config1.cache_key());
340 assert_eq!(config1.cache_key(), config2.cache_key());
341 assert_ne!(config1.cache_key(), config3.cache_key());
342 }
343
344 #[test]
345 fn test_vault_cache_key_display() {
346 let key_with_namespace = VaultCacheKey {
347 address: "https://vault.example.com".to_string(),
348 role_id: "role-123".to_string(),
349 namespace: Some("my-namespace".to_string()),
350 };
351
352 let key_without_namespace = VaultCacheKey {
353 address: "https://vault.example.com".to_string(),
354 role_id: "role-123".to_string(),
355 namespace: None,
356 };
357
358 assert_eq!(
359 key_with_namespace.to_string(),
360 "https://vault.example.com|role-123|my-namespace"
361 );
362
363 assert_eq!(
364 key_without_namespace.to_string(),
365 "https://vault.example.com|role-123|"
366 );
367 }
368
369 async fn setup_mock_approle_login(
371 mock_server: &MockServer,
372 role_id: &str,
373 secret_id: &str,
374 token: &str,
375 ) {
376 Mock::given(method("POST"))
377 .and(path("/v1/auth/approle/login"))
378 .and(body_json(json!({
379 "role_id": role_id,
380 "secret_id": secret_id
381 })))
382 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
383 "request_id": "test-request-id",
384 "lease_id": "",
385 "renewable": false,
386 "lease_duration": 0,
387 "data": null,
388 "wrap_info": null,
389 "warnings": null,
390 "auth": {
391 "client_token": token,
392 "accessor": "test-accessor",
393 "policies": ["default"],
394 "token_policies": ["default"],
395 "metadata": {
396 "role_name": "test-role"
397 },
398 "lease_duration": 3600,
399 "renewable": true,
400 "entity_id": "test-entity-id",
401 "token_type": "service",
402 "orphan": true
403 }
404 })))
405 .mount(mock_server)
406 .await;
407 }
408
409 #[tokio::test]
410 async fn test_vault_service_auth_failure() {
411 let mock_server = MockServer::start().await;
412
413 setup_mock_approle_login(&mock_server, "test-role-id", "test-secret-id", "test-token")
414 .await;
415
416 Mock::given(method("GET"))
417 .and(path("/v1/test-mount/data/my-secret"))
418 .and(header("X-Vault-Token", "test-token"))
419 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
420 "request_id": "test-request-id",
421 "lease_id": "",
422 "renewable": false,
423 "lease_duration": 0,
424 "data": {
425 "data": {
426 "value": "super-secret-value"
427 },
428 "metadata": {
429 "created_time": "2024-01-01T00:00:00Z",
430 "deletion_time": "",
431 "destroyed": false,
432 "version": 1
433 }
434 },
435 "wrap_info": null,
436 "warnings": null,
437 "auth": null
438 })))
439 .mount(&mock_server)
440 .await;
441
442 let config = VaultConfig::new(
443 mock_server.uri(),
444 SecretString::new("test-role-id-fake"),
445 SecretString::new("test-secret-id-fake"),
446 None,
447 "test-mount".to_string(),
448 Some(60),
449 );
450
451 let vault_service = VaultService::new(config);
452
453 let secret = vault_service.retrieve_secret("my-secret").await;
454
455 assert!(secret.is_err());
456
457 if let Err(e) = secret {
458 assert!(matches!(e, VaultError::AuthenticationFailed(_)));
459 assert!(e.to_string().contains("An error occurred with the request"));
460 }
461 }
462
463 #[tokio::test]
464 async fn test_vault_service_retrieve_secret_success() {
465 let mock_server = MockServer::start().await;
466
467 setup_mock_approle_login(&mock_server, "test-role-id", "test-secret-id", "test-token")
468 .await;
469
470 Mock::given(method("GET"))
471 .and(path("/v1/test-mount/data/my-secret"))
472 .and(header("X-Vault-Token", "test-token"))
473 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
474 "request_id": "test-request-id",
475 "lease_id": "",
476 "renewable": false,
477 "lease_duration": 0,
478 "data": {
479 "data": {
480 "value": "super-secret-value"
481 },
482 "metadata": {
483 "created_time": "2024-01-01T00:00:00Z",
484 "deletion_time": "",
485 "destroyed": false,
486 "version": 1
487 }
488 },
489 "wrap_info": null,
490 "warnings": null,
491 "auth": null
492 })))
493 .mount(&mock_server)
494 .await;
495
496 let config = VaultConfig::new(
497 mock_server.uri(),
498 SecretString::new("test-role-id"),
499 SecretString::new("test-secret-id"),
500 None,
501 "test-mount".to_string(),
502 Some(60),
503 );
504
505 let vault_service = VaultService::new(config);
506
507 let secret = vault_service.retrieve_secret("my-secret").await.unwrap();
508
509 assert_eq!(secret, "super-secret-value");
510 }
511
512 #[tokio::test]
513 async fn test_vault_service_sign_success() {
514 let mock_server = MockServer::start().await;
515
516 setup_mock_approle_login(&mock_server, "test-role-id", "test-secret-id", "test-token")
517 .await;
518
519 let message = b"hello world";
520 let encoded_message = base64_encode(message);
521
522 Mock::given(method("POST"))
523 .and(path("/v1/test-mount/sign/my-signing-key"))
524 .and(header("X-Vault-Token", "test-token"))
525 .and(body_json(json!({
526 "input": encoded_message
527 })))
528 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
529 "request_id": "test-request-id",
530 "lease_id": "",
531 "renewable": false,
532 "lease_duration": 0,
533 "data": {
534 "signature": "vault:v1:fake-signature",
535 "key_version": 1
536 },
537 "wrap_info": null,
538 "warnings": null,
539 "auth": null
540 })))
541 .mount(&mock_server)
542 .await;
543
544 let config = VaultConfig::new(
545 mock_server.uri(),
546 SecretString::new("test-role-id"),
547 SecretString::new("test-secret-id"),
548 None,
549 "test-mount".to_string(),
550 Some(60),
551 );
552
553 let vault_service = VaultService::new(config);
554 let signature = vault_service.sign("my-signing-key", message).await.unwrap();
555
556 assert_eq!(signature, "vault:v1:fake-signature");
557 }
558
559 #[tokio::test]
560 async fn test_vault_service_retrieve_secret_failure() {
561 let mock_server = MockServer::start().await;
562
563 setup_mock_approle_login(&mock_server, "test-role-id", "test-secret-id", "test-token")
564 .await;
565
566 Mock::given(method("GET"))
567 .and(path("/v1/test-mount/data/my-secret"))
568 .and(header("X-Vault-Token", "test-token"))
569 .respond_with(ResponseTemplate::new(404).set_body_json(json!({
570 "errors": ["secret not found:"]
571 })))
572 .mount(&mock_server)
573 .await;
574
575 let config = VaultConfig::new(
576 mock_server.uri(),
577 SecretString::new("test-role-id"),
578 SecretString::new("test-secret-id"),
579 None,
580 "test-mount".to_string(),
581 Some(60),
582 );
583
584 let vault_service = VaultService::new(config);
585
586 let result = vault_service.retrieve_secret("my-secret").await;
587 assert!(result.is_err());
588
589 if let Err(e) = result {
590 assert!(matches!(e, VaultError::ClientError(_)));
591 assert!(e
592 .to_string()
593 .contains("The Vault server returned an error (status code 404)"));
594 }
595 }
596
597 #[tokio::test]
598 async fn test_vault_service_sign_failure() {
599 let mock_server = MockServer::start().await;
600
601 setup_mock_approle_login(&mock_server, "test-role-id", "test-secret-id", "test-token")
602 .await;
603
604 let message = b"hello world";
605 let encoded_message = base64_encode(message);
606
607 Mock::given(method("POST"))
608 .and(path("/v1/test-mount/sign/my-signing-key"))
609 .and(header("X-Vault-Token", "test-token"))
610 .and(body_json(json!({
611 "input": encoded_message
612 })))
613 .respond_with(ResponseTemplate::new(400).set_body_json(json!({
614 "errors": ["1 error occurred:\n\t* signing key not found"]
615 })))
616 .mount(&mock_server)
617 .await;
618
619 let config = VaultConfig::new(
620 mock_server.uri(),
621 SecretString::new("test-role-id"),
622 SecretString::new("test-secret-id"),
623 None,
624 "test-mount".to_string(),
625 Some(60),
626 );
627
628 let vault_service = VaultService::new(config);
629 let result = vault_service.sign("my-signing-key", message).await;
630 assert!(result.is_err());
631
632 if let Err(e) = result {
633 assert!(matches!(e, VaultError::SigningError(_)));
634 assert!(e.to_string().contains("Failed to sign with Vault"));
635 }
636 }
637}