diff --git a/kube-client/src/api/portforward.rs b/kube-client/src/api/portforward.rs index 6dee12b4c..4a009b7be 100644 --- a/kube-client/src/api/portforward.rs +++ b/kube-client/src/api/portforward.rs @@ -60,6 +60,10 @@ pub enum Error { #[error("failed to complete the background task: {0}")] Spawn(#[source] tokio::task::JoinError), + + /// Failed to shutdown a pod writer channel. + #[error("failed to shutdown write to Pod channel: {0}")] + Shutdown(#[source] std::io::Error), } type ErrorReceiver = oneshot::Receiver; @@ -69,6 +73,8 @@ type ErrorSender = oneshot::Sender; enum Message { FromPod(u8, Bytes), ToPod(u8, Bytes), + FromPodClose, + ToPodClose(u8), } /// Manages port-forwarded streams. @@ -139,7 +145,16 @@ impl Portforwarder { /// Waits for port forwarding task to complete. pub async fn join(self) -> Result<(), Error> { - self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e))) + let Self { + mut ports, + mut errors, + task, + } = self; + // Start by terminating any streams that have not yet been taken + // since they would otherwise keep the connection open indefinitely + ports.clear(); + errors.clear(); + task.await.unwrap_or_else(|e| Err(Error::Spawn(e))) } } @@ -192,6 +207,10 @@ async fn to_pod_loop( .map_err(Error::ForwardToPod)?; } } + sender + .send(Message::ToPodClose(ch)) + .await + .map_err(Error::ForwardToPod)?; Ok(()) } @@ -217,6 +236,13 @@ where .await .map_err(Error::ForwardFromPod)?; } + message if message.is_close() => { + sender + .send(Message::FromPodClose) + .await + .map_err(Error::ForwardFromPod)?; + break; + } // REVIEW should we error on unexpected websocket message? _ => {} } @@ -238,19 +264,25 @@ async fn forwarder_loop( where S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static, { - // Keep track if the channel has received the initialization frame. - let mut initialized = vec![false; 2 * ports.len()]; + #[derive(Default, Clone)] + struct ChannelState { + // Keep track if the channel has received the initialization frame. + initialized: bool, + // Keep track if the channel has shutdown. + shutdown: bool, + } + let mut chan_state = vec![ChannelState::default(); 2 * ports.len()]; + let mut closed_ports = 0; + let mut socket_shutdown = false; while let Some(msg) = receiver.next().await { match msg { Message::FromPod(ch, mut bytes) => { let ch = ch as usize; - if ch >= initialized.len() { - return Err(Error::InvalidChannel(ch)); - } + let channel = chan_state.get_mut(ch).ok_or(Error::InvalidChannel(ch))?; let port_index = ch / 2; // Initialization - if !initialized[ch] { + if !channel.initialized { // The initial message must be 3 bytes including the channel prefix. if bytes.len() != 2 { return Err(Error::InvalidInitialFrameSize); @@ -264,7 +296,7 @@ where }); } - initialized[ch] = true; + channel.initialized = true; continue; } @@ -276,7 +308,7 @@ where .map_err(Error::InvalidErrorMessage)?; sender.send(s).map_err(Error::ForwardErrorMessage)?; } - } else { + } else if !channel.shutdown { writers[port_index] .write_all(&bytes) .await @@ -293,6 +325,31 @@ where .await .map_err(Error::SendWebSocketMessage)?; } + Message::ToPodClose(ch) => { + let ch = ch as usize; + let channel = chan_state.get_mut(ch).ok_or(Error::InvalidChannel(ch))?; + let port_index = ch / 2; + + if !channel.shutdown { + writers[port_index].shutdown().await.map_err(Error::Shutdown)?; + channel.shutdown = true; + + closed_ports += 1; + } + } + Message::FromPodClose => { + for writer in &mut writers { + writer.shutdown().await.map_err(Error::Shutdown)?; + } + } + } + + if closed_ports == ports.len() && !socket_shutdown { + ws_sink + .send(ws::Message::Close(None)) + .await + .map_err(Error::SendWebSocketMessage)?; + socket_shutdown = true; } } Ok(())