Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP FlightSQL Integration Branch #3316

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions arrow-flight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ default = []
flight-sql-experimental = ["prost-types"]

[dev-dependencies]
arrow = { version = "28.0.0", path = "../arrow", features = ["prettyprint"] }
tempfile = "3.3"
tokio-stream = { version = "0.1", features = ["net"] }
tower = "0.4.13"

[build-dependencies]
tonic-build = { version = "0.8", default-features = false, features = ["transport", "prost"] }
Expand Down
229 changes: 206 additions & 23 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,27 @@
// specific language governing permissions and limitations
// under the License.

use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo};
use arrow_flight::{Action, FlightData, HandshakeRequest, HandshakeResponse, Ticket};
use futures::Stream;
use arrow_array::builder::StringBuilder;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_flight::sql::{ActionCreatePreparedStatementResult, ProstMessageExt, SqlInfo};
use arrow_flight::{
Action, FlightData, FlightEndpoint, HandshakeRequest, HandshakeResponse, IpcMessage,
Location, SchemaAsIpc, Ticket,
};
use futures::{stream, Stream};
use prost_types::Any;
use std::fs;
use std::pin::Pin;
use tonic::transport::Server;
use std::sync::Arc;
use tempfile::NamedTempFile;
use tokio::net::{UnixListener, UnixStream};
use tokio_stream::wrappers::UnixListenerStream;
use tonic::transport::{Endpoint, Server};
use tonic::{Request, Response, Status, Streaming};

use arrow_flight::flight_descriptor::DescriptorType;
use arrow_flight::sql::client::FlightSqlServiceClient;
use arrow_flight::utils::batches_to_flight_data;
use arrow_flight::{
flight_service_server::FlightService,
flight_service_server::FlightServiceServer,
Expand All @@ -36,10 +50,28 @@ use arrow_flight::{
},
FlightDescriptor, FlightInfo,
};
use arrow_ipc::writer::IpcWriteOptions;
use arrow_schema::{ArrowError, DataType, Field, Schema};

macro_rules! status {
($desc:expr, $err:expr) => {
Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!()))
};
}

#[derive(Clone)]
pub struct FlightSqlServiceImpl {}

impl FlightSqlServiceImpl {
fn fake_result() -> Result<RecordBatch, ArrowError> {
let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8, false)]);
let mut builder = StringBuilder::new();
builder.append_value("Hello, FlightSQL!");
let cols = vec![Arc::new(builder.finish()) as ArrayRef];
RecordBatch::try_new(Arc::new(schema), cols)
}
}

#[tonic::async_trait]
impl FlightSqlService for FlightSqlServiceImpl {
type FlightService = FlightSqlServiceImpl;
Expand All @@ -57,7 +89,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
.get("authorization")
.ok_or(Status::invalid_argument("authorization field not present"))?
.to_str()
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
.map_err(|e| status!("authorization not parsable", e))?;
if !authorization.starts_with(basic) {
Err(Status::invalid_argument(format!(
"Auth type not implemented: {}",
Expand All @@ -66,20 +98,20 @@ impl FlightSqlService for FlightSqlServiceImpl {
}
let base64 = &authorization[basic.len()..];
let bytes = base64::decode(base64)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
.map_err(|e| status!("authorization not decodable", e))?;
let str = String::from_utf8(bytes)
.map_err(|_| Status::invalid_argument("authorization not parsable"))?;
.map_err(|e| status!("authorization not parsable", e))?;
let parts: Vec<_> = str.split(":").collect();
if parts.len() != 2 {
Err(Status::invalid_argument(format!(
"Invalid authorization header"
)))?;
}
let user = parts[0];
let pass = parts[1];
if user != "admin" || pass != "password" {
let (user, pass) = match parts.as_slice() {
[user, pass] => (user, pass),
_ => Err(Status::invalid_argument(
"Invalid authorization header".to_string(),
))?,
};
if user != &"admin" || pass != &"password" {
Err(Status::unauthenticated("Invalid credentials!"))?
}

let result = HandshakeResponse {
protocol_version: 0,
payload: "random_uuid_token".as_bytes().to_vec(),
Expand All @@ -89,7 +121,26 @@ impl FlightSqlService for FlightSqlServiceImpl {
return Ok(Response::new(Box::pin(output)));
}

// get_flight_info
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
_message: prost_types::Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let batch =
Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
let schema = (*batch.schema()).clone();
let batches = vec![batch];
let flight_data = batches_to_flight_data(schema, batches)
.map_err(|e| status!("Could not convert batches", e))?
.into_iter()
.map(Ok);

let stream: Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send>> =
Box::pin(stream::iter(flight_data));
let resp = Response::new(stream);
Ok(resp)
}

async fn get_flight_info_statement(
&self,
_query: CommandStatementQuery,
Expand All @@ -102,12 +153,49 @@ impl FlightSqlService for FlightSqlServiceImpl {

async fn get_flight_info_prepared_statement(
&self,
_query: CommandPreparedStatementQuery,
cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(
"get_flight_info_prepared_statement not implemented",
))
let handle = String::from_utf8(cmd.prepared_statement_handle)
.map_err(|e| status!("Unable to parse handle", e))?;
let batch =
Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
let schema = (*batch.schema()).clone();
let num_rows = batch.num_rows();
let num_bytes = batch.get_array_memory_size();
let loc = Location {
uri: "grpc+tcp://127.0.0.1".to_string(),
};
let fetch = FetchResults {
handle: handle.to_string(),
};
let buf = ::prost::Message::encode_to_vec(&fetch.as_any());
let ticket = Ticket { ticket: buf };
let endpoint = FlightEndpoint {
ticket: Some(ticket),
location: vec![loc],
};
let endpoints = vec![endpoint];

let message = SchemaAsIpc::new(&schema, &IpcWriteOptions::default())
.try_into()
.map_err(|e| status!("Unable to serialize schema", e))?;
let IpcMessage(schema_bytes) = message;

let flight_desc = FlightDescriptor {
r#type: DescriptorType::Cmd.into(),
cmd: vec![],
path: vec![],
};
let info = FlightInfo {
schema: schema_bytes,
flight_descriptor: Some(flight_desc),
endpoint: endpoints,
total_records: num_rows as i64,
total_bytes: num_bytes as i64,
};
let resp = Response::new(info);
Ok(resp)
}

async fn get_flight_info_catalogs(
Expand Down Expand Up @@ -328,20 +416,33 @@ impl FlightSqlService for FlightSqlServiceImpl {
))
}

// do_action
async fn do_action_create_prepared_statement(
&self,
_query: ActionCreatePreparedStatementRequest,
_request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
Err(Status::unimplemented("Not yet implemented"))
let handle = "some_uuid";
let schema = Self::fake_result()
.map_err(|e| status!("Error getting result schema", e))?
.schema();
let message = SchemaAsIpc::new(&schema, &IpcWriteOptions::default())
.try_into()
.map_err(|e| status!("Unable to serialize schema", e))?;
let IpcMessage(schema_bytes) = message;
let res = ActionCreatePreparedStatementResult {
prepared_statement_handle: handle.as_bytes().to_vec(),
dataset_schema: schema_bytes,
parameter_schema: vec![], // TODO: parameters
};
Ok(res)
}

async fn do_action_close_prepared_statement(
&self,
_query: ActionClosePreparedStatementRequest,
_request: Request<Action>,
) {
unimplemented!("Not yet implemented")
unimplemented!("Implement do_action_close_prepared_statement")
}

async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
Expand All @@ -360,3 +461,85 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

Ok(())
}

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct FetchResults {
#[prost(string, tag = "1")]
pub handle: ::prost::alloc::string::String,
}

impl ProstMessageExt for FetchResults {
fn type_url() -> &'static str {
"type.googleapis.com/arrow.flight.protocol.sql.FetchResults"
}

fn as_any(&self) -> Any {
prost_types::Any {
type_url: FetchResults::type_url().to_string(),
value: ::prost::Message::encode_to_vec(self),
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures::TryStreamExt;

use arrow::util::pretty::pretty_format_batches;
use arrow_flight::utils::flight_data_to_batches;
use tower::service_fn;

async fn client_with_uds(path: String) -> FlightSqlServiceClient {
let connector = service_fn(move |_| UnixStream::connect(path.clone()));
let channel = Endpoint::try_from("https://example.com")
.unwrap()
.connect_with_connector(connector)
.await
.unwrap();
FlightSqlServiceClient::new(channel)
}

#[tokio::test]
async fn test_select_1() {
let file = NamedTempFile::new().unwrap();
let path = file.into_temp_path().to_str().unwrap().to_string();
let _ = fs::remove_file(path.clone());

let uds = UnixListener::bind(path.clone()).unwrap();
let stream = UnixListenerStream::new(uds);

// We would just listen on TCP, but it seems impossible to know when tonic is ready to serve
let service = FlightSqlServiceImpl {};
let serve_future = Server::builder()
.add_service(FlightServiceServer::new(service))
.serve_with_incoming(stream);

let request_future = async {
let mut client = client_with_uds(path).await;
let token = client.handshake("admin", "password").await.unwrap();
println!("Auth succeeded with token: {:?}", token);
let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
let flight_info = stmt.execute().await.unwrap();
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
let flight_data = client.do_get(ticket).await.unwrap();
let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
let batches = flight_data_to_batches(&flight_data).unwrap();
let res = pretty_format_batches(batches.as_slice()).unwrap();
let expected = r#"
+-------------------+
| salutation |
+-------------------+
| Hello, FlightSQL! |
+-------------------+"#
.trim()
.to_string();
assert_eq!(res.to_string(), expected);
};

tokio::select! {
_ = serve_future => panic!("server returned first"),
_ = request_future => println!("Client finished!"),
}
}
}
18 changes: 9 additions & 9 deletions arrow-flight/examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,63 +58,63 @@ impl FlightService for FlightServiceImpl {
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement handshake"))
}

async fn list_flights(
&self,
_request: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement list_flights"))
}

async fn get_flight_info(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement get_flight_info"))
}

async fn get_schema(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement get_schema"))
}

async fn do_get(
&self,
_request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement do_get"))
}

async fn do_put(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement do_put"))
}

async fn do_action(
&self,
_request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement do_action"))
}

async fn list_actions(
&self,
_request: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement list_actions"))
}

async fn do_exchange(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented("Not yet implemented"))
Err(Status::unimplemented("Implement do_exchange"))
}
}

Expand Down