diff --git a/scylla/src/lib.rs b/scylla/src/lib.rs index c9b3b33cc..a6e222455 100644 --- a/scylla/src/lib.rs +++ b/scylla/src/lib.rs @@ -116,6 +116,7 @@ pub use transport::query_result::QueryResult; pub use transport::session::{IntoTypedRows, Session, SessionConfig}; pub use transport::session_builder::SessionBuilder; +pub use transport::host_filter; pub use transport::load_balancing; pub use transport::retry_policy; pub use transport::speculative_execution; diff --git a/scylla/src/transport/cluster.rs b/scylla/src/transport/cluster.rs index 6e300205a..befd3fea1 100644 --- a/scylla/src/transport/cluster.rs +++ b/scylla/src/transport/cluster.rs @@ -3,6 +3,7 @@ use crate::frame::response::event::{Event, StatusChangeEvent}; use crate::frame::value::ValueList; use crate::load_balancing::TokenAwarePolicy; use crate::routing::Token; +use crate::transport::host_filter::HostFilter; use crate::transport::{ connection::{Connection, VerifiedKeyspaceName}, connection_pool::PoolConfig, @@ -110,6 +111,10 @@ struct ClusterWorker { // Keyspace send in "USE " when opening each connection used_keyspace: Option, + + // The host filter determines towards which nodes we should open + // connections + host_filter: Option>, } #[derive(Debug)] @@ -129,6 +134,7 @@ impl Cluster { pool_config: PoolConfig, fetch_schema_metadata: bool, address_translator: &Option>, + host_filter: &Option>, ) -> Result { let (refresh_sender, refresh_receiver) = tokio::sync::mpsc::channel(32); let (use_keyspace_sender, use_keyspace_receiver) = tokio::sync::mpsc::channel(32); @@ -141,10 +147,17 @@ impl Cluster { server_events_sender, fetch_schema_metadata, address_translator, + host_filter, ); let metadata = metadata_reader.read_metadata(true).await?; - let cluster_data = ClusterData::new(metadata, &pool_config, &HashMap::new(), &None); + let cluster_data = ClusterData::new( + metadata, + &pool_config, + &HashMap::new(), + &None, + host_filter.as_deref(), + ); cluster_data.wait_until_all_pools_are_initialized().await; let cluster_data: Arc> = Arc::new(ArcSwap::from(Arc::new(cluster_data))); @@ -160,6 +173,8 @@ impl Cluster { use_keyspace_channel: use_keyspace_receiver, used_keyspace: None, + + host_filter: host_filter.clone(), }; let (fut, worker_handle) = worker.work().remote_handle(); @@ -273,6 +288,7 @@ impl ClusterData { pool_config: &PoolConfig, known_peers: &HashMap>, used_keyspace: &Option, + host_filter: Option<&dyn HostFilter>, ) -> Self { // Create new updated known_peers and ring let mut new_known_peers: HashMap> = @@ -289,13 +305,17 @@ impl ClusterData { Some(node) if node.datacenter == peer.datacenter && node.rack == peer.rack => { node.clone() } - _ => Arc::new(Node::new( - peer.address, - pool_config.clone(), - peer.datacenter, - peer.rack, - used_keyspace.clone(), - )), + _ => { + let is_enabled = host_filter.map_or(true, |f| f.accept(&peer)); + Arc::new(Node::new( + peer.address, + pool_config.clone(), + peer.datacenter, + peer.rack, + used_keyspace.clone(), + is_enabled, + )) + } }; new_known_peers.insert(peer.address, node.clone()); @@ -567,6 +587,7 @@ impl ClusterWorker { &self.pool_config, &cluster_data.known_peers, &self.used_keyspace, + self.host_filter.as_deref(), )); new_cluster_data diff --git a/scylla/src/transport/host_filter.rs b/scylla/src/transport/host_filter.rs new file mode 100644 index 000000000..48a547ea8 --- /dev/null +++ b/scylla/src/transport/host_filter.rs @@ -0,0 +1,78 @@ +//! Host filters. +//! +//! Host filters are essentially just a predicate over +//! [`Peer`](crate::transport::topology::Peer)s. Currently, they are used +//! by the [`Session`](crate::transport::session::Session) to determine whether +//! connections should be opened to a given node or not. + +use std::collections::HashSet; +use std::io::Error; +use std::net::{SocketAddr, ToSocketAddrs}; + +use crate::transport::topology::Peer; + +/// The `HostFilter` trait. +pub trait HostFilter: Send + Sync { + /// Returns whether a peer should be accepted or not. + fn accept(&self, peer: &Peer) -> bool; +} + +/// Unconditionally accepts all nodes. +pub struct AcceptAllHostFilter; + +impl HostFilter for AcceptAllHostFilter { + fn accept(&self, _peer: &Peer) -> bool { + true + } +} + +/// Accepts nodes whose addresses are present in the allow list provided +/// during filter's construction. +pub struct AllowListHostFilter { + allowed: HashSet, +} + +impl AllowListHostFilter { + /// Creates a new `AllowListHostFilter` which only accepts nodes from the + /// list. + pub fn new(allowed_iter: I) -> Result + where + I: IntoIterator, + A: ToSocketAddrs, + { + // I couldn't get the iterator combinators to work + let mut allowed = HashSet::new(); + for item in allowed_iter { + for addr in item.to_socket_addrs()? { + allowed.insert(addr); + } + } + + Ok(Self { allowed }) + } +} + +impl HostFilter for AllowListHostFilter { + fn accept(&self, peer: &Peer) -> bool { + self.allowed.contains(&peer.address) + } +} + +/// Accepts nodes from given DC. +pub struct DcHostFilter { + local_dc: String, +} + +impl DcHostFilter { + /// Creates a new `DcHostFilter` that accepts nodes only from the + /// `local_dc`. + pub fn new(local_dc: String) -> Self { + Self { local_dc } + } +} + +impl HostFilter for DcHostFilter { + fn accept(&self, peer: &Peer) -> bool { + peer.datacenter.as_ref() == Some(&self.local_dc) + } +} diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index d8edb73df..f893933ed 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -150,7 +150,7 @@ mod tests { keyspaces: HashMap::new(), }; - ClusterData::new(info, &Default::default(), &HashMap::new(), &None) + ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None) } pub const EMPTY_STATEMENT: Statement = Statement { diff --git a/scylla/src/transport/load_balancing/token_aware.rs b/scylla/src/transport/load_balancing/token_aware.rs index 373959e7e..140c01c93 100644 --- a/scylla/src/transport/load_balancing/token_aware.rs +++ b/scylla/src/transport/load_balancing/token_aware.rs @@ -345,7 +345,7 @@ mod tests { keyspaces, }; - ClusterData::new(info, &Default::default(), &HashMap::new(), &None) + ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None) } // creates ClusterData with info about 8 nodes living in two different datacenters @@ -444,7 +444,7 @@ mod tests { keyspaces, }; - ClusterData::new(info, &Default::default(), &HashMap::new(), &None) + ClusterData::new(info, &Default::default(), &HashMap::new(), &None, None) } // Used as child policy for TokenAwarePolicy tests diff --git a/scylla/src/transport/mod.rs b/scylla/src/transport/mod.rs index c8dc0d9b3..69cede77c 100644 --- a/scylla/src/transport/mod.rs +++ b/scylla/src/transport/mod.rs @@ -3,6 +3,7 @@ mod cluster; pub(crate) mod connection; mod connection_pool; pub mod downgrading_consistency_retry_policy; +pub mod host_filter; pub mod iterator; pub mod load_balancing; pub(crate) mod metrics; diff --git a/scylla/src/transport/node.rs b/scylla/src/transport/node.rs index 88b587b0f..683a01662 100644 --- a/scylla/src/transport/node.rs +++ b/scylla/src/transport/node.rs @@ -21,7 +21,8 @@ pub struct Node { pub datacenter: Option, pub rack: Option, - pool: NodeConnectionPool, + // If the node is filtered out by the host filter, this will be None + pool: Option, down_marker: AtomicBool, } @@ -40,9 +41,11 @@ impl Node { datacenter: Option, rack: Option, keyspace_name: Option, + enabled: bool, ) -> Self { - let pool = - NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name); + let pool = enabled.then(|| { + NodeConnectionPool::new(address.ip(), address.port(), pool_config, keyspace_name) + }); Node { address, @@ -54,7 +57,7 @@ impl Node { } pub fn sharder(&self) -> Option { - self.pool.sharder() + self.pool.as_ref()?.sharder() } /// Get connection which should be used to connect using given token @@ -63,18 +66,25 @@ impl Node { &self, token: Token, ) -> Result, QueryError> { - self.pool.connection_for_token(token) + self.get_pool()?.connection_for_token(token) } /// Get random connection pub(crate) async fn random_connection(&self) -> Result, QueryError> { - self.pool.random_connection() + self.get_pool()?.random_connection() } pub fn is_down(&self) -> bool { self.down_marker.load(Ordering::Relaxed) } + /// Returns a boolean which indicates whether this node was is enabled. + /// Only enabled nodes will have connections open. For disabled nodes, + /// no connections will be opened. + pub fn is_enabled(&self) -> bool { + self.pool.is_some() + } + pub(crate) fn change_down_marker(&self, is_down: bool) { self.down_marker.store(is_down, Ordering::Relaxed); } @@ -83,15 +93,30 @@ impl Node { &self, keyspace_name: VerifiedKeyspaceName, ) -> Result<(), QueryError> { - self.pool.use_keyspace(keyspace_name).await + if let Some(pool) = &self.pool { + pool.use_keyspace(keyspace_name).await?; + } + Ok(()) } pub(crate) fn get_working_connections(&self) -> Result>, QueryError> { - self.pool.get_working_connections() + self.get_pool()?.get_working_connections() } pub(crate) async fn wait_until_pool_initialized(&self) { - self.pool.wait_until_initialized().await + if let Some(pool) = &self.pool { + pool.wait_until_initialized().await; + } + } + + fn get_pool(&self) -> Result<&NodeConnectionPool, QueryError> { + self.pool.as_ref().ok_or_else(|| { + QueryError::IoError(Arc::new(std::io::Error::new( + std::io::ErrorKind::Other, + "No connections in the pool: the node has been disabled \ + by the host filter", + ))) + }) } } diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 002e4a7eb..4fef63fb1 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -38,6 +38,7 @@ use crate::tracing::{GetTracingConfig, TracingEvent, TracingInfo}; use crate::transport::cluster::{Cluster, ClusterData, ClusterNeatDebug}; use crate::transport::connection::{Connection, ConnectionConfig, VerifiedKeyspaceName}; use crate::transport::connection_pool::PoolConfig; +use crate::transport::host_filter::HostFilter; use crate::transport::iterator::{PreparedIteratorConfig, RowIterator}; use crate::transport::load_balancing::{ LoadBalancingPolicy, RoundRobinPolicy, Statement, TokenAwarePolicy, @@ -205,6 +206,11 @@ pub struct SessionConfig { pub address_translator: Option>, + /// The host filter decides whether any connections should be opened + /// to the node or not. The driver will also avoid filtered out nodes when + /// re-establishing the control connection. + pub host_filter: Option>, + /// If true, full schema metadata is fetched after successfully reaching a schema agreement. /// It is true by default but can be disabled if successive schema-altering statements should be performed. pub refresh_metadata_on_auto_schema_agreement: bool, @@ -252,6 +258,7 @@ impl SessionConfig { auto_await_schema_agreement_timeout: Some(std::time::Duration::from_secs(60)), request_timeout: Some(Duration::from_secs(30)), address_translator: None, + host_filter: None, refresh_metadata_on_auto_schema_agreement: true, } } @@ -438,6 +445,7 @@ impl Session { config.get_pool_config(), config.fetch_schema_metadata, &config.address_translator, + &config.host_filter, ) .await?; diff --git a/scylla/src/transport/session_builder.rs b/scylla/src/transport/session_builder.rs index aaa59ba7a..09d5e0391 100644 --- a/scylla/src/transport/session_builder.rs +++ b/scylla/src/transport/session_builder.rs @@ -5,6 +5,7 @@ use super::load_balancing::LoadBalancingPolicy; use super::session::{AddressTranslator, Session, SessionConfig}; use super::speculative_execution::SpeculativeExecutionPolicy; use super::Compression; +use crate::transport::host_filter::HostFilter; use crate::transport::{connection_pool::PoolSize, retry_policy::RetryPolicy}; use std::net::SocketAddr; use std::sync::Arc; @@ -622,6 +623,38 @@ impl SessionBuilder { self } + /// Sets the host filter. The host filter decides whether any connections + /// should be opened to the node or not. The driver will also avoid + /// those nodes when re-establishing the control connection. + /// + /// See the [host filter](crate::transport::host_filter) module for a list + /// of pre-defined filters. It is also possible to provide a custom filter + /// by implementing the HostFilter trait. + /// + /// # Example + /// ``` + /// # use async_trait::async_trait; + /// # use std::net::SocketAddr; + /// # use std::sync::Arc; + /// # use scylla::{Session, SessionBuilder}; + /// # use scylla::transport::session::{AddressTranslator, TranslationError}; + /// # use scylla::transport::host_filter::DcHostFilter; + /// + /// # async fn example() -> Result<(), Box> { + /// // The session will only connect to nodes from "my-local-dc" + /// let session: Session = SessionBuilder::new() + /// .known_node("127.0.0.1:9042") + /// .host_filter(Arc::new(DcHostFilter::new("my-local-dc".to_string()))) + /// .build() + /// .await?; + /// # Ok(()) + /// # } + /// ``` + pub fn host_filter(mut self, filter: Arc) -> Self { + self.config.host_filter = Some(filter); + self + } + /// Set the refresh metadata on schema agreement flag. /// The default is true. /// diff --git a/scylla/src/transport/topology.rs b/scylla/src/transport/topology.rs index d3008a5b2..949514ad3 100644 --- a/scylla/src/transport/topology.rs +++ b/scylla/src/transport/topology.rs @@ -4,6 +4,7 @@ use crate::statement::query::Query; use crate::transport::connection::{Connection, ConnectionConfig}; use crate::transport::connection_pool::{NodeConnectionPool, PoolConfig, PoolSize}; use crate::transport::errors::{DbError, QueryError}; +use crate::transport::host_filter::HostFilter; use crate::transport::session::{AddressTranslator, IntoTypedRows}; use crate::utils::parse::{ParseErrorCause, ParseResult, ParserState}; @@ -36,6 +37,7 @@ pub(crate) struct MetadataReader { fetch_schema: bool, address_translator: Option>, + host_filter: Option>, } /// Describes all metadata retrieved from the cluster @@ -219,6 +221,7 @@ impl MetadataReader { server_event_sender: mpsc::Sender, fetch_schema: bool, address_translator: &Option>, + host_filter: &Option>, ) -> Self { let control_connection_address = *known_peers .choose(&mut thread_rng()) @@ -243,6 +246,7 @@ impl MetadataReader { known_peers: known_peers.into(), fetch_schema, address_translator: address_translator.clone(), + host_filter: host_filter.clone(), } } @@ -251,6 +255,9 @@ impl MetadataReader { let mut result = self.fetch_metadata(initial).await; if let Ok(metadata) = result { self.update_known_peers(&metadata); + if initial { + self.handle_unaccepted_host_in_control_connection(&metadata); + } return Ok(metadata); } @@ -302,6 +309,7 @@ impl MetadataReader { match &result { Ok(metadata) => { self.update_known_peers(metadata); + self.handle_unaccepted_host_in_control_connection(metadata); debug!("Fetched new metadata"); } Err(error) => error!( @@ -343,7 +351,67 @@ impl MetadataReader { } fn update_known_peers(&mut self, metadata: &Metadata) { - self.known_peers = metadata.peers.iter().map(|peer| peer.address).collect(); + let host_filter = self.host_filter.as_ref(); + self.known_peers = metadata + .peers + .iter() + .filter(|peer| host_filter.map_or(true, |f| f.accept(peer))) + .map(|peer| peer.address) + .collect(); + + // Check if the host filter isn't accidentally too restrictive, + // and print an error message about this fact + if !metadata.peers.is_empty() && self.known_peers.is_empty() { + error!( + node_ips = ?metadata + .peers + .iter() + .map(|peer| peer.address) + .collect::>(), + "The host filter rejected all nodes in the cluster, \ + no connections that can serve user queries have been \ + established. The session cannot serve any queries!" + ) + } + } + + fn handle_unaccepted_host_in_control_connection(&mut self, metadata: &Metadata) { + let control_connection_peer = metadata + .peers + .iter() + .find(|peer| peer.address == self.control_connection_address); + if let Some(peer) = control_connection_peer { + if !self.host_filter.as_ref().map_or(true, |f| f.accept(peer)) { + warn!( + filtered_node_ips = ?metadata + .peers + .iter() + .filter(|peer| self.host_filter.as_ref().map_or(true, |p| p.accept(peer))) + .map(|peer| peer.address) + .collect::>(), + control_connection_address = ?self.control_connection_address, + "The node that the control connection is established to \ + is not accepted by the host filter. Please verify that \ + the nodes in your initial peers list are accepted by the \ + host filter. The driver will try to re-establish the \ + control connection to a different node." + ); + + // Assuming here that known_peers are up-to-date + if !self.known_peers.is_empty() { + self.control_connection_address = *self + .known_peers + .choose(&mut thread_rng()) + .expect("known_peers is empty - should be impossible"); + + self.control_connection = Self::make_control_connection_pool( + self.control_connection_address, + self.connection_config.clone(), + self.keepalive_interval, + ); + } + } + } } fn make_control_connection_pool(