diff --git a/quinn-proto/src/connection/streams/state.rs b/quinn-proto/src/connection/streams/state.rs index 55071c073..e60a3fc6a 100644 --- a/quinn-proto/src/connection/streams/state.rs +++ b/quinn-proto/src/connection/streams/state.rs @@ -812,7 +812,7 @@ impl StreamsState { /// again. pub(super) fn add_read_credits(&mut self, credits: u64) -> ShouldTransmit { if credits > self.receive_window_shrink_debt { - let net_credits = credits.saturating_sub(self.receive_window_shrink_debt); + let net_credits = credits - self.receive_window_shrink_debt; self.local_max_data = self.local_max_data.saturating_add(net_credits); self.receive_window_shrink_debt = 0; } else { @@ -1585,9 +1585,9 @@ mod tests { #[test] fn expand_receive_window() { let mut server = make(Side::Server); - let new_receive_window = 2 * 1024 * 1024u32; - let explanded = server.set_receive_window(new_receive_window.into()); - assert!(explanded); + let new_receive_window = 2 * server.receive_window as u32; + let expanded = server.set_receive_window(new_receive_window.into()); + assert!(expanded); assert_eq!(server.receive_window, new_receive_window as u64); assert_eq!(server.local_max_data, new_receive_window as u64); assert_eq!(server.receive_window_shrink_debt, 0); @@ -1595,63 +1595,56 @@ mod tests { // credit, expecting all of them added to local_max_data let credits = 1024u64; - let shuould_transmit = server.add_read_credits(credits); + let should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, prev_local_max_data + credits); - assert!(shuould_transmit.should_transmit()); + assert!(should_transmit.should_transmit()); } #[test] fn shrink_receive_window() { let mut server = make(Side::Server); - let new_receive_window = 1024 * 512u32; + let new_receive_window = server.receive_window as u32 / 2; let prev_local_max_data = server.local_max_data; // shrink the receive_winbow, local_max_data is not expected to be changed let shrink_diff = server.receive_window - new_receive_window as u64; - let explanded = server.set_receive_window(new_receive_window.into()); - assert!(!explanded); + let expanded = server.set_receive_window(new_receive_window.into()); + assert!(!expanded); assert_eq!(server.receive_window, new_receive_window as u64); assert_eq!(server.local_max_data, prev_local_max_data); assert_eq!(server.receive_window_shrink_debt, shrink_diff); let prev_local_max_data = server.local_max_data; - // credit, local_max_data does not change as it is absorbed by receive_window_shrink_debt + // credit twice, local_max_data does not change as it is absorbed by receive_window_shrink_debt let credits = 1024u64; - let expected_receive_window_shrink_debt = server.receive_window_shrink_debt - credits; - let shuould_transmit = server.add_read_credits(credits); - assert_eq!( - server.receive_window_shrink_debt, - expected_receive_window_shrink_debt - ); - assert_eq!(server.local_max_data, prev_local_max_data); - assert!(!shuould_transmit.should_transmit()); - - // credit again, local_max_data does not change as it is absorbed by receive_window_shrink_debt - let expected_receive_window_shrink_debt = server.receive_window_shrink_debt - credits; - let shuould_transmit = server.add_read_credits(credits); - assert_eq!( - server.receive_window_shrink_debt, - expected_receive_window_shrink_debt - ); - assert_eq!(server.local_max_data, prev_local_max_data); - assert!(!shuould_transmit.should_transmit()); + for _ in 0..2 { + let expected_receive_window_shrink_debt = server.receive_window_shrink_debt - credits; + let should_transmit = server.add_read_credits(credits); + assert_eq!( + server.receive_window_shrink_debt, + expected_receive_window_shrink_debt + ); + assert_eq!(server.local_max_data, prev_local_max_data); + assert!(!should_transmit.should_transmit()); + } // credit again which exceeds all remaining expected_receive_window_shrink_debt let credits = 1024 * 512; + let prev_local_max_data = server.local_max_data; let expected_local_max_data = server.local_max_data + (credits - server.receive_window_shrink_debt); - let shuould_transmit = server.add_read_credits(credits); + let _should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, expected_local_max_data); - assert!(!shuould_transmit.should_transmit()); + assert!(server.local_max_data > prev_local_max_data); // credit again, all should be added to local_max_data let credits = 1024 * 512; let expected_local_max_data = server.local_max_data + credits; - let shuould_transmit = server.add_read_credits(credits); + let should_transmit = server.add_read_credits(credits); assert_eq!(server.receive_window_shrink_debt, 0); assert_eq!(server.local_max_data, expected_local_max_data); - assert!(shuould_transmit.should_transmit()); + assert!(should_transmit.should_transmit()); } }