Skip to content

Commit

Permalink
Make remote commands cancellable and remove panics (#861)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazk committed Mar 31, 2022
1 parent e3aeb76 commit 8a33aac
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 94 deletions.
6 changes: 2 additions & 4 deletions examples/pod_attach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ async fn separate_outputs(mut attached: AttachedProcess) {
});

join!(stdouts, stderrs);

if let Some(status) = attached.await {
if let Some(status) = attached.take_status().unwrap().await {
println!("{:?}", status);
}
}
Expand All @@ -109,8 +108,7 @@ async fn combined_output(mut attached: AttachedProcess) {
}
});
outputs.await;

if let Some(status) = attached.await {
if let Some(status) = attached.take_status().unwrap().await {
println!("{:?}", status);
}
}
7 changes: 4 additions & 3 deletions examples/pod_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ async fn main() -> anyhow::Result<()> {
println!("{}", stdout);
assert_eq!(stdout, "test string 1\n");

// AttachedProcess resolves with status object.
// AttachedProcess provides access to a future that resolves with a status object.
let status = attached.take_status().unwrap();
// Send `exit 1` to get a failure status.
stdin_writer.write(b"exit 1\n").await?;
if let Some(status) = attached.await {
if let Some(status) = status.await {
println!("{:?}", status);
assert_eq!(status.status, Some("Failure".to_owned()));
assert_eq!(status.reason, Some("NonZeroExitCode".to_owned()));
Expand All @@ -122,6 +123,6 @@ async fn get_output(mut attached: AttachedProcess) -> String {
.collect::<Vec<_>>()
.await
.join("");
attached.await;
attached.join().await.unwrap();
out
}
2 changes: 1 addition & 1 deletion examples/pod_shell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async fn main() -> anyhow::Result<()> {
tokio::io::copy(&mut stdout_reader, &mut stdout).await.unwrap();
});
// When done, type `exit\n` to end it, so the pod is deleted.
let status = attached.await;
let status = attached.take_status().unwrap().await;
println!("{:?}", status);

// Delete it
Expand Down
179 changes: 95 additions & 84 deletions kube-client/src/api/remote_command.rs
Original file line number Diff line number Diff line change
@@ -1,50 +1,84 @@
use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use std::future::Future;

use k8s_openapi::apimachinery::pkg::apis::meta::v1::Status;

use futures::{
channel::oneshot,
future::{
select,
Either::{Left, Right},
},
SinkExt, StreamExt,
FutureExt, SinkExt, StreamExt,
};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, DuplexStream};
use tokio_tungstenite::{tungstenite as ws, WebSocketStream};

use super::AttachParams;

// Internal state of an attached process
struct AttachedProcessState {
waker: Option<Waker>,
finished: bool,
status: Option<Status>,
stdin_writer: Option<DuplexStream>,
stdout_reader: Option<DuplexStream>,
stderr_reader: Option<DuplexStream>,
type StatusReceiver = oneshot::Receiver<Status>;
type StatusSender = oneshot::Sender<Status>;

/// Errors from attaching to a pod.
#[derive(Debug, Error)]
pub enum Error {
/// Failed to read from stdin
#[error("failed to read from stdin: {0}")]
ReadStdin(#[source] std::io::Error),

/// Failed to send stdin data to the pod
#[error("failed to send a stdin data: {0}")]
SendStdin(#[source] ws::Error),

/// Failed to write to stdout
#[error("failed to write to stdout: {0}")]
WriteStdout(#[source] std::io::Error),

/// Failed to write to stderr
#[error("failed to write to stderr: {0}")]
WriteStderr(#[source] std::io::Error),

/// Failed to receive a WebSocket message from the server.
#[error("failed to receive a WebSocket message: {0}")]
ReceiveWebSocketMessage(#[source] ws::Error),

// Failed to complete the background task
#[error("failed to complete the background task: {0}")]
Spawn(#[source] tokio::task::JoinError),

/// Failed to send close message.
#[error("failed to send a WebSocket close message: {0}")]
SendClose(#[source] ws::Error),

/// Failed to deserialize status object
#[error("failed to deserialize status object: {0}")]
DeserializeStatus(#[source] serde_json::Error),

/// Failed to send status object
#[error("failed to send status object")]
SendStatus,
}

const MAX_BUF_SIZE: usize = 1024;

/// Represents an attached process in a container for [`attach`] and [`exec`].
///
/// Resolves when the connection terminates with an optional [`Status`].
/// Provides access to `stdin`, `stdout`, and `stderr` if attached.
///
/// Use [`AttachedProcess::join`] to wait for the process to terminate.
///
/// [`attach`]: crate::Api::attach
/// [`exec`]: crate::Api::exec
/// [`Status`]: k8s_openapi::apimachinery::pkg::apis::meta::v1::Status
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct AttachedProcess {
has_stdin: bool,
has_stdout: bool,
has_stderr: bool,
state: Arc<Mutex<AttachedProcessState>>,
stdin_writer: Option<DuplexStream>,
stdout_reader: Option<DuplexStream>,
stderr_reader: Option<DuplexStream>,
status_rx: Option<StatusReceiver>,
task: tokio::task::JoinHandle<Result<(), Error>>,
}

impl AttachedProcess {
Expand All @@ -67,32 +101,25 @@ impl AttachedProcess {
} else {
(None, None)
};
let (status_tx, status_rx) = oneshot::channel();

let state = Arc::new(Mutex::new(AttachedProcessState {
waker: None,
finished: false,
status: None,
stdin_writer: Some(stdin_writer),
stdout_reader,
stderr_reader,
}));
let shared_state = state.clone();
tokio::spawn(async move {
let status = start_message_loop(stream, stdin_reader, stdout_writer, stderr_writer).await;

let mut shared = shared_state.lock().unwrap();
shared.finished = true;
shared.status = status;
if let Some(waker) = shared.waker.take() {
waker.wake()
}
});
let task = tokio::spawn(start_message_loop(
stream,
stdin_reader,
stdout_writer,
stderr_writer,
status_tx,
));

AttachedProcess {
has_stdin: ap.stdin,
has_stdout: ap.stdout,
has_stderr: ap.stderr,
state,
task,
stdin_writer: Some(stdin_writer),
stdout_reader,
stderr_reader,
status_rx: Some(status_rx),
}
}

Expand All @@ -106,9 +133,7 @@ impl AttachedProcess {
if !self.has_stdin {
return None;
}

let mut state = self.state.lock().unwrap();
state.stdin_writer.take()
self.stdin_writer.take()
}

/// Async reader for stdout outputs.
Expand All @@ -121,8 +146,7 @@ impl AttachedProcess {
if !self.has_stdout {
return None;
}
let mut state = self.state.lock().unwrap();
state.stdout_reader.take()
self.stdout_reader.take()
}

/// Async reader for stderr outputs.
Expand All @@ -135,30 +159,25 @@ impl AttachedProcess {
if !self.has_stderr {
return None;
}

let mut state = self.state.lock().unwrap();
state.stderr_reader.take()
self.stderr_reader.take()
}
}

impl Future for AttachedProcess {
type Output = Option<Status>;
/// Abort the background task, causing remote command to fail.
#[inline]
pub fn abort(&self) {
self.task.abort();
}

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.state.lock().unwrap();
if state.finished {
Poll::Ready(state.status.take())
} else {
// Update waker if necessary
if let Some(waker) = &state.waker {
if waker.will_wake(cx.waker()) {
return Poll::Pending;
}
}
/// Waits for the remote command task to complete.
pub async fn join(self) -> Result<(), Error> {
self.task.await.unwrap_or_else(|e| Err(Error::Spawn(e)))
}

state.waker = Some(cx.waker().clone());
Poll::Pending
}
/// Take a future that resolves with any status object or when the sender is dropped.
///
/// Returns `None` if called more than once.
pub fn take_status(&mut self) -> Option<impl Future<Output = Option<Status>>> {
self.status_rx.take().map(|recv| recv.map(|res| res.ok()))
}
}

Expand All @@ -174,7 +193,8 @@ async fn start_message_loop<S>(
stdin: impl AsyncRead + Unpin,
mut stdout: Option<impl AsyncWrite + Unpin>,
mut stderr: Option<impl AsyncWrite + Unpin>,
) -> Option<Status>
status_tx: StatusSender,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin + Sized + Send + 'static,
{
Expand All @@ -184,7 +204,6 @@ where
let mut server_recv = raw_server_recv.filter_map(filter_message).boxed();
let mut server_msg = server_recv.next();
let mut next_stdin = stdin_stream.next();
let mut status: Option<Status> = None;

loop {
match select(server_msg, next_stdin).await {
Expand All @@ -193,31 +212,25 @@ where
match message {
Ok(Message::Stdout(bin)) => {
if let Some(stdout) = stdout.as_mut() {
stdout
.write_all(&bin[1..])
.await
.expect("stdout pipe is writable");
stdout.write_all(&bin[1..]).await.map_err(Error::WriteStdout)?;
}
}

Ok(Message::Stderr(bin)) => {
if let Some(stderr) = stderr.as_mut() {
stderr
.write_all(&bin[1..])
.await
.expect("stderr pipe is writable");
stderr.write_all(&bin[1..]).await.map_err(Error::WriteStderr)?;
}
}

Ok(Message::Status(bin)) => {
if let Ok(s) = serde_json::from_slice::<Status>(&bin[1..]) {
status = Some(s);
}
let status =
serde_json::from_slice::<Status>(&bin[1..]).map_err(Error::DeserializeStatus)?;
status_tx.send(status).map_err(|_| Error::SendStatus)?;
break;
}

// Fatal error
Err(err) => {
panic!("AttachedProcess: fatal WebSocket error: {:?}", err);
return Err(Error::ReceiveWebSocketMessage(err));
}
}
server_msg = server_recv.next();
Expand All @@ -238,28 +251,26 @@ where
server_send
.send(ws::Message::binary(vec))
.await
.expect("send stdin");
.map_err(Error::SendStdin)?;
}
server_msg = p_server_msg;
next_stdin = stdin_stream.next();
}

Right((Some(Err(err)), _)) => {
server_send.close().await.expect("send close message");
panic!("AttachedProcess: failed to read from stdin pipe: {:?}", err);
return Err(Error::ReadStdin(err));
}

Right((None, _)) => {
// Stdin closed (writer half dropped).
// Let the server know and disconnect.
// REVIEW warn?
server_send.close().await.expect("send close message");
server_send.close().await.map_err(Error::SendClose)?;
break;
}
}
}

status
Ok(())
}

/// Channeled messages from the server.
Expand Down
5 changes: 3 additions & 2 deletions kube-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ mod test {
.collect::<Vec<_>>()
.await
.join("");
attached.await;
attached.join().await.unwrap();
assert_eq!(out.lines().count(), 3);
assert_eq!(out, "1\n2\n3\n");
}
Expand All @@ -362,7 +362,8 @@ mod test {
// AttachedProcess resolves with status object.
// Send `exit 1` to get a failure status.
stdin_writer.write(b"exit 1\n").await?;
if let Some(status) = attached.await {
let status = attached.take_status().unwrap();
if let Some(status) = status.await {
println!("{:?}", status);
assert_eq!(status.status, Some("Failure".to_owned()));
assert_eq!(status.reason, Some("NonZeroExitCode".to_owned()));
Expand Down

0 comments on commit 8a33aac

Please sign in to comment.