From 5a89e7bcf16fbe62be461e01d87d6c76ee193a9c Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 23 Aug 2022 09:47:27 -0600 Subject: [PATCH 1/4] Allow overriding of do_get & export useful macro --- arrow-flight/src/sql/mod.rs | 1 + arrow-flight/src/sql/server.rs | 175 ++++++++++++++------------------- 2 files changed, 73 insertions(+), 103 deletions(-) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index cd198a1401d..6b5aab4e502 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -69,6 +69,7 @@ pub trait ProstMessageExt: prost::Message + Default { fn as_any(&self) -> prost_types::Any; } +#[macro_export] macro_rules! prost_message_ext { ($($name:ty,)*) => { $( diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 6e8f104dc5b..a115ba4119d 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -61,6 +61,17 @@ pub trait FlightSqlService: )) } + /// Implementors may override to handle additional calls to do_get() + async fn custom_do_get( + &self, + request: Request, + message: prost_types::Any, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", message.type_url + ))) + } + /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, @@ -301,18 +312,18 @@ where &self, request: Request, ) -> Result, Status> { - let any: prost_types::Any = + let message: prost_types::Any = Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_statement(token, request).await; } - if any.is::() { - let handle = any + if message.is::() { + let handle = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -320,64 +331,64 @@ where .get_flight_info_prepared_statement(handle, request) .await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_catalogs(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_schemas(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_tables(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_table_types(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_sql_info(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_primary_keys(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_exported_keys(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_imported_keys(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -385,8 +396,7 @@ where } Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {:?}", - String::from_utf8(any.encode_to_vec()).unwrap() + "get_flight_info: The defined request is invalid: {}", message.type_url ))) } @@ -401,91 +411,51 @@ where &self, request: Request, ) -> Result, Status> { - let any: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) + let msg: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) .map_err(decode_error_to_status)?; - if any.is::() { - let token = any + fn unpack(msg: prost_types::Any) -> Result { + msg .unpack() .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_statement(token, request).await; + .ok_or(Status::internal("Expected a command, but found none.".to_string())) } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_prepared_statement(token, request).await; + + if msg.is::() { + return self.do_get_statement(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_catalogs(token, request).await; + if msg.is::() { + return self.do_get_prepared_statement(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_schemas(token, request).await; + if msg.is::() { + return self.do_get_catalogs(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_tables(token, request).await; + if msg.is::() { + return self.do_get_schemas(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_table_types(token, request).await; + if msg.is::() { + return self.do_get_tables(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_sql_info(token, request).await; + if msg.is::() { + return self.do_get_table_types(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_primary_keys(token, request).await; + if msg.is::() { + return self.do_get_sql_info(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_exported_keys(token, request).await; + if msg.is::() { + return self.do_get_primary_keys(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_imported_keys(token, request).await; + if msg.is::() { + return self.do_get_exported_keys(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_cross_reference(token, request).await; + if msg.is::() { + return self.do_get_imported_keys(unpack(msg)?, request).await; + } + if msg.is::() { + return self.do_get_cross_reference(unpack(msg)?, request).await; } - Err(Status::unimplemented(format!( - "do_get: The defined request is invalid: {:?}", - String::from_utf8(request.get_ref().ticket.clone()).unwrap() - ))) + self.custom_do_get(request, msg).await } async fn do_put( @@ -493,11 +463,11 @@ where mut request: Request>, ) -> Result, Status> { let cmd = request.get_mut().message().await?.unwrap(); - let any: prost_types::Any = + let message: prost_types::Any = prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -508,15 +478,15 @@ where })]); return Ok(Response::new(Box::pin(output))); } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.do_put_prepared_statement_query(token, request).await; } - if any.is::() { - let handle = any + if message.is::() { + let handle = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -531,8 +501,7 @@ where } Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {:?}", - String::from_utf8(any.encode_to_vec()).unwrap() + "do_put: The defined request is invalid: {}", message.type_url ))) } From 73955c9576c7cd3a5c14b0271613347a0db0d1d1 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 24 Aug 2022 13:47:55 -0600 Subject: [PATCH 2/4] All hail clippy --- arrow-flight/src/sql/server.rs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index a115ba4119d..1d118b04f62 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -64,11 +64,12 @@ pub trait FlightSqlService: /// Implementors may override to handle additional calls to do_get() async fn custom_do_get( &self, - request: Request, + _request: Request, message: prost_types::Any, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented(format!( - "do_get: The defined request is invalid: {}", message.type_url + "do_get: The defined request is invalid: {}", + message.type_url ))) } @@ -396,7 +397,8 @@ where } Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {}", message.type_url + "get_flight_info: The defined request is invalid: {}", + message.type_url ))) } @@ -415,10 +417,9 @@ where .map_err(decode_error_to_status)?; fn unpack(msg: prost_types::Any) -> Result { - msg - .unpack() + msg.unpack() .map_err(arrow_error_to_status)? - .ok_or(Status::internal("Expected a command, but found none.".to_string())) + .ok_or_else(|| Status::internal("Expected a command, but found none.")) } if msg.is::() { @@ -501,7 +502,8 @@ where } Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {}", message.type_url + "do_put: The defined request is invalid: {}", + message.type_url ))) } From 82c0dbce327f4df49bb052a1bdb498cd9a617461 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 25 Aug 2022 12:33:13 -0600 Subject: [PATCH 3/4] Remove macro export --- arrow-flight/src/sql/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 6b5aab4e502..cd198a1401d 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -69,7 +69,6 @@ pub trait ProstMessageExt: prost::Message + Default { fn as_any(&self) -> prost_types::Any; } -#[macro_export] macro_rules! prost_message_ext { ($($name:ty,)*) => { $( From 46544182f14c6aebcdb869e9ec429beb4563d618 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Thu, 25 Aug 2022 12:41:47 -0600 Subject: [PATCH 4/4] Rename function --- arrow-flight/src/sql/server.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 1d118b04f62..f3208d37649 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -62,7 +62,7 @@ pub trait FlightSqlService: } /// Implementors may override to handle additional calls to do_get() - async fn custom_do_get( + async fn do_get_fallback( &self, _request: Request, message: prost_types::Any, @@ -456,7 +456,7 @@ where return self.do_get_cross_reference(unpack(msg)?, request).await; } - self.custom_do_get(request, msg).await + self.do_get_fallback(request, msg).await } async fn do_put(