openzeppelin_relayer/services/aws_kms/
mod.rs

1//! # AWS KMS Service Module
2//!
3//! This module provides integration with AWS KMS for secure key management
4//! and cryptographic operations such as public key retrieval and message signing.
5//!
6//! Currently only EVM is supported.
7//!
8//! ## Features
9//!
10//! - Service account authentication using credential providers
11//! - Public key retrieval from KMS
12//! - Message signing via KMS
13//!
14//! ## Architecture
15//!
16//! ```text
17//! AwsKmsService (implements AwsKmsEvmService)
18//!   ├── Authentication (via AwsKmsClient)
19//!   ├── Public Key Retrieval (via AwsKmsClient)
20//!   └── Message Signing (via AwsKmsClient)
21//! ```
22//! is based on
23//! ```text
24//! AwsKmsClient (implements AwsKmsK256)
25//!   ├── Authentication (via shared credentials)
26//!   ├── Public Key Retrieval in DER Encoding
27//!   └── Message Digest Signing in DER Encoding
28//! ```
29//! `AwsKmsK256` is mocked with `mockall` for unit testing
30//! and injected into `AwsKmsService`
31//!
32
33use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37    primitives::Blob,
38    types::{MessageType, SigningAlgorithmSpec},
39    Client,
40};
41use serde::Serialize;
42
43use crate::{
44    models::{Address, AwsKmsSignerConfig},
45    utils::{self, derive_ethereum_address_from_der, extract_public_key_from_der},
46};
47
48#[cfg(test)]
49use mockall::{automock, mock};
50
51#[derive(Clone, Debug, thiserror::Error, Serialize)]
52pub enum AwsKmsError {
53    #[error("AWS KMS response parse error: {0}")]
54    ParseError(String),
55    #[error("AWS KMS config error: {0}")]
56    ConfigError(String),
57    #[error("AWS KMS get error: {0}")]
58    GetError(String),
59    #[error("AWS KMS signing error: {0}")]
60    SignError(String),
61    #[error("AWS KMS permissions error: {0}")]
62    PermissionError(String),
63    #[error("AWS KMS public key error: {0}")]
64    RecoveryError(#[from] utils::Secp256k1Error),
65    #[error("AWS KMS conversion error: {0}")]
66    ConvertError(String),
67    #[error("AWS KMS Other error: {0}")]
68    Other(String),
69}
70
71pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
72
73#[async_trait]
74#[cfg_attr(test, automock)]
75pub trait AwsKmsEvmService: Send + Sync {
76    /// Returns the EVM address derived from the configured public key.
77    async fn get_evm_address(&self) -> AwsKmsResult<Address>;
78    /// Signs a payload using the EVM signing scheme.
79    /// Pre-hashes the message with keccak-256.
80    async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
81}
82
83#[async_trait]
84#[cfg_attr(test, automock)]
85pub trait AwsKmsK256: Send + Sync {
86    /// Fetches the DER-encoded public key from AWS KMS.
87    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
88    /// Signs a digest using EcdsaSha256 spec. Returns DER-encoded signature
89    async fn sign_digest<'a, 'b>(
90        &'a self,
91        key_id: &'b str,
92        digest: [u8; 32],
93    ) -> AwsKmsResult<Vec<u8>>;
94}
95
96#[cfg(test)]
97mock! {
98    pub AwsKmsClient { }
99    impl Clone for AwsKmsClient {
100        fn clone(&self) -> Self;
101    }
102
103    #[async_trait]
104    impl AwsKmsK256 for AwsKmsClient {
105        async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
106        async fn sign_digest<'a, 'b>(
107            &'a self,
108            key_id: &'b str,
109            digest: [u8; 32],
110        ) -> AwsKmsResult<Vec<u8>>;
111    }
112
113}
114
115#[derive(Debug, Clone)]
116pub struct AwsKmsClient {
117    inner: Client,
118}
119
120#[async_trait]
121impl AwsKmsK256 for AwsKmsClient {
122    async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
123        let get_output = self
124            .inner
125            .get_public_key()
126            .key_id(key_id)
127            .send()
128            .await
129            .map_err(|e| AwsKmsError::GetError(e.to_string()))?;
130
131        let der_pk_blob = get_output
132            .public_key
133            .ok_or(AwsKmsError::GetError(
134                "No public key blob found".to_string(),
135            ))?
136            .into_inner();
137
138        Ok(der_pk_blob)
139    }
140
141    async fn sign_digest<'a, 'b>(
142        &'a self,
143        key_id: &'b str,
144        digest: [u8; 32],
145    ) -> AwsKmsResult<Vec<u8>> {
146        // Sign the digest with the AWS KMS
147        let sign_result = self
148            .inner
149            .sign()
150            .key_id(key_id)
151            .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
152            .message_type(MessageType::Digest)
153            .message(Blob::new(digest))
154            .send()
155            .await;
156
157        // Process the result, extract DER signature
158        let der_signature = sign_result
159            .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
160            .signature
161            .ok_or(AwsKmsError::SignError(
162                "Signature not found in response".to_string(),
163            ))?
164            .into_inner();
165
166        Ok(der_signature)
167    }
168}
169
170#[derive(Debug, Clone)]
171pub struct AwsKmsService<T: AwsKmsK256 + Clone = AwsKmsClient> {
172    pub kms_key_id: String,
173    client: T,
174}
175
176impl AwsKmsService<AwsKmsClient> {
177    pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
178        let region_provider =
179            RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
180
181        let auth_config = aws_config::defaults(BehaviorVersion::latest())
182            .region(region_provider)
183            .load()
184            .await;
185        let client = AwsKmsClient {
186            inner: Client::new(&auth_config),
187        };
188
189        Ok(Self {
190            kms_key_id: config.key_id,
191            client,
192        })
193    }
194}
195
196#[cfg(test)]
197impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
198    pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
199        Self {
200            client,
201            kms_key_id: config.key_id,
202        }
203    }
204}
205
206impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
207    /// Signs a bytes with the private key stored in AWS KMS.
208    ///
209    /// Pre-hashes the message with keccak256.
210    pub async fn sign_bytes_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
211        // Create a digest of a message payload
212        let digest = keccak256(bytes).0;
213
214        // Sign the digest with the AWS KMS
215        // Process the result, extract DER signature
216        let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
217
218        // Parse DER into Secp256k1 format
219        let rs = k256::ecdsa::Signature::from_der(&der_signature)
220            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
221
222        let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
223
224        // Extract public key from AWS KMS and convert it to an uncompressed 64 pk
225        let pk = extract_public_key_from_der(&der_pk)
226            .map_err(|e| AwsKmsError::ConvertError(e.to_string()))?;
227
228        // Extract v value from the public key recovery
229        let v = utils::recover_public_key(&pk, &rs, bytes)?;
230
231        // Adjust v value for Ethereum legacy transaction.
232        let eth_v = 27 + v;
233
234        // Append `v` to a signature bytes
235        let mut sig_bytes = rs.to_vec();
236        sig_bytes.push(eth_v);
237
238        Ok(sig_bytes)
239    }
240}
241
242#[async_trait]
243impl<T: AwsKmsK256 + Clone> AwsKmsEvmService for AwsKmsService<T> {
244    async fn get_evm_address(&self) -> AwsKmsResult<Address> {
245        let der = self.client.get_der_public_key(&self.kms_key_id).await?;
246        let eth_address = derive_ethereum_address_from_der(&der)
247            .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
248        Ok(Address::Evm(eth_address))
249    }
250
251    async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
252        self.sign_bytes_evm(message).await
253    }
254}
255
256#[cfg(test)]
257pub mod tests {
258    use super::*;
259
260    use alloy::primitives::utils::eip191_message;
261    use k256::{
262        ecdsa::SigningKey,
263        elliptic_curve::rand_core::OsRng,
264        pkcs8::{der::Encode, EncodePublicKey},
265    };
266    use mockall::predicate::{eq, ne};
267
268    pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
269        let mut client = MockAwsKmsClient::new();
270        let signing_key = SigningKey::random(&mut OsRng);
271        let s = signing_key
272            .verifying_key()
273            .to_public_key_der()
274            .unwrap()
275            .to_der()
276            .unwrap();
277
278        client
279            .expect_get_der_public_key()
280            .with(eq("test-key-id"))
281            .return_const(Ok(s));
282        client
283            .expect_get_der_public_key()
284            .with(ne("test-key-id"))
285            .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
286
287        client
288            .expect_sign_digest()
289            .withf(|key_id, _| key_id.ne("test-key-id"))
290            .return_const(Err(AwsKmsError::SignError(
291                "Key does not exist".to_string(),
292            )));
293
294        let key = signing_key.clone();
295        client
296            .expect_sign_digest()
297            .withf(|key_id, _| key_id.eq("test-key-id"))
298            .returning(move |_, digest| {
299                let (signature, _) = signing_key
300                    .sign_prehash_recoverable(&digest)
301                    .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
302                let der_signature = signature.to_der().as_bytes().to_vec();
303                Ok(der_signature)
304            });
305
306        client.expect_clone().return_once(MockAwsKmsClient::new);
307
308        (client, key)
309    }
310
311    #[tokio::test]
312    async fn test_get_public_key() {
313        let (mock_client, key) = setup_mock_kms_client();
314        let kms = AwsKmsService::new_for_testing(
315            mock_client,
316            AwsKmsSignerConfig {
317                region: Some("us-east-1".to_string()),
318                key_id: "test-key-id".to_string(),
319            },
320        );
321
322        let result = kms.get_evm_address().await;
323        assert!(result.is_ok());
324        if let Ok(Address::Evm(evm_address)) = result {
325            let expected_address = derive_ethereum_address_from_der(
326                key.verifying_key().to_public_key_der().unwrap().as_bytes(),
327            )
328            .unwrap();
329            assert_eq!(expected_address, evm_address);
330        }
331    }
332
333    #[tokio::test]
334    async fn test_get_public_key_fail() {
335        let (mock_client, _) = setup_mock_kms_client();
336        let kms = AwsKmsService::new_for_testing(
337            mock_client,
338            AwsKmsSignerConfig {
339                region: Some("us-east-1".to_string()),
340                key_id: "invalid-key-id".to_string(),
341            },
342        );
343
344        let result = kms.get_evm_address().await;
345        assert!(result.is_err());
346        if let Err(err) = result {
347            assert!(matches!(err, AwsKmsError::GetError(_)))
348        }
349    }
350
351    #[tokio::test]
352    async fn test_sign_digest() {
353        let (mock_client, _) = setup_mock_kms_client();
354        let kms = AwsKmsService::new_for_testing(
355            mock_client,
356            AwsKmsSignerConfig {
357                region: Some("us-east-1".to_string()),
358                key_id: "test-key-id".to_string(),
359            },
360        );
361
362        let message_eip = eip191_message(b"Hello World!");
363        let result = kms.sign_payload_evm(&message_eip).await;
364
365        // We just assert for Ok, since the pubkey recovery indicates the validity of signature
366        assert!(result.is_ok());
367    }
368
369    #[tokio::test]
370    async fn test_sign_digest_fail() {
371        let (mock_client, _) = setup_mock_kms_client();
372        let kms = AwsKmsService::new_for_testing(
373            mock_client,
374            AwsKmsSignerConfig {
375                region: Some("us-east-1".to_string()),
376                key_id: "invalid-key-id".to_string(),
377            },
378        );
379
380        let message_eip = eip191_message(b"Hello World!");
381        let result = kms.sign_payload_evm(&message_eip).await;
382        assert!(result.is_err());
383        if let Err(err) = result {
384            assert!(matches!(err, AwsKmsError::SignError(_)))
385        }
386    }
387}