From 4e28be7f9b393862a2e02f8b7ff4f569622d09dc Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Wed, 3 Aug 2022 19:03:20 -0600 Subject: [PATCH] Pass metadata to impl --- arrow-flight/examples/flight_sql_server.rs | 28 ++ arrow-flight/src/sql/server.rs | 357 ++++++++++----------- 2 files changed, 195 insertions(+), 190 deletions(-) diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 7e2a759c559..3973be246c6 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -19,6 +19,7 @@ use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo}; use arrow_flight::{FlightData, HandshakeRequest, HandshakeResponse}; use futures::Stream; use std::pin::Pin; +use tonic::metadata::MetadataMap; use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; @@ -94,6 +95,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandStatementQuery, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -101,6 +103,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementQuery, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -108,6 +111,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetCatalogs, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -115,6 +119,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetDbSchemas, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -122,6 +127,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetTables, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -129,6 +135,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetTableTypes, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -136,6 +143,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetSqlInfo, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -143,6 +151,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetPrimaryKeys, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -150,6 +159,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetExportedKeys, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -157,6 +167,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetImportedKeys, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -164,6 +175,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandGetCrossReference, _request: FlightDescriptor, + _metadata: MetadataMap, ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -171,6 +183,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_statement( &self, _ticket: TicketStatementQuery, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -178,60 +191,70 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_get_prepared_statement( &self, _query: CommandPreparedStatementQuery, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_catalogs( &self, _query: CommandGetCatalogs, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_schemas( &self, _query: CommandGetDbSchemas, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_tables( &self, _query: CommandGetTables, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_table_types( &self, _query: CommandGetTableTypes, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_sql_info( &self, _query: CommandGetSqlInfo, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_primary_keys( &self, _query: CommandGetPrimaryKeys, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_exported_keys( &self, _query: CommandGetExportedKeys, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_imported_keys( &self, _query: CommandGetImportedKeys, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } async fn do_get_cross_reference( &self, _query: CommandGetCrossReference, + _metadata: MetadataMap, ) -> Result::DoGetStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -239,6 +262,7 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_put_statement_update( &self, _ticket: CommandStatementUpdate, + _metadata: MetadataMap, ) -> Result { Err(Status::unimplemented("Not yet implemented")) } @@ -246,6 +270,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementQuery, _request: Streaming, + _metadata: MetadataMap, ) -> Result::DoPutStream>, Status> { Err(Status::unimplemented("Not yet implemented")) } @@ -253,6 +278,7 @@ impl FlightSqlService for FlightSqlServiceImpl { &self, _query: CommandPreparedStatementUpdate, _request: Streaming, + _metadata: MetadataMap, ) -> Result { Err(Status::unimplemented("Not yet implemented")) } @@ -260,12 +286,14 @@ impl FlightSqlService for FlightSqlServiceImpl { async fn do_action_create_prepared_statement( &self, _query: ActionCreatePreparedStatementRequest, + _metadata: MetadataMap, ) -> Result { Err(Status::unimplemented("Not yet implemented")) } async fn do_action_close_prepared_statement( &self, _query: ActionClosePreparedStatementRequest, + _metadata: MetadataMap, ) { unimplemented!("Not yet implemented") } diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index 2d9d8863858..b2f8db5c176 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -19,6 +19,7 @@ use std::pin::Pin; use futures::Stream; use prost::Message; +use tonic::metadata::MetadataMap; use tonic::{Request, Response, Status, Streaming}; use super::{ @@ -66,6 +67,7 @@ pub trait FlightSqlService: &self, query: CommandStatementQuery, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo for executing an already created prepared statement. @@ -73,6 +75,7 @@ pub trait FlightSqlService: &self, query: CommandPreparedStatementQuery, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo for listing catalogs. @@ -80,6 +83,7 @@ pub trait FlightSqlService: &self, query: CommandGetCatalogs, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo for listing schemas. @@ -87,6 +91,7 @@ pub trait FlightSqlService: &self, query: CommandGetDbSchemas, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo for listing tables. @@ -94,6 +99,7 @@ pub trait FlightSqlService: &self, query: CommandGetTables, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo to extract information about the table types. @@ -101,6 +107,7 @@ pub trait FlightSqlService: &self, query: CommandGetTableTypes, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo for retrieving other information (See SqlInfo). @@ -108,6 +115,7 @@ pub trait FlightSqlService: &self, query: CommandGetSqlInfo, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo to extract information about primary and foreign keys. @@ -115,6 +123,7 @@ pub trait FlightSqlService: &self, query: CommandGetPrimaryKeys, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo to extract information about exported keys. @@ -122,6 +131,7 @@ pub trait FlightSqlService: &self, query: CommandGetExportedKeys, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo to extract information about imported keys. @@ -129,6 +139,7 @@ pub trait FlightSqlService: &self, query: CommandGetImportedKeys, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; /// Get a FlightInfo to extract information about cross reference. @@ -136,6 +147,7 @@ pub trait FlightSqlService: &self, query: CommandGetCrossReference, request: FlightDescriptor, + metadata: MetadataMap, ) -> Result, Status>; // do_get @@ -144,66 +156,77 @@ pub trait FlightSqlService: async fn do_get_statement( &self, ticket: TicketStatementQuery, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the prepared statement query results. async fn do_get_prepared_statement( &self, query: CommandPreparedStatementQuery, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of catalogs. async fn do_get_catalogs( &self, query: CommandGetCatalogs, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of schemas. async fn do_get_schemas( &self, query: CommandGetDbSchemas, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of tables. async fn do_get_tables( &self, query: CommandGetTables, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the table types. async fn do_get_table_types( &self, query: CommandGetTableTypes, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the list of SqlInfo results. async fn do_get_sql_info( &self, query: CommandGetSqlInfo, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the primary and foreign keys. async fn do_get_primary_keys( &self, query: CommandGetPrimaryKeys, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the exported keys. async fn do_get_exported_keys( &self, query: CommandGetExportedKeys, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the imported keys. async fn do_get_imported_keys( &self, query: CommandGetImportedKeys, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; /// Get a FlightDataStream containing the data related to the cross reference. async fn do_get_cross_reference( &self, query: CommandGetCrossReference, + metadata: MetadataMap, ) -> Result::DoGetStream>, Status>; // do_put @@ -212,6 +235,7 @@ pub trait FlightSqlService: async fn do_put_statement_update( &self, ticket: CommandStatementUpdate, + metadata: MetadataMap, ) -> Result; /// Bind parameters to given prepared statement. @@ -219,6 +243,7 @@ pub trait FlightSqlService: &self, query: CommandPreparedStatementQuery, request: Streaming, + metadata: MetadataMap, ) -> Result::DoPutStream>, Status>; /// Execute an update SQL prepared statement. @@ -226,6 +251,7 @@ pub trait FlightSqlService: &self, query: CommandPreparedStatementUpdate, request: Streaming, + metadata: MetadataMap, ) -> Result; // do_action @@ -234,12 +260,14 @@ pub trait FlightSqlService: async fn do_action_create_prepared_statement( &self, query: ActionCreatePreparedStatementRequest, + metadata: MetadataMap, ) -> Result; /// Close a prepared statement. async fn do_action_close_prepared_statement( &self, query: ActionClosePreparedStatementRequest, + metadata: MetadataMap, ); /// Register a new SqlInfo result, making it available when calling GetSqlInfo. @@ -285,120 +313,92 @@ where async fn get_flight_info( &self, - request: Request, + tonic_request: Request, ) -> Result, Status> { - let request = request.into_inner(); + let md = tonic_request.metadata().clone(); + let request = tonic_request.into_inner(); let any: prost_types::Any = prost::Message::decode(&*request.cmd).map_err(decode_error_to_status)?; if any.is::() { - return self - .get_flight_info_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_statement(token, request, md).await; } if any.is::() { + let handle = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); return self - .get_flight_info_prepared_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .get_flight_info_prepared_statement(handle, request, md) .await; } if any.is::() { - return self - .get_flight_info_catalogs( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_catalogs(token, request, md).await; } if any.is::() { - return self - .get_flight_info_schemas( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_schemas(token, request, md).await; } if any.is::() { - return self - .get_flight_info_tables( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_tables(token, request, md).await; } if any.is::() { - return self - .get_flight_info_table_types( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_table_types(token, request, md).await; } if any.is::() { - return self - .get_flight_info_sql_info( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_sql_info(token, request, md).await; } if any.is::() { - return self - .get_flight_info_primary_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_primary_keys(token, request, md).await; } if any.is::() { - return self - .get_flight_info_exported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_exported_keys(token, request, md).await; } if any.is::() { - return self - .get_flight_info_imported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.get_flight_info_imported_keys(token, request, md).await; } if any.is::() { + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); return self - .get_flight_info_cross_reference( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .get_flight_info_cross_reference(token, request, md) .await; } @@ -417,110 +417,89 @@ where async fn do_get( &self, - request: Request, + tonic_request: Request, ) -> Result, Status> { - let request = request.into_inner(); + let md = tonic_request.metadata().clone(); + let request = tonic_request.into_inner(); let any: prost_types::Any = prost::Message::decode(&*request.ticket).map_err(decode_error_to_status)?; if any.is::() { - return self - .do_get_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_statement(token, md).await; } if any.is::() { - return self - .do_get_prepared_statement( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_prepared_statement(token, md).await; } if any.is::() { - return self - .do_get_catalogs( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_catalogs(token, md).await; } if any.is::() { - return self - .do_get_schemas( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_schemas(token, md).await; } if any.is::() { - return self - .do_get_tables( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_tables(token, md).await; } if any.is::() { - return self - .do_get_table_types( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_table_types(token, md).await; } if any.is::() { - return self - .do_get_sql_info( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_sql_info(token, md).await; } if any.is::() { - return self - .do_get_primary_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_primary_keys(token, md).await; } if any.is::() { - return self - .do_get_exported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_exported_keys(token, md).await; } if any.is::() { - return self - .do_get_imported_keys( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_imported_keys(token, md).await; } if any.is::() { - return self - .do_get_cross_reference( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + return self.do_get_cross_reference(token, md).await; } Err(Status::unimplemented(format!( @@ -531,21 +510,20 @@ where async fn do_put( &self, - request: Request>, + tonic_request: Request>, ) -> Result, Status> { - let mut request = request.into_inner(); + let md = tonic_request.metadata().clone(); + let mut request = tonic_request.into_inner(); let cmd = request.message().await?.unwrap(); let any: prost_types::Any = prost::Message::decode(&*cmd.flight_descriptor.unwrap().cmd) .map_err(decode_error_to_status)?; if any.is::() { - let record_count = self - .do_put_statement_update( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - ) - .await?; + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); + let record_count = self.do_put_statement_update(token, md).await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { app_metadata: result.as_any().encode_to_vec(), @@ -553,23 +531,21 @@ where return Ok(Response::new(Box::pin(output))); } if any.is::() { + let token = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); return self - .do_put_prepared_statement_query( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .do_put_prepared_statement_query(token, request, md) .await; } if any.is::() { + let handle = any + .unpack() + .map_err(arrow_error_to_status)? + .expect("unreachable"); let record_count = self - .do_put_prepared_statement_update( - any.unpack() - .map_err(arrow_error_to_status)? - .expect("unreachable"), - request, - ) + .do_put_prepared_statement_update(handle, request, md) .await?; let result = DoPutUpdateResult { record_count }; let output = futures::stream::iter(vec![Ok(super::super::gen::PutResult { @@ -612,9 +588,10 @@ where async fn do_action( &self, - request: Request, + tonic_request: Request, ) -> Result, Status> { - let request = request.into_inner(); + let md = tonic_request.metadata().clone(); + let request = tonic_request.into_inner(); if request.r#type == CREATE_PREPARED_STATEMENT { let any: prost_types::Any = @@ -628,7 +605,7 @@ where "Unable to unpack ActionCreatePreparedStatementRequest.", ) })?; - let stmt = self.do_action_create_prepared_statement(cmd).await?; + let stmt = self.do_action_create_prepared_statement(cmd, md).await?; let output = futures::stream::iter(vec![Ok(super::super::gen::Result { body: stmt.as_any().encode_to_vec(), })]); @@ -646,7 +623,7 @@ where "Unable to unpack ActionClosePreparedStatementRequest.", ) })?; - self.do_action_close_prepared_statement(cmd).await; + self.do_action_close_prepared_statement(cmd, md).await; return Ok(Response::new(Box::pin(futures::stream::empty()))); }