Skip to content

Commit

Permalink
Make MockRequestDispatcher::with_request_checker to own request
Browse files Browse the repository at this point in the history
Currently MockRequestDispatcher::with_request_checker takes a Fn
that passes a non-mutable reference to the checker function. This
causes issues for users who wish to check a streamed payload's
contents because it is immutable. By making it owned it will no
longer be a problem.
  • Loading branch information
allada committed Apr 18, 2021
1 parent 31cf150 commit 4c23423
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

- MockRequestDispatcher::with_request_checker now takes owned object callback
- Update to `serde_urlencoded` 0.7
- Update to `rustc_version` 0.3
- Replace `time`-related types in `rusoto_signature` with `chrono` types, to
Expand Down
6 changes: 3 additions & 3 deletions mock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub struct MockRequestDispatcher {
outcome: RequestOutcome,
body: Vec<u8>,
headers: HeaderMap<String>,
request_checker: Option<Box<dyn Fn(&SignedRequest) + Send + Sync>>,
request_checker: Option<Box<dyn Fn(SignedRequest) + Send + Sync>>,
}

enum RequestOutcome {
Expand Down Expand Up @@ -125,7 +125,7 @@ impl MockRequestDispatcher {
/// to AWS
pub fn with_request_checker<F>(mut self, checker: F) -> MockRequestDispatcher
where
F: Fn(&SignedRequest) + Send + Sync + 'static,
F: Fn(SignedRequest) + Send + Sync + 'static,
{
self.request_checker = Some(Box::new(checker));
self
Expand All @@ -146,7 +146,7 @@ impl DispatchSignedRequest for MockRequestDispatcher {
_timeout: Option<Duration>,
) -> rusoto_core::request::DispatchSignedRequestFuture {
if self.request_checker.is_some() {
self.request_checker.as_ref().unwrap()(&request);
self.request_checker.as_ref().unwrap()(request);
}
match self.outcome {
RequestOutcome::Performed(ref status) => futures::future::ready(Ok(HttpResponse {
Expand Down
2 changes: 1 addition & 1 deletion rusoto/services/cloudformation/src/custom/custom_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async fn should_serialize_list_parameters_in_request_body() {
</ResponseMetadata>
</ListStacksResponse>"#,
)
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);
assert_eq!("/", request.path);
if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down
2 changes: 1 addition & 1 deletion rusoto/services/cloudwatch/src/custom/custom_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use serde_urlencoded;
async fn should_serialize_complex_metric_data_params() {
let mock = MockRequestDispatcher::with_status(200)
.with_body("")
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);
assert_eq!("/", request.path);
if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down
2 changes: 1 addition & 1 deletion rusoto/services/lambda/src/custom/custom_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async fn should_parse_invocation_response() {
.with_body(r#"{"arbitrary":"json"}"#)
.with_header("X-Amz-Function-Error", "Handled")
.with_header("X-Amz-Log-Result", "foo bar baz")
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);
if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
assert_eq!(b"raw payload", buffer.as_ref());
Expand Down
8 changes: 4 additions & 4 deletions rusoto/services/s3/src/custom/custom_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async fn list_multipart_upload_parts_happy_path() {
<Size>5242880</Size>
</Part>
</ListPartsResult>"#)
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/rusoto1440826511/testfile.zip");
assert!(request.payload.is_none());
Expand Down Expand Up @@ -318,7 +318,7 @@ fn bench_parse_list_buckets_response(b: &mut Bencher) {
</ListAllMyBucketsResult>
"#,
)
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/");
assert!(request.payload.is_none());
Expand Down Expand Up @@ -355,7 +355,7 @@ async fn should_parse_sample_list_buckets_response() {
</ListAllMyBucketsResult>
"#,
)
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/");
assert!(request.payload.is_none());
Expand Down Expand Up @@ -421,7 +421,7 @@ async fn should_serialize_complicated_request() {

let mock = MockRequestDispatcher::with_status(200)
.with_body("")
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!(request.method, "GET");
assert_eq!(request.path, "/bucket/key");
assert_eq!(
Expand Down
10 changes: 5 additions & 5 deletions rusoto/services/sns/src/custom/custom_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rusoto_core::Region;
#[tokio::test]
async fn should_serialize_map_parameters_in_create_platform_application_request_body() {
let mock =
MockRequestDispatcher::with_status(200).with_request_checker(|request: &SignedRequest| {
MockRequestDispatcher::with_status(200).with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);

if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down Expand Up @@ -63,7 +63,7 @@ async fn should_serialize_map_parameters_in_create_platform_application_request_
#[tokio::test]
async fn should_serialize_map_parameters_in_create_platform_endpoint_request_body() {
let mock =
MockRequestDispatcher::with_status(200).with_request_checker(|request: &SignedRequest| {
MockRequestDispatcher::with_status(200).with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);

if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down Expand Up @@ -97,7 +97,7 @@ async fn should_serialize_map_parameters_in_create_platform_endpoint_request_bod
#[tokio::test]
async fn should_serialize_map_parameters_in_set_platform_application_attributes_request_body() {
let mock =
MockRequestDispatcher::with_status(200).with_request_checker(|request: &SignedRequest| {
MockRequestDispatcher::with_status(200).with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);

if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down Expand Up @@ -135,7 +135,7 @@ async fn should_serialize_map_parameters_in_set_platform_application_attributes_
#[tokio::test]
async fn should_serialize_map_parameters_in_set_endpoint_attributes_request_body() {
let mock =
MockRequestDispatcher::with_status(200).with_request_checker(|request: &SignedRequest| {
MockRequestDispatcher::with_status(200).with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);

if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down Expand Up @@ -198,7 +198,7 @@ async fn should_serialize_map_parameters_in_get_sms_attributes_request_body() {
#[tokio::test]
async fn should_serialize_map_parameters_in_set_sms_attributes_request_body() {
let mock =
MockRequestDispatcher::with_status(200).with_request_checker(|request: &SignedRequest| {
MockRequestDispatcher::with_status(200).with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);

if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down
18 changes: 7 additions & 11 deletions rusoto/services/sqs/src/custom/custom_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async fn should_serialize_map_parameters_in_request_body() {
</ResponseMetadata>
</SendMessageResponse>"#,
)
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
println!("{:#?}", request.params);

assert_eq!("POST", request.method);
Expand Down Expand Up @@ -115,7 +115,7 @@ async fn should_fix_issue_323() {
</ResponseMetadata>
</ReceiveMessageResponse>"#,
)
.with_request_checker(|request: &SignedRequest| {
.with_request_checker(|request: SignedRequest| {
assert_eq!("POST", request.method);
assert_eq!("/", request.path);
if let Some(SignedRequestPayload::Buffer(ref buffer)) = request.payload {
Expand Down Expand Up @@ -182,9 +182,8 @@ async fn test_parse_queue_does_not_exist_error() {

#[tokio::test]
async fn should_deserialize_map_parameters_in_response_body() {
let mock = MockRequestDispatcher::with_status(200)
.with_body(
r#"<?xml version="1.0" encoding="UTF-8"?>
let mock = MockRequestDispatcher::with_status(200).with_body(
r#"<?xml version="1.0" encoding="UTF-8"?>
<ReceiveMessageResponse>
<ReceiveMessageResult>
<Message>
Expand Down Expand Up @@ -215,7 +214,7 @@ async fn should_deserialize_map_parameters_in_response_body() {
</RequestId>
</ResponseMetadata>
</ReceiveMessageResponse>"#,
);
);

let request = ReceiveMessageRequest {
queue_url: "foo".to_owned(),
Expand All @@ -240,8 +239,5 @@ async fn should_deserialize_map_parameters_in_response_body() {
},
);

assert_eq!(
message_attributes,
message.message_attributes.unwrap(),
);
}
assert_eq!(message_attributes, message.message_attributes.unwrap(),);
}

0 comments on commit 4c23423

Please sign in to comment.