Skip to content

Commit

Permalink
Merge pull request kube-rs#973 from tiagolobocastro/portforward-fix
Browse files Browse the repository at this point in the history
fix: portforward connections
  • Loading branch information
nightkr committed Sep 7, 2022
2 parents 05f62c0 + 3be17bf commit 95397a5
Showing 1 changed file with 66 additions and 9 deletions.
75 changes: 66 additions & 9 deletions kube-client/src/api/portforward.rs
Expand Up @@ -60,6 +60,10 @@ pub enum Error {

#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),

/// Failed to shutdown a pod writer channel.
#[error("failed to shutdown write to Pod channel: {0}")]
Shutdown(#[source] std::io::Error),
}

type ErrorReceiver = oneshot::Receiver<String>;
Expand All @@ -69,6 +73,8 @@ type ErrorSender = oneshot::Sender<String>;
enum Message {
FromPod(u8, Bytes),
ToPod(u8, Bytes),
FromPodClose,
ToPodClose(u8),
}

/// Manages port-forwarded streams.
Expand Down Expand Up @@ -139,7 +145,16 @@ impl Portforwarder {

/// Waits for port forwarding task to complete.
pub async fn join(self) -> Result<(), Error> {
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
let Self {
mut ports,
mut errors,
task,
} = self;
// Start by terminating any streams that have not yet been taken
// since they would otherwise keep the connection open indefinitely
ports.clear();
errors.clear();
task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}
}

Expand Down Expand Up @@ -192,6 +207,10 @@ async fn to_pod_loop(
.map_err(Error::ForwardToPod)?;
}
}
sender
.send(Message::ToPodClose(ch))
.await
.map_err(Error::ForwardToPod)?;
Ok(())
}

Expand All @@ -217,6 +236,13 @@ where
.await
.map_err(Error::ForwardFromPod)?;
}
message if message.is_close() => {
sender
.send(Message::FromPodClose)
.await
.map_err(Error::ForwardFromPod)?;
break;
}
// REVIEW should we error on unexpected websocket message?
_ => {}
}
Expand All @@ -238,19 +264,25 @@ async fn forwarder_loop<S>(
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
// Keep track if the channel has received the initialization frame.
let mut initialized = vec![false; 2 * ports.len()];
#[derive(Default, Clone)]
struct ChannelState {
// Keep track if the channel has received the initialization frame.
initialized: bool,
// Keep track if the channel has shutdown.
shutdown: bool,
}
let mut chan_state = vec![ChannelState::default(); 2 * ports.len()];
let mut closed_ports = 0;
let mut socket_shutdown = false;
while let Some(msg) = receiver.next().await {
match msg {
Message::FromPod(ch, mut bytes) => {
let ch = ch as usize;
if ch >= initialized.len() {
return Err(Error::InvalidChannel(ch));
}
let channel = chan_state.get_mut(ch).ok_or(Error::InvalidChannel(ch))?;

let port_index = ch / 2;
// Initialization
if !initialized[ch] {
if !channel.initialized {
// The initial message must be 3 bytes including the channel prefix.
if bytes.len() != 2 {
return Err(Error::InvalidInitialFrameSize);
Expand All @@ -264,7 +296,7 @@ where
});
}

initialized[ch] = true;
channel.initialized = true;
continue;
}

Expand All @@ -276,7 +308,7 @@ where
.map_err(Error::InvalidErrorMessage)?;
sender.send(s).map_err(Error::ForwardErrorMessage)?;
}
} else {
} else if !channel.shutdown {
writers[port_index]
.write_all(&bytes)
.await
Expand All @@ -293,6 +325,31 @@ where
.await
.map_err(Error::SendWebSocketMessage)?;
}
Message::ToPodClose(ch) => {
let ch = ch as usize;
let channel = chan_state.get_mut(ch).ok_or(Error::InvalidChannel(ch))?;
let port_index = ch / 2;

if !channel.shutdown {
writers[port_index].shutdown().await.map_err(Error::Shutdown)?;
channel.shutdown = true;

closed_ports += 1;
}
}
Message::FromPodClose => {
for writer in &mut writers {
writer.shutdown().await.map_err(Error::Shutdown)?;
}
}
}

if closed_ports == ports.len() && !socket_shutdown {
ws_sink
.send(ws::Message::Close(None))
.await
.map_err(Error::SendWebSocketMessage)?;
socket_shutdown = true;
}
}
Ok(())
Expand Down

0 comments on commit 95397a5

Please sign in to comment.