Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Host filter #567

Merged
merged 6 commits into from Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions scylla/src/lib.rs
Expand Up @@ -116,6 +116,7 @@ pub use transport::query_result::QueryResult;
pub use transport::session::{IntoTypedRows, Session, SessionConfig};
pub use transport::session_builder::SessionBuilder;

pub use transport::host_filter;
pub use transport::load_balancing;
pub use transport::retry_policy;
pub use transport::speculative_execution;
Expand Down
37 changes: 29 additions & 8 deletions scylla/src/transport/cluster.rs
Expand Up @@ -3,6 +3,7 @@ use crate::frame::response::event::{Event, StatusChangeEvent};
use crate::frame::value::ValueList;
use crate::load_balancing::TokenAwarePolicy;
use crate::routing::Token;
use crate::transport::host_filter::HostFilter;
use crate::transport::{
connection::{Connection, VerifiedKeyspaceName},
connection_pool::PoolConfig,
Expand Down Expand Up @@ -110,6 +111,10 @@ struct ClusterWorker {

// Keyspace send in "USE <keyspace name>" when opening each connection
used_keyspace: Option<VerifiedKeyspaceName>,

// The host filter determines towards which nodes we should open
// connections
host_filter: Option<Arc<dyn HostFilter>>,
}

#[derive(Debug)]
Expand All @@ -129,6 +134,7 @@ impl Cluster {
pool_config: PoolConfig,
fetch_schema_metadata: bool,
address_translator: &Option<Arc<dyn AddressTranslator>>,
host_filter: &Option<Arc<dyn HostFilter>>,
) -> Result<Cluster, QueryError> {
let (refresh_sender, refresh_receiver) = tokio::sync::mpsc::channel(32);
let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32);
Expand All @@ -141,10 +147,17 @@ impl Cluster {
server_events_sender,
fetch_schema_metadata,
address_translator,
host_filter,
);

let metadata = metadata_reader.read_metadata(true).await?;
let cluster_data = ClusterData::new(metadata, &pool_config, &HashMap::new(), &None);
let cluster_data = ClusterData::new(
metadata,
&pool_config,
&HashMap::new(),
&None,
host_filter.as_deref(),
);
cluster_data.wait_until_all_pools_are_initialized().await;
let cluster_data: Arc<ArcSwap<ClusterData>> =
Arc::new(ArcSwap::from(Arc::new(cluster_data)));
Expand All @@ -160,6 +173,8 @@ impl Cluster {

use_keyspace_channel: use_keyspace_receiver,
used_keyspace: None,

host_filter: host_filter.clone(),
};

let (fut, worker_handle) = worker.work().remote_handle();
Expand Down Expand Up @@ -273,6 +288,7 @@ impl ClusterData {
pool_config: &PoolConfig,
known_peers: &HashMap<SocketAddr, Arc<Node>>,
used_keyspace: &Option<VerifiedKeyspaceName>,
host_filter: Option<&dyn HostFilter>,
) -> Self {
// Create new updated known_peers and ring
let mut new_known_peers: HashMap<SocketAddr, Arc<Node>> =
Expand All @@ -289,13 +305,17 @@ impl ClusterData {
Some(node) if node.datacenter == peer.datacenter && node.rack == peer.rack => {
node.clone()
}
_ => Arc::new(Node::new(
peer.address,
pool_config.clone(),
peer.datacenter,
peer.rack,
used_keyspace.clone(),
)),
_ => {
let is_enabled = host_filter.map_or(true, |f| f.accept(&peer));
Arc::new(Node::new(
peer.address,
pool_config.clone(),
peer.datacenter,
peer.rack,
used_keyspace.clone(),
is_enabled,
))
}
};

new_known_peers.insert(peer.address, node.clone());
Expand Down Expand Up @@ -567,6 +587,7 @@ impl ClusterWorker {
&self.pool_config,
&cluster_data.known_peers,
&self.used_keyspace,
self.host_filter.as_deref(),
));

new_cluster_data
Expand Down
78 changes: 78 additions & 0 deletions scylla/src/transport/host_filter.rs
@@ -0,0 +1,78 @@
//! Host filters.
//!
//! Host filters are essentially just a predicate over
//! [`Peer`](crate::transport::topology::Peer)s. Currently, they are used
//! by the [`Session`](crate::transport::session::Session) to determine whether
//! connections should be opened to a given node or not.

use std::collections::HashSet;
use std::io::Error;
use std::net::{SocketAddr, ToSocketAddrs};

use crate::transport::topology::Peer;

/// The `HostFilter` trait.
pub trait HostFilter: Send + Sync {
/// Returns whether a peer should be accepted or not.
fn accept(&self, peer: &Peer) -> bool;
}

/// Unconditionally accepts all nodes.
pub struct AcceptAllHostFilter;

impl HostFilter for AcceptAllHostFilter {
fn accept(&self, _peer: &Peer) -> bool {
true
}
}

/// Accepts nodes whose addresses are present in the allow list provided
/// during filter's construction.
pub struct AllowListHostFilter {
allowed: HashSet<SocketAddr>,
}

impl AllowListHostFilter {
/// Creates a new `AllowListHostFilter` which only accepts nodes from the
/// list.
pub fn new<I, A>(allowed_iter: I) -> Result<Self, Error>
where
I: IntoIterator<Item = A>,
A: ToSocketAddrs,
{
// I couldn't get the iterator combinators to work
let mut allowed = HashSet::new();
for item in allowed_iter {
for addr in item.to_socket_addrs()? {
allowed.insert(addr);
}
}

Ok(Self { allowed })
}
}

impl HostFilter for AllowListHostFilter {
fn accept(&self, peer: &Peer) -> bool {
self.allowed.contains(&peer.address)
}
}

/// Accepts nodes from given DC.
pub struct DcHostFilter {
local_dc: String,
}

impl DcHostFilter {
/// Creates a new `DcHostFilter` that accepts nodes only from the
/// `local_dc`.
pub fn new(local_dc: String) -> Self {
Self { local_dc }
}
}

impl HostFilter for DcHostFilter {
fn accept(&self, peer: &Peer) -> bool {
peer.datacenter.as_ref() == Some(&self.local_dc)
}
}
2 changes: 1 addition & 1 deletion scylla/src/transport/load_balancing/mod.rs
Expand Up @@ -150,7 +150,7 @@ mod tests {
keyspaces: HashMap::new(),
};

ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
}

pub const EMPTY_STATEMENT: Statement = Statement {
Expand Down
4 changes: 2 additions & 2 deletions scylla/src/transport/load_balancing/token_aware.rs
Expand Up @@ -345,7 +345,7 @@ mod tests {
keyspaces,
};

ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
}

// creates ClusterData with info about 8 nodes living in two different datacenters
Expand Down Expand Up @@ -444,7 +444,7 @@ mod tests {
keyspaces,
};

ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
}

// Used as child policy for TokenAwarePolicy tests
Expand Down
1 change: 1 addition & 0 deletions scylla/src/transport/mod.rs
Expand Up @@ -3,6 +3,7 @@ mod cluster;
pub(crate) mod connection;
mod connection_pool;
pub mod downgrading_consistency_retry_policy;
pub mod host_filter;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could also be reexported in lib.rs for convenience.

pub mod iterator;
pub mod load_balancing;
pub(crate) mod metrics;
Expand Down
43 changes: 34 additions & 9 deletions scylla/src/transport/node.rs
Expand Up @@ -21,7 +21,8 @@ pub struct Node {
pub datacenter: Option<String>,
pub rack: Option<String>,

pool: NodeConnectionPool,
// If the node is filtered out by the host filter, this will be None
pool: Option<NodeConnectionPool>,

down_marker: AtomicBool,
}
Expand All @@ -40,9 +41,11 @@ impl Node {
datacenter: Option<String>,
rack: Option<String>,
keyspace_name: Option<VerifiedKeyspaceName>,
enabled: bool,
) -> Self {
let pool =
NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name);
let pool = enabled.then(|| {
NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name)
});

Node {
address,
Expand All @@ -54,7 +57,7 @@ impl Node {
}

pub fn sharder(&self) -> Option<Sharder> {
self.pool.sharder()
self.pool.as_ref()?.sharder()
}

/// Get connection which should be used to connect using given token
Expand All @@ -63,18 +66,25 @@ impl Node {
&self,
token: Token,
) -> Result<Arc<Connection>, QueryError> {
self.pool.connection_for_token(token)
self.get_pool()?.connection_for_token(token)
}

/// Get random connection
pub(crate) async fn random_connection(&self) -> Result<Arc<Connection>, QueryError> {
self.pool.random_connection()
self.get_pool()?.random_connection()
}

pub fn is_down(&self) -> bool {
self.down_marker.load(Ordering::Relaxed)
}

/// Returns a boolean which indicates whether this node was is enabled.
/// Only enabled nodes will have connections open. For disabled nodes,
/// no connections will be opened.
pub fn is_enabled(&self) -> bool {
self.pool.is_some()
}

pub(crate) fn change_down_marker(&self, is_down: bool) {
self.down_marker.store(is_down, Ordering::Relaxed);
}
Expand All @@ -83,15 +93,30 @@ impl Node {
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
self.pool.use_keyspace(keyspace_name).await
if let Some(pool) = &self.pool {
pool.use_keyspace(keyspace_name).await?;
}
Ok(())
}

pub(crate) fn get_working_connections(&self) -> Result<Vec<Arc<Connection>>, QueryError> {
self.pool.get_working_connections()
self.get_pool()?.get_working_connections()
}

pub(crate) async fn wait_until_pool_initialized(&self) {
self.pool.wait_until_initialized().await
if let Some(pool) = &self.pool {
pool.wait_until_initialized().await;
}
}

fn get_pool(&self) -> Result<&NodeConnectionPool, QueryError> {
self.pool.as_ref().ok_or_else(|| {
QueryError::IoError(Arc::new(std::io::Error::new(
std::io::ErrorKind::Other,
"No connections in the pool: the node has been disabled \
by the host filter",
)))
})
}
}

Expand Down
8 changes: 8 additions & 0 deletions scylla/src/transport/session.rs
Expand Up @@ -38,6 +38,7 @@ use crate::tracing::{GetTracingConfig, TracingEvent, TracingInfo};
use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug};
use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName};
use crate::transport::connection_pool::PoolConfig;
use crate::transport::host_filter::HostFilter;
use crate::transport::iterator::{PreparedIteratorConfig, RowIterator};
use crate::transport::load_balancing::{
LoadBalancingPolicy, RoundRobinPolicy, Statement, TokenAwarePolicy,
Expand Down Expand Up @@ -205,6 +206,11 @@ pub struct SessionConfig {

pub address_translator: Option<Arc<dyn AddressTranslator>>,

/// The host filter decides whether any connections should be opened
/// to the node or not. The driver will also avoid filtered out nodes when
/// re-establishing the control connection.
pub host_filter: Option<Arc<dyn HostFilter>>,

/// If true, full schema metadata is fetched after successfully reaching a schema agreement.
/// It is true by default but can be disabled if successive schema-altering statements should be performed.
pub refresh_metadata_on_auto_schema_agreement: bool,
Expand Down Expand Up @@ -252,6 +258,7 @@ impl SessionConfig {
auto_await_schema_agreement_timeout: Some(std::time::Duration::from_secs(60)),
request_timeout: Some(Duration::from_secs(30)),
address_translator: None,
host_filter: None,
refresh_metadata_on_auto_schema_agreement: true,
}
}
Expand Down Expand Up @@ -438,6 +445,7 @@ impl Session {
config.get_pool_config(),
config.fetch_schema_metadata,
&config.address_translator,
&config.host_filter,
)
.await?;

Expand Down
33 changes: 33 additions & 0 deletions scylla/src/transport/session_builder.rs
Expand Up @@ -5,6 +5,7 @@ use super::load_balancing::LoadBalancingPolicy;
use super::session::{AddressTranslator, Session, SessionConfig};
use super::speculative_execution::SpeculativeExecutionPolicy;
use super::Compression;
use crate::transport::host_filter::HostFilter;
use crate::transport::{connection_pool::PoolSize, retry_policy::RetryPolicy};
use std::net::SocketAddr;
use std::sync::Arc;
Expand Down Expand Up @@ -622,6 +623,38 @@ impl SessionBuilder {
self
}

/// Sets the host filter. The host filter decides whether any connections
/// should be opened to the node or not. The driver will also avoid
/// those nodes when re-establishing the control connection.
///
/// See the [host filter](crate::transport::host_filter) module for a list
/// of pre-defined filters. It is also possible to provide a custom filter
/// by implementing the HostFilter trait.
///
/// # Example
/// ```
/// # use async_trait::async_trait;
/// # use std::net::SocketAddr;
/// # use std::sync::Arc;
/// # use scylla::{Session, SessionBuilder};
/// # use scylla::transport::session::{AddressTranslator, TranslationError};
/// # use scylla::transport::host_filter::DcHostFilter;
///
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// // The session will only connect to nodes from "my-local-dc"
/// let session: Session = SessionBuilder::new()
/// .known_node("127.0.0.1:9042")
/// .host_filter(Arc::new(DcHostFilter::new("my-local-dc".to_string())))
/// .build()
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn host_filter(mut self, filter: Arc<dyn HostFilter>) -> Self {
self.config.host_filter = Some(filter);
self
}

/// Set the refresh metadata on schema agreement flag.
/// The default is true.
///
Expand Down