1use anonymiser_lib::process_package;
21use aws_config::meta::region::RegionProviderChain;
22use aws_config::{BehaviorVersion, SdkConfig};
23use aws_lambda_events::sqs::SqsMessage;
24use aws_sdk_s3::primitives::ByteStream;
25use aws_sdk_s3::Client as S3Client;
26use aws_sdk_sqs::Client as SQSClient;
27use lambda_runtime::Error;
28use serde::{Deserialize, Serialize};
29use std::fs;
30use std::fs::File;
31use std::io::Write;
32use std::path::{Path, PathBuf};
33
34#[derive(Deserialize, Serialize)]
36struct MessageBody {
37 parameters: S3Details,
38}
39
40#[derive(Deserialize, Serialize)]
41#[serde(rename_all = "camelCase")]
42struct S3Details {
43 status: String,
44 reference: String,
45 s3_bucket: String,
46 s3_key: String,
47}
48
49pub async fn process_record(
53 message: &SqsMessage,
54 working_directory: PathBuf,
55 s3_endpoint_url: Option<&str>,
56 sqs_endpoint_url: Option<&str>,
57) -> Result<PathBuf, Error> {
58 let body = message
59 .body
60 .as_ref()
61 .ok_or("No body found in the SQS message")?;
62 let s3_client = create_s3_client(s3_endpoint_url).await;
63 let sqs_client = create_sqs_client(sqs_endpoint_url).await;
64
65 let message_body: MessageBody = serde_json::from_str(body)?;
66 let parameters = message_body.parameters;
67 let input_file_path = download(
68 &s3_client,
69 parameters.s3_bucket,
70 parameters.s3_key,
71 &working_directory,
72 )
73 .await?;
74 let output_path = &working_directory.join(PathBuf::from("output"));
75 fs::create_dir_all(output_path)?;
76 let output_tar_path = process_package(output_path, &input_file_path)?;
77 let file_name = output_tar_path
78 .file_name()
79 .and_then(|file_name_as_os_string| file_name_as_os_string.to_str())
80 .expect("Cannot parse file name from output path");
81
82 let output_bucket = std::env::var("OUTPUT_BUCKET")?;
83 upload(&s3_client, &output_tar_path, &output_bucket, file_name).await?;
84
85 let output_queue = std::env::var("OUTPUT_QUEUE")?;
86 let reference = parameters.reference.replace("TDR", "TST");
87 let status = parameters.status;
88 let output_message_body = MessageBody {
89 parameters: S3Details {
90 s3_bucket: output_bucket,
91 s3_key: file_name.to_string(),
92 status,
93 reference,
94 },
95 };
96 let message_string = serde_json::to_string(&output_message_body)?;
97 let _ = sqs_client
98 .send_message()
99 .queue_url(&output_queue)
100 .message_body(message_string)
101 .send()
102 .await?;
103 Ok(output_path.clone())
104}
105
106async fn upload(
110 client: &S3Client,
111 body_path: &PathBuf,
112 bucket: &str,
113 key: &str,
114) -> Result<(), Error> {
115 let body = ByteStream::from_path(body_path).await?;
116 client
117 .put_object()
118 .bucket(bucket)
119 .key(key)
120 .body(body)
121 .send()
122 .await?;
123 Ok(())
124}
125
126async fn download(
130 client: &S3Client,
131 bucket: String,
132 key: String,
133 working_directory: &Path,
134) -> Result<PathBuf, Error> {
135 let destination = working_directory.join(PathBuf::from(&key));
136 let mut destination_path = destination.clone();
137 destination_path.pop();
138 fs::create_dir_all(&destination_path)?;
139
140 let mut file = File::create(&destination)?;
141
142 let mut object = client.get_object().bucket(bucket).key(&key).send().await?;
143
144 while let Some(bytes) = object.body.try_next().await? {
145 file.write_all(&bytes)?;
146 }
147
148 Ok(destination)
149}
150
151async fn create_sqs_client(potential_endpoint_url: Option<&str>) -> SQSClient {
153 let config = aws_config("sqs", potential_endpoint_url).await;
154 SQSClient::new(&config)
155}
156
157async fn aws_config(service: &str, potential_endpoint_url: Option<&str>) -> SdkConfig {
159 let default_endpoint = format!("https://{service}.eu-west-2.amazonaws.com");
160 let endpoint_url = potential_endpoint_url.unwrap_or(default_endpoint.as_str());
161 let region_provider = RegionProviderChain::default_provider().or_else("eu-west-2");
162
163 aws_config::defaults(BehaviorVersion::latest())
164 .region(region_provider)
165 .endpoint_url(endpoint_url)
166 .load()
167 .await
168}
169
170async fn create_s3_client(potential_endpoint_url: Option<&str>) -> S3Client {
172 let config = aws_config("s3", potential_endpoint_url).await;
173 S3Client::new(&config)
174}
175
176#[cfg(test)]
177mod test {
178 use crate::{aws_config, create_s3_client};
179
180 #[tokio::test]
181 async fn test_create_client_with_default_region() {
182 let client = create_s3_client(None).await;
183 let config = client.config();
184
185 assert_eq!(config.region().unwrap().to_string(), "eu-west-2");
186 }
187
188 #[tokio::test]
189 async fn test_aws_config_endpoint_url() {
190 let config_default_endpoint = aws_config("test", None).await;
191 assert_eq!(
192 config_default_endpoint.endpoint_url().unwrap(),
193 "https://test.eu-west-2.amazonaws.com"
194 );
195
196 let config_custom_endpoint = aws_config("test", Some("https://example.com")).await;
197 assert_eq!(
198 config_custom_endpoint.endpoint_url().unwrap(),
199 "https://example.com"
200 );
201 }
202}