Skip to content

Commit

Permalink
feat!: Response.filter_header(), with_headers(), HeaderError
Browse files Browse the repository at this point in the history
In tiny-http#209 it has been requested to provide a feature to remove the Server HTTP header from responses.

You can add now a filter to the response with filter_header(HeaderField).
This prevents the header to be sent with the response.

All methods adding header stuff to the response now return a Result with HeaderError if the header got not added.

For adding multiple headers at once, there is now add_headers(Vec) and with_headers(Vec).
  • Loading branch information
kolbma committed Jan 17, 2024
1 parent 027e514 commit a4187cc
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 58 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Expand Up @@ -14,6 +14,11 @@
- 431 Request Header Fields Too Large when line is over 2048 bytes
- complete header is limited to 8192 bytes

* BREAKING CHANGE: Response-header-methods return a Result with HeaderError if the header got not added

Affected methods are add_header(), add_headers(), filter_header(), with_header(), with_headers().
Response.filter_header(HeaderField) prevents the HeaderField to be sent in the Response.

* New feature _native-tls_

_native-tls_ is a crate that will pick the platforms native TLS implementation depending on the chosen build target.
Expand Down
10 changes: 6 additions & 4 deletions examples/serve-root.rs
Expand Up @@ -39,10 +39,12 @@ fn main() {
if let Ok(file) = file {
let response = tiny_http::Response::from_file(file);

let response = response.with_header(tiny_http::Header {
field: "Content-Type".parse().unwrap(),
value: AsciiString::from_ascii(get_content_type(path)).unwrap(),
});
let response = response
.with_header(tiny_http::Header {
field: "Content-Type".parse().unwrap(),
value: AsciiString::from_ascii(get_content_type(path)).unwrap(),
})
.unwrap();

if let Err(err) = rq.respond(response) {
eprintln!("{err:#?}");
Expand Down
12 changes: 6 additions & 6 deletions examples/websockets.rs
Expand Up @@ -36,6 +36,7 @@ fn home_page(port: u16) -> tiny_http::Response<Cursor<Vec<u8>>> {
.parse::<tiny_http::Header>()
.unwrap(),
)
.unwrap()
}

/// Turns a Sec-WebSocket-Key into a Sec-WebSocket-Accept.
Expand Down Expand Up @@ -102,18 +103,17 @@ fn main() {

// building the "101 Switching Protocols" response
let response = tiny_http::Response::new_empty(tiny_http::StatusCode(101))
.with_header("Upgrade: websocket".parse::<tiny_http::Header>().unwrap())
.with_header("Connection: Upgrade".parse::<tiny_http::Header>().unwrap())
.with_header(
.with_headers(Vec::from([
"Upgrade: websocket".parse::<tiny_http::Header>().unwrap(),
"Connection: Upgrade".parse::<tiny_http::Header>().unwrap(),
"Sec-WebSocket-Protocol: ping"
.parse::<tiny_http::Header>()
.unwrap(),
)
.with_header(
format!("Sec-WebSocket-Accept: {}", convert_key(key.as_str()))
.parse::<tiny_http::Header>()
.unwrap(),
);
]))
.unwrap();

//
let mut stream = request.upgrade("websocket", response);
Expand Down
61 changes: 48 additions & 13 deletions src/common.rs
Expand Up @@ -2,6 +2,7 @@ use ascii::{AsAsciiStr, AsciiChar, AsciiStr, AsciiString, FromAsciiError};
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::fmt::{self, Display, Formatter};
use std::hash::Hash;
use std::str::FromStr;

/// Status code of a request or response.
Expand Down Expand Up @@ -180,16 +181,19 @@ impl Header {
}

impl FromStr for Header {
type Err = ();
type Err = HeaderError;

fn from_str(input: &str) -> Result<Header, ()> {
fn from_str(input: &str) -> Result<Header, HeaderError> {
let mut elems = input.splitn(2, ':');

let field = elems.next().and_then(|f| f.parse().ok()).ok_or(())?;
let field = elems
.next()
.and_then(|f| f.parse().ok())
.ok_or(HeaderError)?;
let value = elems
.next()
.and_then(|v| AsciiString::from_ascii(v.trim()).ok())
.ok_or(())?;
.ok_or(HeaderError)?;

Ok(Header { field, value })
}
Expand Down Expand Up @@ -240,13 +244,15 @@ impl HeaderField {
}

impl FromStr for HeaderField {
type Err = ();
type Err = HeaderError;

fn from_str(s: &str) -> Result<HeaderField, ()> {
fn from_str(s: &str) -> Result<HeaderField, HeaderError> {
if s.contains(char::is_whitespace) {
Err(())
Err(HeaderError)
} else {
AsciiString::from_ascii(s).map(HeaderField).map_err(|_| ())
AsciiString::from_ascii(s)
.map(HeaderField)
.map_err(|_| HeaderError)
}
}
}
Expand Down Expand Up @@ -278,6 +284,28 @@ impl PartialEq for HeaderField {
}
}

impl Hash for HeaderField {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_ascii_lowercase().hash(state);
}
}

// Needs to be lower-case!!!
pub(crate) const HEADER_FORBIDDEN: &[&str] =
&["connection", "trailer", "transfer-encoding", "upgrade"];

/// Header was not added
#[derive(Debug)]
pub struct HeaderError;

impl std::error::Error for HeaderError {}

impl std::fmt::Display for HeaderError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("header not allowed")
}
}

/// HTTP request methods
///
/// As per [RFC 7231](https://tools.ietf.org/html/rfc7231#section-4.1) and
Expand Down Expand Up @@ -450,13 +478,13 @@ mod test {
use ascii::AsAsciiStr;
use httpdate::HttpDate;

use super::Header;
use super::{Header, HEADER_FORBIDDEN};

#[test]
fn test_parse_header() {
let header: Header = "Content-Type: text/html".parse().unwrap();

assert!(header.field.equiv(&"content-type"));
assert!(header.field.equiv("content-type"));
assert!(header.value.as_str() == "text/html");

assert!("hello world".parse::<Header>().is_err());
Expand All @@ -467,7 +495,7 @@ mod test {
let header: Header =
Header::try_from("Content-Type: text/html".as_ascii_str().unwrap()).unwrap();

assert!(header.field.equiv(&"content-type"));
assert!(header.field.equiv("content-type"));
assert!(header.value.as_str() == "text/html");
}

Expand All @@ -482,15 +510,15 @@ mod test {
fn test_parse_header_with_doublecolon() {
let header: Header = "Time: 20: 34".parse().unwrap();

assert!(header.field.equiv(&"time"));
assert!(header.field.equiv("time"));
assert!(header.value.as_str() == "20: 34");
}

#[test]
fn test_header_with_doublecolon_try_from_ascii() {
let header: Header = Header::try_from("Time: 20: 34".as_ascii_str().unwrap()).unwrap();

assert!(header.field.equiv(&"time"));
assert!(header.field.equiv("time"));
assert!(header.value.as_str() == "20: 34");
}

Expand Down Expand Up @@ -539,4 +567,11 @@ mod test {
);
}
}

#[test]
fn test_header_forbidden_lc() {
for h in HEADER_FORBIDDEN {
assert_eq!(h, &h.to_lowercase());
}
}
}

0 comments on commit a4187cc

Please sign in to comment.