Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: portforward connection cleanup #973

Merged
merged 6 commits into from Sep 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
75 changes: 66 additions & 9 deletions kube-client/src/api/portforward.rs
Expand Up @@ -60,6 +60,10 @@ pub enum Error {

#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),

/// Failed to shutdown a pod writer channel.
#[error("failed to shutdown write to Pod channel: {0}")]
Shutdown(#[source] std::io::Error),
}

type ErrorReceiver = oneshot::Receiver<String>;
Expand All @@ -69,6 +73,8 @@ type ErrorSender = oneshot::Sender<String>;
enum Message {
FromPod(u8, Bytes),
ToPod(u8, Bytes),
FromPodClose,
ToPodClose(u8),
}

/// Manages port-forwarded streams.
Expand Down Expand Up @@ -139,7 +145,16 @@ impl Portforwarder {

/// Waits for port forwarding task to complete.
pub async fn join(self) -> Result<(), Error> {
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
let Self {
mut ports,
mut errors,
task,
} = self;
// Start by terminating any streams that have not yet been taken
// since they would otherwise keep the connection open indefinitely
ports.clear();
errors.clear();
task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
}

Expand Down Expand Up @@ -192,6 +207,10 @@ async fn to_pod_loop(
.map_err(Error::ForwardToPod)?;
}
}
sender
.send(Message::ToPodClose(ch))
.await
.map_err(Error::ForwardToPod)?;
Ok(())
}

Expand All @@ -217,6 +236,13 @@ where
.await
.map_err(Error::ForwardFromPod)?;
}
message if message.is_close() => {
sender
.send(Message::FromPodClose)
.await
.map_err(Error::ForwardFromPod)?;
kazk marked this conversation as resolved.
Show resolved Hide resolved
break;
}
// REVIEW should we error on unexpected websocket message?
_ => {}
}
Expand All @@ -238,19 +264,25 @@ async fn forwarder_loop<S>(
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
// Keep track if the channel has received the initialization frame.
let mut initialized = vec![false; 2 * ports.len()];
#[derive(Default, Clone)]
struct ChannelState {
// Keep track if the channel has received the initialization frame.
initialized: bool,
// Keep track if the channel has shutdown.
shutdown: bool,
}
let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
let mut closed_ports = 0;
let mut socket_shutdown = false;
while let Some(msg) = receiver.next().await {
match msg {
Message::FromPod(ch, mut bytes) => {
let ch = ch as usize;
if ch >= initialized.len() {
return Err(Error::InvalidChannel(ch));
}
let channel = chan_state.get_mut(ch).ok_or(Error::InvalidChannel(ch))?;

let port_index = ch / 2;
// Initialization
if !initialized[ch] {
if !channel.initialized {
// The initial message must be 3 bytes including the channel prefix.
if bytes.len() != 2 {
return Err(Error::InvalidInitialFrameSize);
Expand All @@ -264,7 +296,7 @@ where
});
}

initialized[ch] = true;
channel.initialized = true;
continue;
}

Expand All @@ -276,7 +308,7 @@ where
.map_err(Error::InvalidErrorMessage)?;
sender.send(s).map_err(Error::ForwardErrorMessage)?;
}
} else {
} else if !channel.shutdown {
writers[port_index]
.write_all(&bytes)
.await
Expand All @@ -293,6 +325,31 @@ where
.await
.map_err(Error::SendWebSocketMessage)?;
}
Message::ToPodClose(ch) => {
let ch = ch as usize;
let channel = chan_state.get_mut(ch).ok_or(Error::InvalidChannel(ch))?;
let port_index = ch / 2;

if !channel.shutdown {
writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
channel.shutdown = true;

closed_ports += 1;
}
}
Message::FromPodClose => {
for writer in &mut writers {
writer.shutdown().await.map_err(Error::Shutdown)?;
}
kazk marked this conversation as resolved.
Show resolved Hide resolved
}
}

if closed_ports == ports.len() && !socket_shutdown {
kazk marked this conversation as resolved.
Show resolved Hide resolved
ws_sink
.send(ws::Message::Close(None))
.await
.map_err(Error::SendWebSocketMessage)?;
socket_shutdown = true;
}
}
Ok(())
Expand Down