lambda/
lib.rs

1//! Anonymiser lambda
2//!
3//! This lambda is used to convert incoming scripts from the TRE production bucket into anonymised packages
4//!
5//! Given the following input:
6//! ```json
7//! {
8//!   "parameters": {
9//!     "s3Bucket": "input-bucket",
10//!     "s3Key": "TRE-TDR-2023-ABC.tar.gz"
11//!   }
12//! }
13//! ```
14//! The lambda will:
15//! * Download the file from S3 to local disk
16//! * Anonymise it using the anonymise library
17//! * Upload it to S3 using the `OUTPUT_BUCKET` environment variable
18//! * Send the SQS message to the queue specified in the `OUTPUT_QUEUE` environment variable
19
20use anonymiser_lib::process_package;
21use aws_config::meta::region::RegionProviderChain;
22use aws_config::{BehaviorVersion, SdkConfig};
23use aws_lambda_events::sqs::SqsMessage;
24use aws_sdk_s3::primitives::ByteStream;
25use aws_sdk_s3::Client as S3Client;
26use aws_sdk_sqs::Client as SQSClient;
27use lambda_runtime::Error;
28use serde::{Deserialize, Serialize};
29use std::fs;
30use std::fs::File;
31use std::io::Write;
32use std::path::{Path, PathBuf};
33
34/// The bucket and key for the file we are processing
35#[derive(Deserialize, Serialize)]
36struct MessageBody {
37    parameters: S3Details,
38}
39
40#[derive(Deserialize, Serialize)]
41#[serde(rename_all = "camelCase")]
42struct S3Details {
43    status: String,
44    reference: String,
45    s3_bucket: String,
46    s3_key: String,
47}
48
49/// # Processes the SQS message.
50///
51/// This will download the file specified in the message body, anonymise it, upload it to S3 and send the message on to the output queue.
52pub async fn process_record(
53    message: &SqsMessage,
54    working_directory: PathBuf,
55    s3_endpoint_url: Option<&str>,
56    sqs_endpoint_url: Option<&str>,
57) -> Result<PathBuf, Error> {
58    let body = message
59        .body
60        .as_ref()
61        .ok_or("No body found in the SQS message")?;
62    let s3_client = create_s3_client(s3_endpoint_url).await;
63    let sqs_client = create_sqs_client(sqs_endpoint_url).await;
64
65    let message_body: MessageBody = serde_json::from_str(body)?;
66    let parameters = message_body.parameters;
67    let input_file_path = download(
68        &s3_client,
69        parameters.s3_bucket,
70        parameters.s3_key,
71        &working_directory,
72    )
73    .await?;
74    let output_path = &working_directory.join(PathBuf::from("output"));
75    fs::create_dir_all(output_path)?;
76    let output_tar_path = process_package(output_path, &input_file_path)?;
77    let file_name = output_tar_path
78        .file_name()
79        .and_then(|file_name_as_os_string| file_name_as_os_string.to_str())
80        .expect("Cannot parse file name from output path");
81
82    let output_bucket = std::env::var("OUTPUT_BUCKET")?;
83    upload(&s3_client, &output_tar_path, &output_bucket, file_name).await?;
84
85    let output_queue = std::env::var("OUTPUT_QUEUE")?;
86    let reference = parameters.reference.replace("TDR", "TST");
87    let status = parameters.status;
88    let output_message_body = MessageBody {
89        parameters: S3Details {
90            s3_bucket: output_bucket,
91            s3_key: file_name.to_string(),
92            status,
93            reference,
94        },
95    };
96    let message_string = serde_json::to_string(&output_message_body)?;
97    let _ = sqs_client
98        .send_message()
99        .queue_url(&output_queue)
100        .message_body(message_string)
101        .send()
102        .await?;
103    Ok(output_path.clone())
104}
105
106/// # Uploads the specified file
107///
108/// This will upload the contents of the file in `body_path` to the `bucket` with the specified `key`
109async fn upload(
110    client: &S3Client,
111    body_path: &PathBuf,
112    bucket: &str,
113    key: &str,
114) -> Result<(), Error> {
115    let body = ByteStream::from_path(body_path).await?;
116    client
117        .put_object()
118        .bucket(bucket)
119        .key(key)
120        .body(body)
121        .send()
122        .await?;
123    Ok(())
124}
125
126/// # Downloads the specified file
127///
128/// This downloads the contents of the file in the S3 `bucket` with the specified `key` into the `working_directory`
129async fn download(
130    client: &S3Client,
131    bucket: String,
132    key: String,
133    working_directory: &Path,
134) -> Result<PathBuf, Error> {
135    let destination = working_directory.join(PathBuf::from(&key));
136    let mut destination_path = destination.clone();
137    destination_path.pop();
138    fs::create_dir_all(&destination_path)?;
139
140    let mut file = File::create(&destination)?;
141
142    let mut object = client.get_object().bucket(bucket).key(&key).send().await?;
143
144    while let Some(bytes) = object.body.try_next().await? {
145        file.write_all(&bytes)?;
146    }
147
148    Ok(destination)
149}
150
151/// # Creates an SQS client
152async fn create_sqs_client(potential_endpoint_url: Option<&str>) -> SQSClient {
153    let config = aws_config("sqs", potential_endpoint_url).await;
154    SQSClient::new(&config)
155}
156
157/// # Creates an AWS SDK config object
158async fn aws_config(service: &str, potential_endpoint_url: Option<&str>) -> SdkConfig {
159    let default_endpoint = format!("https://{service}.eu-west-2.amazonaws.com");
160    let endpoint_url = potential_endpoint_url.unwrap_or(default_endpoint.as_str());
161    let region_provider = RegionProviderChain::default_provider().or_else("eu-west-2");
162
163    aws_config::defaults(BehaviorVersion::latest())
164        .region(region_provider)
165        .endpoint_url(endpoint_url)
166        .load()
167        .await
168}
169
170/// # Creates an S3 client
171async fn create_s3_client(potential_endpoint_url: Option<&str>) -> S3Client {
172    let config = aws_config("s3", potential_endpoint_url).await;
173    S3Client::new(&config)
174}
175
176#[cfg(test)]
177mod test {
178    use crate::{aws_config, create_s3_client};
179
180    #[tokio::test]
181    async fn test_create_client_with_default_region() {
182        let client = create_s3_client(None).await;
183        let config = client.config();
184
185        assert_eq!(config.region().unwrap().to_string(), "eu-west-2");
186    }
187
188    #[tokio::test]
189    async fn test_aws_config_endpoint_url() {
190        let config_default_endpoint = aws_config("test", None).await;
191        assert_eq!(
192            config_default_endpoint.endpoint_url().unwrap(),
193            "https://test.eu-west-2.amazonaws.com"
194        );
195
196        let config_custom_endpoint = aws_config("test", Some("https://example.com")).await;
197        assert_eq!(
198            config_custom_endpoint.endpoint_url().unwrap(),
199            "https://example.com"
200        );
201    }
202}