openzeppelin_relayer/services/aws_kms/
mod.rs1use alloy::primitives::keccak256;
34use async_trait::async_trait;
35use aws_config::{meta::region::RegionProviderChain, BehaviorVersion, Region};
36use aws_sdk_kms::{
37 primitives::Blob,
38 types::{MessageType, SigningAlgorithmSpec},
39 Client,
40};
41use serde::Serialize;
42
43use crate::{
44 models::{Address, AwsKmsSignerConfig},
45 utils::{self, derive_ethereum_address_from_der, extract_public_key_from_der},
46};
47
48#[cfg(test)]
49use mockall::{automock, mock};
50
51#[derive(Clone, Debug, thiserror::Error, Serialize)]
52pub enum AwsKmsError {
53 #[error("AWS KMS response parse error: {0}")]
54 ParseError(String),
55 #[error("AWS KMS config error: {0}")]
56 ConfigError(String),
57 #[error("AWS KMS get error: {0}")]
58 GetError(String),
59 #[error("AWS KMS signing error: {0}")]
60 SignError(String),
61 #[error("AWS KMS permissions error: {0}")]
62 PermissionError(String),
63 #[error("AWS KMS public key error: {0}")]
64 RecoveryError(#[from] utils::Secp256k1Error),
65 #[error("AWS KMS conversion error: {0}")]
66 ConvertError(String),
67 #[error("AWS KMS Other error: {0}")]
68 Other(String),
69}
70
71pub type AwsKmsResult<T> = Result<T, AwsKmsError>;
72
73#[async_trait]
74#[cfg_attr(test, automock)]
75pub trait AwsKmsEvmService: Send + Sync {
76 async fn get_evm_address(&self) -> AwsKmsResult<Address>;
78 async fn sign_payload_evm(&self, payload: &[u8]) -> AwsKmsResult<Vec<u8>>;
81}
82
83#[async_trait]
84#[cfg_attr(test, automock)]
85pub trait AwsKmsK256: Send + Sync {
86 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
88 async fn sign_digest<'a, 'b>(
90 &'a self,
91 key_id: &'b str,
92 digest: [u8; 32],
93 ) -> AwsKmsResult<Vec<u8>>;
94}
95
96#[cfg(test)]
97mock! {
98 pub AwsKmsClient { }
99 impl Clone for AwsKmsClient {
100 fn clone(&self) -> Self;
101 }
102
103 #[async_trait]
104 impl AwsKmsK256 for AwsKmsClient {
105 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>>;
106 async fn sign_digest<'a, 'b>(
107 &'a self,
108 key_id: &'b str,
109 digest: [u8; 32],
110 ) -> AwsKmsResult<Vec<u8>>;
111 }
112
113}
114
115#[derive(Debug, Clone)]
116pub struct AwsKmsClient {
117 inner: Client,
118}
119
120#[async_trait]
121impl AwsKmsK256 for AwsKmsClient {
122 async fn get_der_public_key<'a, 'b>(&'a self, key_id: &'b str) -> AwsKmsResult<Vec<u8>> {
123 let get_output = self
124 .inner
125 .get_public_key()
126 .key_id(key_id)
127 .send()
128 .await
129 .map_err(|e| AwsKmsError::GetError(e.to_string()))?;
130
131 let der_pk_blob = get_output
132 .public_key
133 .ok_or(AwsKmsError::GetError(
134 "No public key blob found".to_string(),
135 ))?
136 .into_inner();
137
138 Ok(der_pk_blob)
139 }
140
141 async fn sign_digest<'a, 'b>(
142 &'a self,
143 key_id: &'b str,
144 digest: [u8; 32],
145 ) -> AwsKmsResult<Vec<u8>> {
146 let sign_result = self
148 .inner
149 .sign()
150 .key_id(key_id)
151 .signing_algorithm(SigningAlgorithmSpec::EcdsaSha256)
152 .message_type(MessageType::Digest)
153 .message(Blob::new(digest))
154 .send()
155 .await;
156
157 let der_signature = sign_result
159 .map_err(|e| AwsKmsError::PermissionError(e.to_string()))?
160 .signature
161 .ok_or(AwsKmsError::SignError(
162 "Signature not found in response".to_string(),
163 ))?
164 .into_inner();
165
166 Ok(der_signature)
167 }
168}
169
170#[derive(Debug, Clone)]
171pub struct AwsKmsService<T: AwsKmsK256 + Clone = AwsKmsClient> {
172 pub kms_key_id: String,
173 client: T,
174}
175
176impl AwsKmsService<AwsKmsClient> {
177 pub async fn new(config: AwsKmsSignerConfig) -> AwsKmsResult<Self> {
178 let region_provider =
179 RegionProviderChain::first_try(config.region.map(Region::new)).or_default_provider();
180
181 let auth_config = aws_config::defaults(BehaviorVersion::latest())
182 .region(region_provider)
183 .load()
184 .await;
185 let client = AwsKmsClient {
186 inner: Client::new(&auth_config),
187 };
188
189 Ok(Self {
190 kms_key_id: config.key_id,
191 client,
192 })
193 }
194}
195
196#[cfg(test)]
197impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
198 pub fn new_for_testing(client: T, config: AwsKmsSignerConfig) -> Self {
199 Self {
200 client,
201 kms_key_id: config.key_id,
202 }
203 }
204}
205
206impl<T: AwsKmsK256 + Clone> AwsKmsService<T> {
207 pub async fn sign_bytes_evm(&self, bytes: &[u8]) -> AwsKmsResult<Vec<u8>> {
211 let digest = keccak256(bytes).0;
213
214 let der_signature = self.client.sign_digest(&self.kms_key_id, digest).await?;
217
218 let rs = k256::ecdsa::Signature::from_der(&der_signature)
220 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
221
222 let der_pk = self.client.get_der_public_key(&self.kms_key_id).await?;
223
224 let pk = extract_public_key_from_der(&der_pk)
226 .map_err(|e| AwsKmsError::ConvertError(e.to_string()))?;
227
228 let v = utils::recover_public_key(&pk, &rs, bytes)?;
230
231 let eth_v = 27 + v;
233
234 let mut sig_bytes = rs.to_vec();
236 sig_bytes.push(eth_v);
237
238 Ok(sig_bytes)
239 }
240}
241
242#[async_trait]
243impl<T: AwsKmsK256 + Clone> AwsKmsEvmService for AwsKmsService<T> {
244 async fn get_evm_address(&self) -> AwsKmsResult<Address> {
245 let der = self.client.get_der_public_key(&self.kms_key_id).await?;
246 let eth_address = derive_ethereum_address_from_der(&der)
247 .map_err(|e| AwsKmsError::ParseError(e.to_string()))?;
248 Ok(Address::Evm(eth_address))
249 }
250
251 async fn sign_payload_evm(&self, message: &[u8]) -> AwsKmsResult<Vec<u8>> {
252 self.sign_bytes_evm(message).await
253 }
254}
255
256#[cfg(test)]
257pub mod tests {
258 use super::*;
259
260 use alloy::primitives::utils::eip191_message;
261 use k256::{
262 ecdsa::SigningKey,
263 elliptic_curve::rand_core::OsRng,
264 pkcs8::{der::Encode, EncodePublicKey},
265 };
266 use mockall::predicate::{eq, ne};
267
268 pub fn setup_mock_kms_client() -> (MockAwsKmsClient, SigningKey) {
269 let mut client = MockAwsKmsClient::new();
270 let signing_key = SigningKey::random(&mut OsRng);
271 let s = signing_key
272 .verifying_key()
273 .to_public_key_der()
274 .unwrap()
275 .to_der()
276 .unwrap();
277
278 client
279 .expect_get_der_public_key()
280 .with(eq("test-key-id"))
281 .return_const(Ok(s));
282 client
283 .expect_get_der_public_key()
284 .with(ne("test-key-id"))
285 .return_const(Err(AwsKmsError::GetError("Key does not exist".to_string())));
286
287 client
288 .expect_sign_digest()
289 .withf(|key_id, _| key_id.ne("test-key-id"))
290 .return_const(Err(AwsKmsError::SignError(
291 "Key does not exist".to_string(),
292 )));
293
294 let key = signing_key.clone();
295 client
296 .expect_sign_digest()
297 .withf(|key_id, _| key_id.eq("test-key-id"))
298 .returning(move |_, digest| {
299 let (signature, _) = signing_key
300 .sign_prehash_recoverable(&digest)
301 .map_err(|e| AwsKmsError::SignError(e.to_string()))?;
302 let der_signature = signature.to_der().as_bytes().to_vec();
303 Ok(der_signature)
304 });
305
306 client.expect_clone().return_once(MockAwsKmsClient::new);
307
308 (client, key)
309 }
310
311 #[tokio::test]
312 async fn test_get_public_key() {
313 let (mock_client, key) = setup_mock_kms_client();
314 let kms = AwsKmsService::new_for_testing(
315 mock_client,
316 AwsKmsSignerConfig {
317 region: Some("us-east-1".to_string()),
318 key_id: "test-key-id".to_string(),
319 },
320 );
321
322 let result = kms.get_evm_address().await;
323 assert!(result.is_ok());
324 if let Ok(Address::Evm(evm_address)) = result {
325 let expected_address = derive_ethereum_address_from_der(
326 key.verifying_key().to_public_key_der().unwrap().as_bytes(),
327 )
328 .unwrap();
329 assert_eq!(expected_address, evm_address);
330 }
331 }
332
333 #[tokio::test]
334 async fn test_get_public_key_fail() {
335 let (mock_client, _) = setup_mock_kms_client();
336 let kms = AwsKmsService::new_for_testing(
337 mock_client,
338 AwsKmsSignerConfig {
339 region: Some("us-east-1".to_string()),
340 key_id: "invalid-key-id".to_string(),
341 },
342 );
343
344 let result = kms.get_evm_address().await;
345 assert!(result.is_err());
346 if let Err(err) = result {
347 assert!(matches!(err, AwsKmsError::GetError(_)))
348 }
349 }
350
351 #[tokio::test]
352 async fn test_sign_digest() {
353 let (mock_client, _) = setup_mock_kms_client();
354 let kms = AwsKmsService::new_for_testing(
355 mock_client,
356 AwsKmsSignerConfig {
357 region: Some("us-east-1".to_string()),
358 key_id: "test-key-id".to_string(),
359 },
360 );
361
362 let message_eip = eip191_message(b"Hello World!");
363 let result = kms.sign_payload_evm(&message_eip).await;
364
365 assert!(result.is_ok());
367 }
368
369 #[tokio::test]
370 async fn test_sign_digest_fail() {
371 let (mock_client, _) = setup_mock_kms_client();
372 let kms = AwsKmsService::new_for_testing(
373 mock_client,
374 AwsKmsSignerConfig {
375 region: Some("us-east-1".to_string()),
376 key_id: "invalid-key-id".to_string(),
377 },
378 );
379
380 let message_eip = eip191_message(b"Hello World!");
381 let result = kms.sign_payload_evm(&message_eip).await;
382 assert!(result.is_err());
383 if let Err(err) = result {
384 assert!(matches!(err, AwsKmsError::SignError(_)))
385 }
386 }
387}