Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: portforward connection cleanup #973

Merged
merged 6 commits into from Sep 7, 2022
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
40 changes: 27 additions & 13 deletions kube-client/src/api/portforward.rs
Expand Up @@ -144,9 +144,17 @@ impl Portforwarder {
}

/// Waits for port forwarding task to complete.
pub async fn join(mut self) -> Result<(), Error> {
self.ports.clear();
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
pub async fn join(self) -> Result<(), Error> {
let Self {
mut ports,
mut errors,
task,
} = self;
// Start by terminating any streams that have not yet been taken
// since they would otherwise keep the connection open indefinitely
ports.clear();
errors.clear();
task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
}

Expand Down Expand Up @@ -233,6 +241,7 @@ where
.send(Message::FromPodClose)
.await
.map_err(Error::ForwardFromPod)?;
kazk marked this conversation as resolved.
Show resolved Hide resolved
break;
}
// REVIEW should we error on unexpected websocket message?
_ => {}
Expand All @@ -255,22 +264,27 @@ async fn forwarder_loop<S>(
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
// Keep track if the channel has received the initialization frame.
let mut initialized = vec![false; 2 * ports.len()];
let mut shutdown = vec![false; 2 * ports.len()];
#[derive(Default, Clone)]
struct ChannelState {
// Keep track if the channel has received the initialization frame.
initialized: bool,
// Keep track if the channel has shutdown.
shutdown: bool,
}
let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
let mut closed_ports = 0;
let mut socket_shutdown = false;
while let Some(msg) = receiver.next().await {
match msg {
Message::FromPod(ch, mut bytes) => {
let ch = ch as usize;
if ch >= initialized.len() {
if ch >= chan_state.len() {
return Err(Error::InvalidChannel(ch));
}
tiagolobocastro marked this conversation as resolved.
Show resolved Hide resolved

let port_index = ch / 2;
// Initialization
if !initialized[ch] {
if !chan_state[ch].initialized {
// The initial message must be 3 bytes including the channel prefix.
if bytes.len() != 2 {
return Err(Error::InvalidInitialFrameSize);
Expand All @@ -284,7 +298,7 @@ where
});
}

initialized[ch] = true;
chan_state[ch].initialized = true;
continue;
}

Expand All @@ -297,7 +311,7 @@ where
sender.send(s).map_err(Error::ForwardErrorMessage)?;
}
} else {
if !shutdown[port_index] {
if !chan_state[port_index].shutdown {
writers[port_index]
.write_all(&bytes)
.await
Expand All @@ -317,13 +331,13 @@ where
}
Message::ToPodClose(ch) => {
let ch = ch as usize;
if ch >= initialized.len() {
if ch >= chan_state.len() {
return Err(Error::InvalidChannel(ch));
}
let port_index = ch / 2;
if !shutdown[port_index] {
if !chan_state[port_index].shutdown {
writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
shutdown[port_index] = true;
chan_state[port_index].shutdown = true;

closed_ports += 1;
}
Expand Down