Skip to content

Commit

Permalink
完善udf插件化代码 (apache#2)
Browse files Browse the repository at this point in the history
完善udf插件化代码
  • Loading branch information
EricJoy2048 committed Feb 24, 2022
1 parent e97600a commit 7570fd3
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 111 deletions.
14 changes: 3 additions & 11 deletions ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1080,11 +1080,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
}
// argo engine add start
ExprType::AggregateUdfExpr(expr) => {
let gpm = global_plugin_manager("").lock().unwrap();
let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap();
if let Some(udf_plugin_manager) =
plugin_registrar.as_any().downcast_ref::<UDFPluginManager>()
{
if let Some(udf_plugin_manager) = get_udf_plugin_manager("") {
let fun = udf_plugin_manager
.aggregate_udfs
.get(expr.fun_name.as_str())
Expand All @@ -1106,11 +1102,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
}
}
ExprType::ScalarUdfProtoExpr(expr) => {
let gpm = global_plugin_manager("").lock().unwrap();
let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap();
if let Some(udf_plugin_manager) =
plugin_registrar.as_any().downcast_ref::<UDFPluginManager>()
{
if let Some(udf_plugin_manager) = get_udf_plugin_manager("") {
let fun = udf_plugin_manager
.scalar_udfs
.get(expr.fun_name.as_str())
Expand Down Expand Up @@ -1335,7 +1327,7 @@ impl TryInto<Field> for &protobuf::Field {
use crate::serde::protobuf::ColumnStats;
use datafusion::physical_plan::{aggregates, windows};
use datafusion::plugin::plugin_manager::global_plugin_manager;
use datafusion::plugin::udf::UDFPluginManager;
use datafusion::plugin::udf::{get_udf_plugin_manager, UDFPluginManager};
use datafusion::plugin::PluginEnum;
use datafusion::prelude::{
array, date_part, date_trunc, length, lower, ltrim, md5, rtrim, sha224, sha256,
Expand Down
13 changes: 3 additions & 10 deletions ballista/rust/core/src/serde/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ use datafusion::physical_plan::{
AggregateExpr, ColumnStatistics, ExecutionPlan, PhysicalExpr, Statistics, WindowExpr,
};
use datafusion::plugin::plugin_manager::global_plugin_manager;
use datafusion::plugin::udf::UDFPluginManager;
use datafusion::plugin::udf::{get_udf_plugin_manager, UDFPluginManager};
use datafusion::plugin::PluginEnum;
use datafusion::prelude::CsvReadOptions;
use log::debug;
Expand Down Expand Up @@ -320,10 +320,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
ExprType::AggregateUdfExpr(agg_node) => {
let name = agg_node.fun_name.as_str();
let udaf_fun_name = &name[0..name.find('(').unwrap()];
let gpm = global_plugin_manager("").lock().unwrap();
let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap();
if let Some(udf_plugin_manager) = plugin_registrar.as_any().downcast_ref::<UDFPluginManager>()
{
if let Some(udf_plugin_manager) = get_udf_plugin_manager("") {
let fun = udf_plugin_manager.aggregate_udfs.get(udaf_fun_name).ok_or_else(|| {
proto_error(format!(
"can not get udaf:{} from plugins!",
Expand Down Expand Up @@ -584,11 +581,7 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc<dyn PhysicalExpr> {
}
// argo engine add.
ExprType::ScalarUdfProtoExpr(e) => {
let gpm = global_plugin_manager("").lock().unwrap();
let plugin_registrar = gpm.plugin_managers.get(&PluginEnum::UDF).unwrap();
if let Some(udf_plugin_manager) =
plugin_registrar.as_any().downcast_ref::<UDFPluginManager>()
{
if let Some(udf_plugin_manager) = get_udf_plugin_manager("") {
let fun = udf_plugin_manager
.scalar_udfs
.get(&e.fun_name)
Expand Down
10 changes: 2 additions & 8 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::PhysicalPlanner;
use crate::plugin::plugin_manager::global_plugin_manager;
use crate::plugin::udf::UDFPluginManager;
use crate::plugin::PluginEnum;
use crate::plugin::udf::get_udf_plugin_manager;
use crate::sql::{
parser::{DFParser, FileType},
planner::{ContextProvider, SqlToRel},
Expand Down Expand Up @@ -198,13 +196,9 @@ impl ExecutionContext {
})),
};

let gpm = global_plugin_manager(config.plugin_dir.as_str());

// register udf
let gpm_guard = gpm.lock().unwrap();
let plugin_registrar = gpm_guard.plugin_managers.get(&PluginEnum::UDF).unwrap();
if let Some(udf_plugin_manager) =
plugin_registrar.as_any().downcast_ref::<UDFPluginManager>()
get_udf_plugin_manager(config.plugin_dir.as_str())
{
udf_plugin_manager
.scalar_udfs
Expand Down
44 changes: 25 additions & 19 deletions datafusion/src/plugin/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::error::Result;
use crate::plugin::udf::UDFPluginManager;
use libloading::Library;
use std::any::Any;
use std::env;
use std::sync::Arc;

/// plugin manager
pub mod plugin_manager;
Expand Down Expand Up @@ -47,17 +66,15 @@ pub struct PluginDeclaration {

/// One of PluginEnum
pub plugin_type: unsafe extern "C" fn() -> PluginEnum,

/// `register` is a function which impl PluginRegistrar. It will be call when plugin load.
pub register: unsafe extern "C" fn(&mut Box<dyn PluginRegistrar>),
}

/// Plugin Registrar , Every plugin need implement this trait
pub trait PluginRegistrar: Send + Sync + 'static {
/// The implementer of the plug-in needs to call this interface to report his own information to the plug-in manager
fn register_plugin(&mut self, plugin: Box<dyn Plugin>) -> Result<()>;
/// # Safety
/// load plugin from library
unsafe fn load(&mut self, library: Arc<Library>) -> Result<()>;

/// Returns the plugin registrar as [`Any`](std::any::Any) so that it can be
/// Returns the plugin as [`Any`](std::any::Any) so that it can be
/// downcast to a specific implementation.
fn as_any(&self) -> &dyn Any;
}
Expand All @@ -66,22 +83,12 @@ pub trait PluginRegistrar: Send + Sync + 'static {
///
/// # Notes
///
/// This works by automatically generating an `extern "C"` function with a
/// This works by automatically generating an `extern "C"` function named `get_plugin_type` with a
/// pre-defined signature and symbol name. And then generating a PluginDeclaration.
/// Therefore you will only be able to declare one plugin per library.
#[macro_export]
macro_rules! declare_plugin {
($plugin_type:expr, $curr_plugin_type:ty, $constructor:path) => {
#[no_mangle]
pub extern "C" fn register_plugin(
registrar: &mut Box<dyn $crate::plugin::PluginRegistrar>,
) {
// make sure the constructor is the correct type.
let constructor: fn() -> $curr_plugin_type = $constructor;
let object = constructor();
registrar.register_plugin(Box::new(object)).unwrap();
}

($plugin_type:expr) => {
#[no_mangle]
pub extern "C" fn get_plugin_type() -> $crate::plugin::PluginEnum {
$plugin_type
Expand All @@ -93,7 +100,6 @@ macro_rules! declare_plugin {
rustc_version: $crate::plugin::RUSTC_VERSION,
core_version: $crate::plugin::CORE_VERSION,
plugin_type: get_plugin_type,
register: register_plugin,
};
};
}
Expand Down
61 changes: 43 additions & 18 deletions datafusion/src/plugin/plugin_manager.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::error::{DataFusionError, Result};
use crate::plugin::{PluginDeclaration, CORE_VERSION, RUSTC_VERSION};
use crate::plugin::{PluginEnum, PluginRegistrar};
Expand All @@ -13,11 +29,14 @@ use once_cell::sync::OnceCell;
/// To prevent the library from being loaded multiple times, we use once_cell defines a Arc<Mutex<GlobalPluginManager>>
/// Because datafusion is a library, not a service, users may not need to load all plug-ins in the process.
/// So fn global_plugin_manager return Arc<Mutex<GlobalPluginManager>>. In this way, users can load the required library through the load method of GlobalPluginManager when needed
static INSTANCE: OnceCell<Arc<Mutex<GlobalPluginManager>>> = OnceCell::new();

/// global_plugin_manager
pub fn global_plugin_manager(
plugin_path: &str,
) -> &'static Arc<Mutex<GlobalPluginManager>> {
static INSTANCE: OnceCell<Arc<Mutex<GlobalPluginManager>>> = OnceCell::new();
INSTANCE.get_or_init(move || unsafe {
println!("====================init===================");
let mut gpm = GlobalPluginManager::default();
gpm.load(plugin_path).unwrap();
Arc::new(Mutex::new(gpm))
Expand All @@ -38,6 +57,9 @@ impl GlobalPluginManager {
/// # Safety
/// find plugin file from `plugin_path` and load it .
unsafe fn load(&mut self, plugin_path: &str) -> Result<()> {
if "".eq(plugin_path) {
return Ok(());
}
// find library file from udaf_plugin_path
info!("load plugin from dir:{}", plugin_path);
println!("load plugin from dir:{}", plugin_path);
Expand All @@ -54,18 +76,19 @@ impl GlobalPluginManager {

let library = Arc::new(library);

// get a pointer to the plugin_declaration symbol.
let dec = library
.get::<*mut PluginDeclaration>(b"plugin_declaration\0")
.map_err(|e| {
DataFusionError::IoError(io::Error::new(
io::ErrorKind::Other,
format!("not found plugin_declaration in the library: {}", e),
))
})?
.read();

// version checks to prevent accidental ABI incompatibilities
let dec = library.get::<*mut PluginDeclaration>(b"plugin_declaration\0");
if dec.is_err() {
info!(
"not found plugin_declaration in the library: {}",
plugin_file.path().to_str().unwrap()
);
return Ok(());
}

let dec = dec.unwrap().read();

// ersion checks to prevent accidental ABI incompatibilities

if dec.rustc_version != RUSTC_VERSION || dec.core_version != CORE_VERSION {
return Err(DataFusionError::IoError(io::Error::new(
io::ErrorKind::Other,
Expand All @@ -82,8 +105,7 @@ impl GlobalPluginManager {
}
Some(manager) => manager,
};

(dec.register)(curr_plugin_manager);
curr_plugin_manager.load(library)?;
self.plugin_files
.push(plugin_file.path().to_str().unwrap().to_string());
}
Expand Down Expand Up @@ -112,17 +134,20 @@ impl GlobalPluginManager {
if let Some(path) = item.path().extension() {
if let Some(suffix) = path.to_str() {
if suffix == "dylib" || suffix == "so" || suffix == "dll" {
info!("load plugin from library file:{}", path.to_str().unwrap());
info!(
"load plugin from library file:{}",
item.path().to_str().unwrap()
);
println!(
"load plugin from library file:{}",
path.to_str().unwrap()
item.path().to_str().unwrap()
);
return Some(item);
}
}
}

return None;
None
}) {
plugin_files.push(entry);
}
Expand Down

0 comments on commit 7570fd3

Please sign in to comment.