diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index a7e15fe24cc..f5f28f683a9 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::task::Poll; +use std::{pin::Pin, task::Poll}; use crate::{ decode::FlightRecordBatchStream, @@ -28,6 +28,7 @@ use crate::{ use arrow_schema::Schema; use bytes::Bytes; use futures::{ + channel::oneshot::{Receiver, Sender}, future::ready, ready, stream::{self, BoxStream}, @@ -364,33 +365,18 @@ impl FlightClient { &mut self, request: S, ) -> Result>> { - let (sender, mut receiver) = futures::channel::oneshot::channel(); + let (sender, receiver) = futures::channel::oneshot::channel(); // Intercepts client errors and sends them to the oneshot channel above - let mut request = Box::pin(request); // Pin to heap - let mut sender = Some(sender); // Wrap into Option so can be taken - let request_stream = futures::stream::poll_fn(move |cx| { - Poll::Ready(match ready!(request.poll_next_unpin(cx)) { - Some(Ok(data)) => Some(data), - Some(Err(e)) => { - let _ = sender.take().unwrap().send(e); - None - } - None => None, - }) - }); + let request = Box::pin(request); // Pin to heap + let request_stream = FallibleRequestStream::new(sender, request); let request = self.make_request(request_stream); - let mut response_stream = self.inner.do_put(request).await?.into_inner(); + let response_stream = self.inner.do_put(request).await?.into_inner(); // Forwards errors from the error oneshot with priority over responses from server - let error_stream = futures::stream::poll_fn(move |cx| { - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - } - let next = ready!(response_stream.poll_next_unpin(cx)); - Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) - }); + let response_stream = Box::pin(response_stream); + let error_stream = FallibleTonicResponseStream::new(receiver, response_stream); // combine the response from the server and any error from the client Ok(error_stream.boxed()) @@ -433,33 +419,17 @@ impl FlightClient { &mut self, request: S, ) -> Result { - let (sender, mut receiver) = futures::channel::oneshot::channel(); + let (sender, receiver) = futures::channel::oneshot::channel(); + let request = Box::pin(request); // Intercepts client errors and sends them to the oneshot channel above - let mut request = Box::pin(request); // Pin to heap - let mut sender = Some(sender); // Wrap into Option so can be taken - let request_stream = futures::stream::poll_fn(move |cx| { - Poll::Ready(match ready!(request.poll_next_unpin(cx)) { - Some(Ok(data)) => Some(data), - Some(Err(e)) => { - let _ = sender.take().unwrap().send(e); - None - } - None => None, - }) - }); + let request_stream = FallibleRequestStream::new(sender, request); let request = self.make_request(request_stream); - let mut response_stream = self.inner.do_exchange(request).await?.into_inner(); + let response_stream = self.inner.do_exchange(request).await?.into_inner(); - // Forwards errors from the error oneshot with priority over responses from server - let error_stream = futures::stream::poll_fn(move |cx| { - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - } - let next = ready!(response_stream.poll_next_unpin(cx)); - Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) - }); + let response_stream = Box::pin(response_stream); + let error_stream = FallibleTonicResponseStream::new(receiver, response_stream); // combine the response from the server and any error from the client Ok(FlightRecordBatchStream::new_from_flight_data(error_stream)) @@ -704,3 +674,99 @@ impl FlightClient { request } } + +/// Wrapper around fallible stream such that when +/// it encounters an error it uses the oneshot sender to +/// notify the error and stop any further streaming. See `do_put` or +/// `do_exchange` for it's uses. +struct FallibleRequestStream { + /// sender to notify error + sender: Option>, + /// fallible stream + fallible_stream: Pin> + Send + 'static>>, +} + +impl FallibleRequestStream { + fn new( + sender: Sender, + fallible_stream: Pin> + Send + 'static>>, + ) -> Self { + Self { + sender: Some(sender), + fallible_stream, + } + } +} + +impl Stream for FallibleRequestStream { + type Item = T; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pinned = self.get_mut(); + let mut request_streams = pinned.fallible_stream.as_mut(); + match ready!(request_streams.poll_next_unpin(cx)) { + Some(Ok(data)) => Poll::Ready(Some(data)), + Some(Err(e)) => { + // unwrap() here is safe, ownership of sender will + // be moved only once as this stream will not be polled + // again + let _ = pinned.sender.take().unwrap().send(e); + Poll::Ready(None) + } + None => Poll::Ready(None), + } + } +} + +/// Wrapper for a tonic response stream that can produce a tonic +/// error. This is tied to a oneshot receiver which can be notified +/// of other errors. When it receives an error through receiver +/// end, it prioritises that error to be sent back. See `do_put` or +/// `do_exchange` for it's uses +struct FallibleTonicResponseStream { + /// Receiver for FlightError + receiver: Receiver, + /// Tonic response stream + response_stream: + Pin> + Send + 'static>>, +} + +impl FallibleTonicResponseStream { + fn new( + receiver: Receiver, + response_stream: Pin< + Box> + Send + 'static>, + >, + ) -> Self { + Self { + receiver, + response_stream, + } + } +} + +impl Stream for FallibleTonicResponseStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let pinned = self.get_mut(); + let receiver = &mut pinned.receiver; + // Prioritise sending the error that's been notified over + // polling the response_stream + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + }; + + match ready!(pinned.response_stream.poll_next_unpin(cx)) { + Some(Ok(res)) => Poll::Ready(Some(Ok(res))), + Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))), + None => Poll::Ready(None), + } + } +}