diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 6e8f104dc5b..f3208d37649 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -61,6 +61,18 @@ pub trait FlightSqlService: )) } + /// Implementors may override to handle additional calls to do_get() + async fn do_get_fallback( + &self, + _request: Request, + message: prost_types::Any, + ) -> Result::DoGetStream>, Status> { + Err(Status::unimplemented(format!( + "do_get: The defined request is invalid: {}", + message.type_url + ))) + } + /// Get a FlightInfo for executing a SQL query. async fn get_flight_info_statement( &self, @@ -301,18 +313,18 @@ where &self, request: Request, ) -> Result, Status> { - let any: prost_types::Any = + let message: prost_types::Any = Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?; - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_statement(token, request).await; } - if any.is::() { - let handle = any + if message.is::() { + let handle = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -320,64 +332,64 @@ where .get_flight_info_prepared_statement(handle, request) .await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_catalogs(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_schemas(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_tables(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_table_types(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_sql_info(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_primary_keys(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_exported_keys(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.get_flight_info_imported_keys(token, request).await; } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -385,8 +397,8 @@ where } Err(Status::unimplemented(format!( - "get_flight_info: The defined request is invalid: {:?}", - String::from_utf8(any.encode_to_vec()).unwrap() + "get_flight_info: The defined request is invalid: {}", + message.type_url ))) } @@ -401,91 +413,50 @@ where &self, request: Request, ) -> Result, Status> { - let any: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) + let msg: prost_types::Any = prost::Message::decode(&*request.get_ref().ticket) .map_err(decode_error_to_status)?; - if any.is::() { - let token = any - .unpack() + fn unpack(msg: prost_types::Any) -> Result { + msg.unpack() .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_statement(token, request).await; + .ok_or_else(|| Status::internal("Expected a command, but found none.")) } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_prepared_statement(token, request).await; + + if msg.is::() { + return self.do_get_statement(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_catalogs(token, request).await; + if msg.is::() { + return self.do_get_prepared_statement(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_schemas(token, request).await; + if msg.is::() { + return self.do_get_catalogs(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_tables(token, request).await; + if msg.is::() { + return self.do_get_schemas(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_table_types(token, request).await; + if msg.is::() { + return self.do_get_tables(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_sql_info(token, request).await; + if msg.is::() { + return self.do_get_table_types(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_primary_keys(token, request).await; + if msg.is::() { + return self.do_get_sql_info(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_exported_keys(token, request).await; + if msg.is::() { + return self.do_get_primary_keys(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_imported_keys(token, request).await; + if msg.is::() { + return self.do_get_exported_keys(unpack(msg)?, request).await; } - if any.is::() { - let token = any - .unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"); - return self.do_get_cross_reference(token, request).await; + if msg.is::() { + return self.do_get_imported_keys(unpack(msg)?, request).await; + } + if msg.is::() { + return self.do_get_cross_reference(unpack(msg)?, request).await; } - Err(Status::unimplemented(format!( - "do_get: The defined request is invalid: {:?}", - String::from_utf8(request.get_ref().ticket.clone()).unwrap() - ))) + self.do_get_fallback(request, msg).await } async fn do_put( @@ -493,11 +464,11 @@ where mut request: Request>, ) -> Result, Status> { let cmd = request.get_mut().message().await?.unwrap(); - let any: prost_types::Any = + let message: prost_types::Any = prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -508,15 +479,15 @@ where })]); return Ok(Response::new(Box::pin(output))); } - if any.is::() { - let token = any + if message.is::() { + let token = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); return self.do_put_prepared_statement_query(token, request).await; } - if any.is::() { - let handle = any + if message.is::() { + let handle = message .unpack() .map_err(arrow_error_to_status)? .expect("unreachable"); @@ -531,8 +502,8 @@ where } Err(Status::invalid_argument(format!( - "do_put: The defined request is invalid: {:?}", - String::from_utf8(any.encode_to_vec()).unwrap() + "do_put: The defined request is invalid: {}", + message.type_url ))) }