Skip to content

Commit

Permalink
Generate a Client method for Dropshot websocket channels (#183)
Browse files Browse the repository at this point in the history
Generated methods return `ResponseValue<reqwest::Upgrade`, which may be
passed to a websocket protocol implementation such as
`tokio_tungstenite::WebSocketStream::from_raw_stream(rv.into_inner(), ...)`
for the purpose of implementing against the raw websocket connection, but
may later be extended as a generic to allow higher-level channel message
definitions.

Per the changelog, consumers will need to depend on reqwest 0.11.12 or
newer for HTTP Upgrade support, as well as base64 and rand if any
endpoints are websocket channels:
```
[dependencies]
reqwest = { version = "0.11.12" features = ["json", "stream"] }
base64 = "0.13"
rand = "0.8"
```

Co-authored-by: lif <>
  • Loading branch information
lifning committed Sep 28, 2022
1 parent fd1ae2b commit 4e2dcc5
Show file tree
Hide file tree
Showing 16 changed files with 5,510 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.adoc
Expand Up @@ -24,6 +24,7 @@ https://github.com/oxidecomputer/progenitor/compare/v0.1.1\...v0.2.0[Full list o
* Derive `Debug` for `Client` and builders for the various operations (#145)
* Builders for `struct` types (#171)
* Add a prelude that include the `Client` and any extension traits (#176)
* Add support for upgrading connections, which requires a version bump to reqwest. (#183)

== 0.1.1 (released 2022-05-13)

Expand Down
40 changes: 40 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions README.md
Expand Up @@ -54,6 +54,13 @@ Similarly if there is a `format` field set to `uuid`:
+uuid = { version = "1.0.0", features = ["serde", "v4"] }
```

And if there are any websocket channel endpoints:
```diff
[dependencies]
+base64 = "0.13"
+rand = "0.8"
```

The macro has some additional fancy options to control the generated code:

```rust
Expand Down Expand Up @@ -116,7 +123,7 @@ You'll need to add add the following to `Cargo.toml`:
+serde_json = "1.0"
```

(`chrono` and `uuid` as above)
(`chrono`, `uuid`, `base64`, and `rand` as above)

Note that `progenitor` is used by `build.rs`, but the generated code required
`progenitor-client`.
Expand Down Expand Up @@ -290,4 +297,4 @@ let result = client
```

Consumers do not need to specify parameters and struct properties that are not
required or for which the API specifies defaults. Neat!
required or for which the API specifies defaults. Neat!
4 changes: 3 additions & 1 deletion example-build/Cargo.toml
Expand Up @@ -7,7 +7,9 @@ edition = "2021"
[dependencies]
chrono = { version = "0.4", features = ["serde"] }
progenitor-client = { path = "../progenitor-client" }
reqwest = { version = "0.11", features = ["json", "stream"] }
reqwest = { version = "0.11.12", features = ["json", "stream"] }
base64 = "0.13"
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.0", features = ["serde", "v4"] }

Expand Down
2 changes: 1 addition & 1 deletion example-macro/Cargo.toml
Expand Up @@ -7,7 +7,7 @@ edition = "2021"
[dependencies]
chrono = { version = "0.4", features = ["serde"] }
progenitor = { path = "../progenitor" }
reqwest = { version = "0.11", features = ["json", "stream"] }
reqwest = { version = "0.11.12", features = ["json", "stream"] }
schemars = { version = "0.8.10", features = ["uuid1"] }
serde = { version = "1.0", features = ["derive"] }
uuid = { version = "1.0", features = ["serde", "v4"] }
2 changes: 1 addition & 1 deletion progenitor-client/Cargo.toml
Expand Up @@ -10,7 +10,7 @@ description = "An OpenAPI client generator - client support"
bytes = "1.2.1"
futures-core = "0.3.24"
percent-encoding = "2.2"
reqwest = { version = "0.11", default-features = false, features = ["json", "stream"] }
reqwest = { version = "0.11.12", default-features = false, features = ["json", "stream"] }
serde = "1.0"
serde_json = "1.0"
serde_urlencoded = "0.7.1"
24 changes: 24 additions & 0 deletions progenitor-client/src/progenitor_client.rs
Expand Up @@ -76,6 +76,30 @@ impl<T: DeserializeOwned> ResponseValue<T> {
}
}

impl ResponseValue<reqwest::Upgraded> {
#[doc(hidden)]
pub async fn upgrade<E: std::fmt::Debug>(
response: reqwest::Response,
) -> Result<Self, Error<E>> {
let status = response.status();
let headers = response.headers().clone();
if status == reqwest::StatusCode::SWITCHING_PROTOCOLS {
let inner = response
.upgrade()
.await
.map_err(Error::InvalidResponsePayload)?;

Ok(Self {
inner,
status,
headers,
})
} else {
Err(Error::UnexpectedResponse(response))
}
}
}

impl ResponseValue<ByteStream> {
#[doc(hidden)]
pub fn stream(response: reqwest::Response) -> Self {
Expand Down
11 changes: 10 additions & 1 deletion progenitor-impl/src/lib.rs
Expand Up @@ -26,6 +26,8 @@ pub enum Error {
UnexpectedFormat(String),
#[error("invalid operation path {0}")]
InvalidPath(String),
#[error("invalid dropshot extension use: {0}")]
InvalidExtension(String),
#[error("internal error {0}")]
InternalError(String),
}
Expand All @@ -36,6 +38,7 @@ pub struct Generator {
type_space: TypeSpace,
settings: GenerationSettings,
uses_futures: bool,
uses_websockets: bool,
}

#[derive(Default, Clone)]
Expand Down Expand Up @@ -116,6 +119,7 @@ impl Default for Generator {
),
settings: Default::default(),
uses_futures: Default::default(),
uses_websockets: Default::default(),
}
}
}
Expand All @@ -133,6 +137,7 @@ impl Generator {
type_space: TypeSpace::new(&type_settings),
settings: settings.clone(),
uses_futures: false,
uses_websockets: false,
}
}

Expand Down Expand Up @@ -426,7 +431,7 @@ impl Generator {
"bytes = \"1.1\"",
"futures-core = \"0.3\"",
"percent-encoding = \"2.1\"",
"reqwest = { version = \"0.11\", features = [\"json\", \"stream\"] }",
"reqwest = { version = \"0.11.12\", features = [\"json\", \"stream\"] }",
"serde = { version = \"1.0\", features = [\"derive\"] }",
"serde_urlencoded = \"0.7\"",
];
Expand All @@ -444,6 +449,10 @@ impl Generator {
if self.uses_futures {
deps.push("futures = \"0.3\"")
}
if self.uses_websockets {
deps.push("base64 = \"0.13\"");
deps.push("rand = \"0.8\"");
}
if self.type_space.uses_serde_json() {
deps.push("serde_json = \"1.0\"")
}
Expand Down
65 changes: 62 additions & 3 deletions progenitor-impl/src/method.rs
Expand Up @@ -29,6 +29,7 @@ pub(crate) struct OperationMethod {
params: Vec<OperationParameter>,
responses: Vec<OperationResponse>,
dropshot_paginated: Option<DropshotPagination>,
dropshot_websocket: bool,
}

enum HttpMethod {
Expand Down Expand Up @@ -189,6 +190,7 @@ impl OperationResponseStatus {
matches!(
self,
OperationResponseStatus::Default
| OperationResponseStatus::Code(101)
| OperationResponseStatus::Code(200..=299)
| OperationResponseStatus::Range(2)
)
Expand Down Expand Up @@ -225,6 +227,7 @@ enum OperationResponseType {
Type(TypeId),
None,
Raw,
Upgrade,
}

impl Generator {
Expand Down Expand Up @@ -338,6 +341,12 @@ impl Generator {
})
.collect::<Result<Vec<_>>>()?;

let dropshot_websocket =
operation.extensions.get("x-dropshot-websocket").is_some();
if dropshot_websocket {
self.uses_websockets = true;
}

if let Some(body_param) = self.get_body_param(operation, components)? {
params.push(body_param);
}
Expand Down Expand Up @@ -378,9 +387,10 @@ impl Generator {
let (status_code, response) = v?;

// We categorize responses as "typed" based on the
// "application/json" content type, "raw" if there's any other
// response content type (we don't investigate further), or
// "none" if there is no content.
// "application/json" content type, "upgrade" if it's a
// websocket channel without a meaningful content-type,
// "raw" if there's any other response content type (we don't
// investigate further), or "none" if there is no content.
// TODO if there are multiple response content types we could
// treat those like different response types and create an
// enum; the generated client method would check for the
Expand All @@ -407,6 +417,8 @@ impl Generator {
};

OperationResponseType::Type(typ)
} else if dropshot_websocket {
OperationResponseType::Upgrade
} else if response.content.first().is_some() {
OperationResponseType::Raw
} else {
Expand Down Expand Up @@ -449,9 +461,25 @@ impl Generator {
});
}

// Must accept HTTP 101 Switching Protocols
if dropshot_websocket {
responses.push(OperationResponse {
status_code: OperationResponseStatus::Code(101),
typ: OperationResponseType::Upgrade,
description: None,
})
}

let dropshot_paginated =
self.dropshot_pagination_data(operation, &params, &responses);

if dropshot_websocket && dropshot_paginated.is_some() {
return Err(Error::InvalidExtension(format!(
"conflicting extensions in {:?}",
operation_id
)));
}

Ok(OperationMethod {
operation_id: sanitize(operation_id, Case::Snake),
tags: operation.tags.clone(),
Expand All @@ -465,6 +493,7 @@ impl Generator {
params,
responses,
dropshot_paginated,
dropshot_websocket,
})
}

Expand Down Expand Up @@ -705,6 +734,20 @@ impl Generator {
(query_build, query_use)
};

let websock_hdrs = if method.dropshot_websocket {
quote! {
.header(reqwest::header::CONNECTION, "Upgrade")
.header(reqwest::header::UPGRADE, "websocket")
.header(reqwest::header::SEC_WEBSOCKET_VERSION, "13")
.header(
reqwest::header::SEC_WEBSOCKET_KEY,
base64::encode(rand::random::<[u8; 16]>()),
)
}
} else {
quote! {}
};

// Generate the path rename map; then use it to generate code for
// assigning the path parameters to the `url` variable.
let url_renames = method
Expand Down Expand Up @@ -791,6 +834,11 @@ impl Generator {
Ok(ResponseValue::stream(response))
}
}
OperationResponseType::Upgrade => {
quote! {
ResponseValue::upgrade(response).await
}
}
};

quote! { #pat => { #decode } }
Expand Down Expand Up @@ -842,6 +890,13 @@ impl Generator {
))
}
}
OperationResponseType::Upgrade => {
if response.status_code == OperationResponseStatus::Default {
return quote! { } // catch-all handled below
} else {
todo!("non-default error response handling for upgrade requests is not yet implemented");
}
}
};

quote! { #pat => { #decode } }
Expand Down Expand Up @@ -879,6 +934,7 @@ impl Generator {
. #method_func (url)
#(#body_func)*
#query_use
#websock_hdrs
.build()?;
#pre_hook
let result = #client.client
Expand Down Expand Up @@ -988,6 +1044,9 @@ impl Generator {
OperationResponseType::Raw => {
quote! { ByteStream }
}
OperationResponseType::Upgrade => {
quote! { reqwest::Upgraded }
}
})
// TODO should this be a bytestream?
.unwrap_or_else(|| quote! { () });
Expand Down

0 comments on commit 4e2dcc5

Please sign in to comment.