Skip to content

Commit

Permalink
Merge pull request #567 from piodul/host-filter
Browse files Browse the repository at this point in the history
Host filter
  • Loading branch information
cvybhu committed Oct 4, 2022
2 parents 51fa01d + c2944af commit 90f8cdc
Show file tree
Hide file tree
Showing 10 changed files with 256 additions and 21 deletions.
1 change: 1 addition & 0 deletions scylla/src/lib.rs
Expand Up @@ -116,6 +116,7 @@ pub use transport::query_result::QueryResult;
pub use transport::session::{IntoTypedRows, Session, SessionConfig};
pub use transport::session_builder::SessionBuilder;

pub use transport::host_filter;
pub use transport::load_balancing;
pub use transport::retry_policy;
pub use transport::speculative_execution;
Expand Down
37 changes: 29 additions & 8 deletions scylla/src/transport/cluster.rs
Expand Up @@ -3,6 +3,7 @@ use crate::frame::response::event::{Event, StatusChangeEvent};
use crate::frame::value::ValueList;
use crate::load_balancing::TokenAwarePolicy;
use crate::routing::Token;
use crate::transport::host_filter::HostFilter;
use crate::transport::{
connection::{Connection, VerifiedKeyspaceName},
connection_pool::PoolConfig,
Expand Down Expand Up @@ -110,6 +111,10 @@ struct ClusterWorker {

// Keyspace send in "USE <keyspace name>" when opening each connection
used_keyspace: Option<VerifiedKeyspaceName>,

// The host filter determines towards which nodes we should open
// connections
host_filter: Option<Arc<dyn HostFilter>>,
}

#[derive(Debug)]
Expand All @@ -129,6 +134,7 @@ impl Cluster {
pool_config: PoolConfig,
fetch_schema_metadata: bool,
address_translator: &Option<Arc<dyn AddressTranslator>>,
host_filter: &Option<Arc<dyn HostFilter>>,
) -> Result<Cluster, QueryError> {
let (refresh_sender, refresh_receiver) = tokio::sync::mpsc::channel(32);
let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32);
Expand All @@ -141,10 +147,17 @@ impl Cluster {
server_events_sender,
fetch_schema_metadata,
address_translator,
host_filter,
);

let metadata = metadata_reader.read_metadata(true).await?;
let cluster_data = ClusterData::new(metadata, &pool_config, &HashMap::new(), &None);
let cluster_data = ClusterData::new(
metadata,
&pool_config,
&HashMap::new(),
&None,
host_filter.as_deref(),
);
cluster_data.wait_until_all_pools_are_initialized().await;
let cluster_data: Arc<ArcSwap<ClusterData>> =
Arc::new(ArcSwap::from(Arc::new(cluster_data)));
Expand All @@ -160,6 +173,8 @@ impl Cluster {

use_keyspace_channel: use_keyspace_receiver,
used_keyspace: None,

host_filter: host_filter.clone(),
};

let (fut, worker_handle) = worker.work().remote_handle();
Expand Down Expand Up @@ -273,6 +288,7 @@ impl ClusterData {
pool_config: &PoolConfig,
known_peers: &HashMap<SocketAddr, Arc<Node>>,
used_keyspace: &Option<VerifiedKeyspaceName>,
host_filter: Option<&dyn HostFilter>,
) -> Self {
// Create new updated known_peers and ring
let mut new_known_peers: HashMap<SocketAddr, Arc<Node>> =
Expand All @@ -289,13 +305,17 @@ impl ClusterData {
Some(node) if node.datacenter == peer.datacenter && node.rack == peer.rack => {
node.clone()
}
_ => Arc::new(Node::new(
peer.address,
pool_config.clone(),
peer.datacenter,
peer.rack,
used_keyspace.clone(),
)),
_ => {
let is_enabled = host_filter.map_or(true, |f| f.accept(&peer));
Arc::new(Node::new(
peer.address,
pool_config.clone(),
peer.datacenter,
peer.rack,
used_keyspace.clone(),
is_enabled,
))
}
};

new_known_peers.insert(peer.address, node.clone());
Expand Down Expand Up @@ -567,6 +587,7 @@ impl ClusterWorker {
&self.pool_config,
&cluster_data.known_peers,
&self.used_keyspace,
self.host_filter.as_deref(),
));

new_cluster_data
Expand Down
78 changes: 78 additions & 0 deletions scylla/src/transport/host_filter.rs
@@ -0,0 +1,78 @@
//! Host filters.
//!
//! Host filters are essentially just a predicate over
//! [`Peer`](crate::transport::topology::Peer)s. Currently, they are used
//! by the [`Session`](crate::transport::session::Session) to determine whether
//! connections should be opened to a given node or not.

use std::collections::HashSet;
use std::io::Error;
use std::net::{SocketAddr, ToSocketAddrs};

use crate::transport::topology::Peer;

/// The `HostFilter` trait.
pub trait HostFilter: Send + Sync {
/// Returns whether a peer should be accepted or not.
fn accept(&self, peer: &Peer) -> bool;
}

/// Unconditionally accepts all nodes.
pub struct AcceptAllHostFilter;

impl HostFilter for AcceptAllHostFilter {
fn accept(&self, _peer: &Peer) -> bool {
true
}
}

/// Accepts nodes whose addresses are present in the allow list provided
/// during filter's construction.
pub struct AllowListHostFilter {
allowed: HashSet<SocketAddr>,
}

impl AllowListHostFilter {
/// Creates a new `AllowListHostFilter` which only accepts nodes from the
/// list.
pub fn new<I, A>(allowed_iter: I) -> Result<Self, Error>
where
I: IntoIterator<Item = A>,
A: ToSocketAddrs,
{
// I couldn't get the iterator combinators to work
let mut allowed = HashSet::new();
for item in allowed_iter {
for addr in item.to_socket_addrs()? {
allowed.insert(addr);
}
}

Ok(Self { allowed })
}
}

impl HostFilter for AllowListHostFilter {
fn accept(&self, peer: &Peer) -> bool {
self.allowed.contains(&peer.address)
}
}

/// Accepts nodes from given DC.
pub struct DcHostFilter {
local_dc: String,
}

impl DcHostFilter {
/// Creates a new `DcHostFilter` that accepts nodes only from the
/// `local_dc`.
pub fn new(local_dc: String) -> Self {
Self { local_dc }
}
}

impl HostFilter for DcHostFilter {
fn accept(&self, peer: &Peer) -> bool {
peer.datacenter.as_ref() == Some(&self.local_dc)
}
}
2 changes: 1 addition & 1 deletion scylla/src/transport/load_balancing/mod.rs
Expand Up @@ -150,7 +150,7 @@ mod tests {
keyspaces: HashMap::new(),
};

ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
}

pub const EMPTY_STATEMENT: Statement = Statement {
Expand Down
4 changes: 2 additions & 2 deletions scylla/src/transport/load_balancing/token_aware.rs
Expand Up @@ -345,7 +345,7 @@ mod tests {
keyspaces,
};

ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
}

// creates ClusterData with info about 8 nodes living in two different datacenters
Expand Down Expand Up @@ -444,7 +444,7 @@ mod tests {
keyspaces,
};

ClusterData::new(info, &Default::default(), &HashMap::new(), &None)
ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None)
}

// Used as child policy for TokenAwarePolicy tests
Expand Down
1 change: 1 addition & 0 deletions scylla/src/transport/mod.rs
Expand Up @@ -3,6 +3,7 @@ mod cluster;
pub(crate) mod connection;
mod connection_pool;
pub mod downgrading_consistency_retry_policy;
pub mod host_filter;
pub mod iterator;
pub mod load_balancing;
pub(crate) mod metrics;
Expand Down
43 changes: 34 additions & 9 deletions scylla/src/transport/node.rs
Expand Up @@ -21,7 +21,8 @@ pub struct Node {
pub datacenter: Option<String>,
pub rack: Option<String>,

pool: NodeConnectionPool,
// If the node is filtered out by the host filter, this will be None
pool: Option<NodeConnectionPool>,

down_marker: AtomicBool,
}
Expand All @@ -40,9 +41,11 @@ impl Node {
datacenter: Option<String>,
rack: Option<String>,
keyspace_name: Option<VerifiedKeyspaceName>,
enabled: bool,
) -> Self {
let pool =
NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name);
let pool = enabled.then(|| {
NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name)
});

Node {
address,
Expand All @@ -54,7 +57,7 @@ impl Node {
}

pub fn sharder(&self) -> Option<Sharder> {
self.pool.sharder()
self.pool.as_ref()?.sharder()
}

/// Get connection which should be used to connect using given token
Expand All @@ -63,18 +66,25 @@ impl Node {
&self,
token: Token,
) -> Result<Arc<Connection>, QueryError> {
self.pool.connection_for_token(token)
self.get_pool()?.connection_for_token(token)
}

/// Get random connection
pub(crate) async fn random_connection(&self) -> Result<Arc<Connection>, QueryError> {
self.pool.random_connection()
self.get_pool()?.random_connection()
}

pub fn is_down(&self) -> bool {
self.down_marker.load(Ordering::Relaxed)
}

/// Returns a boolean which indicates whether this node was is enabled.
/// Only enabled nodes will have connections open. For disabled nodes,
/// no connections will be opened.
pub fn is_enabled(&self) -> bool {
self.pool.is_some()
}

pub(crate) fn change_down_marker(&self, is_down: bool) {
self.down_marker.store(is_down, Ordering::Relaxed);
}
Expand All @@ -83,15 +93,30 @@ impl Node {
&self,
keyspace_name: VerifiedKeyspaceName,
) -> Result<(), QueryError> {
self.pool.use_keyspace(keyspace_name).await
if let Some(pool) = &self.pool {
pool.use_keyspace(keyspace_name).await?;
}
Ok(())
}

pub(crate) fn get_working_connections(&self) -> Result<Vec<Arc<Connection>>, QueryError> {
self.pool.get_working_connections()
self.get_pool()?.get_working_connections()
}

pub(crate) async fn wait_until_pool_initialized(&self) {
self.pool.wait_until_initialized().await
if let Some(pool) = &self.pool {
pool.wait_until_initialized().await;
}
}

fn get_pool(&self) -> Result<&NodeConnectionPool, QueryError> {
self.pool.as_ref().ok_or_else(|| {
QueryError::IoError(Arc::new(std::io::Error::new(
std::io::ErrorKind::Other,
"No connections in the pool: the node has been disabled \
by the host filter",
)))
})
}
}

Expand Down
8 changes: 8 additions & 0 deletions scylla/src/transport/session.rs
Expand Up @@ -36,6 +36,7 @@ use crate::tracing::{GetTracingConfig, TracingEvent, TracingInfo};
use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug};
use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName};
use crate::transport::connection_pool::PoolConfig;
use crate::transport::host_filter::HostFilter;
use crate::transport::iterator::{PreparedIteratorConfig, RowIterator};
use crate::transport::load_balancing::{
LoadBalancingPolicy, RoundRobinPolicy, Statement, TokenAwarePolicy,
Expand Down Expand Up @@ -203,6 +204,11 @@ pub struct SessionConfig {

pub address_translator: Option<Arc<dyn AddressTranslator>>,

/// The host filter decides whether any connections should be opened
/// to the node or not. The driver will also avoid filtered out nodes when
/// re-establishing the control connection.
pub host_filter: Option<Arc<dyn HostFilter>>,

/// If true, full schema metadata is fetched after successfully reaching a schema agreement.
/// It is true by default but can be disabled if successive schema-altering statements should be performed.
pub refresh_metadata_on_auto_schema_agreement: bool,
Expand Down Expand Up @@ -250,6 +256,7 @@ impl SessionConfig {
auto_await_schema_agreement_timeout: Some(std::time::Duration::from_secs(60)),
request_timeout: Some(Duration::from_secs(30)),
address_translator: None,
host_filter: None,
refresh_metadata_on_auto_schema_agreement: true,
}
}
Expand Down Expand Up @@ -436,6 +443,7 @@ impl Session {
config.get_pool_config(),
config.fetch_schema_metadata,
&config.address_translator,
&config.host_filter,
)
.await?;

Expand Down
33 changes: 33 additions & 0 deletions scylla/src/transport/session_builder.rs
Expand Up @@ -5,6 +5,7 @@ use super::load_balancing::LoadBalancingPolicy;
use super::session::{AddressTranslator, Session, SessionConfig};
use super::speculative_execution::SpeculativeExecutionPolicy;
use super::Compression;
use crate::transport::host_filter::HostFilter;
use crate::transport::{connection_pool::PoolSize, retry_policy::RetryPolicy};
use std::net::SocketAddr;
use std::sync::Arc;
Expand Down Expand Up @@ -622,6 +623,38 @@ impl SessionBuilder {
self
}

/// Sets the host filter. The host filter decides whether any connections
/// should be opened to the node or not. The driver will also avoid
/// those nodes when re-establishing the control connection.
///
/// See the [host filter](crate::transport::host_filter) module for a list
/// of pre-defined filters. It is also possible to provide a custom filter
/// by implementing the HostFilter trait.
///
/// # Example
/// ```
/// # use async_trait::async_trait;
/// # use std::net::SocketAddr;
/// # use std::sync::Arc;
/// # use scylla::{Session, SessionBuilder};
/// # use scylla::transport::session::{AddressTranslator, TranslationError};
/// # use scylla::transport::host_filter::DcHostFilter;
///
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// // The session will only connect to nodes from "my-local-dc"
/// let session: Session = SessionBuilder::new()
/// .known_node("127.0.0.1:9042")
/// .host_filter(Arc::new(DcHostFilter::new("my-local-dc".to_string())))
/// .build()
/// .await?;
/// # Ok(())
/// # }
/// ```
pub fn host_filter(mut self, filter: Arc<dyn HostFilter>) -> Self {
self.config.host_filter = Some(filter);
self
}

/// Set the refresh metadata on schema agreement flag.
/// The default is true.
///
Expand Down

0 comments on commit 90f8cdc

Please sign in to comment.