Skip to content

Commit

Permalink
fix(d1): allow custom max_bind_values
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky committed Apr 29, 2024
1 parent 264f24c commit 824d6e5
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Expand Up @@ -71,7 +71,7 @@ tracing = { version = "0.1" }
tsify = { version = "0.4.5" }
wasm-bindgen = { version = "0.2.89" }
wasm-bindgen-futures = { version = "0.4" }
wasm-rs-dbg = { version = "0.1.2" }
wasm-rs-dbg = { version = "0.1.2", default-features = false, features = ["console-error"] }
wasm-bindgen-test = { version = "0.3.0" }
url = { version = "2.5.0" }

Expand Down
15 changes: 14 additions & 1 deletion quaint/src/connector/connection_info.rs
Expand Up @@ -190,6 +190,19 @@ impl ConnectionInfo {
}
}

pub fn max_insert_rows(&self) -> Option<usize> {
self.sql_family().max_insert_rows()
}

pub fn max_bind_values(&self) -> usize {
match self {
#[cfg(not(target_arch = "wasm32"))]
ConnectionInfo::Native(_) => self.sql_family().max_bind_values(),
// Wasm connectors can override the default max bind values.
ConnectionInfo::External(info) => info.max_bind_values.unwrap_or(self.sql_family().max_bind_values()),
}
}

/// The family of databases connected.
pub fn sql_family(&self) -> SqlFamily {
match self {
Expand Down Expand Up @@ -316,7 +329,7 @@ impl SqlFamily {
}

/// Get the default max rows for a batch insert.
pub fn max_insert_rows(&self) -> Option<usize> {
pub(crate) fn max_insert_rows(&self) -> Option<usize> {
match self {
#[cfg(feature = "postgresql")]
SqlFamily::Postgres => None,
Expand Down
4 changes: 3 additions & 1 deletion quaint/src/connector/external.rs
Expand Up @@ -6,13 +6,15 @@ use super::{SqlFamily, TransactionCapable};
pub struct ExternalConnectionInfo {
pub sql_family: SqlFamily,
pub schema_name: String,
pub max_bind_values: Option<usize>,
}

impl ExternalConnectionInfo {
pub fn new(sql_family: SqlFamily, schema_name: String) -> Self {
pub fn new(sql_family: SqlFamily, schema_name: String, max_bind_values: Option<usize>) -> Self {
ExternalConnectionInfo {
sql_family,
schema_name,
max_bind_values,
}
}
}
Expand Down
Expand Up @@ -73,7 +73,16 @@ impl Actor {
Some("READ COMMITTED"),
);

let mut runner = Runner::load(datamodel, &[], version, tag, setup_metrics(), log_capture).await?;
let mut runner = Runner::load(
datamodel,
&[],
version,
tag,
CONFIG.max_bind_values(),
setup_metrics(),
log_capture,
)
.await?;

tokio::spawn(async move {
while let Some(message) = query_receiver.recv().await {
Expand Down
@@ -1,6 +1,6 @@
use query_engine_tests::*;

#[test_suite(schema(autoinc_id), capabilities(CreateMany, AutoIncrement), exclude(CockroachDb))]
#[test_suite(schema(autoinc_id), capabilities(AutoIncrement), exclude(CockroachDb))]
mod not_in_chunking {
use query_engine_tests::Runner;

Expand Down
Expand Up @@ -69,6 +69,9 @@ pub struct TestConfigFromSerde {
/// This is the URL to the mobile emulator which will execute the queries against
/// the instances of the engine running on the device.
pub(crate) mobile_emulator_url: Option<String>,

/// The maximum number of bind values to use in a query for a driver adapter test runner.
pub(crate) driver_adapter_max_bind_values: Option<usize>,
}

impl TestConfigFromSerde {
Expand Down Expand Up @@ -156,6 +159,9 @@ pub(crate) struct WithDriverAdapter {
/// The driver adapter configuration to forward as a stringified JSON object to the external
/// test executor by setting the `DRIVER_ADAPTER_CONFIG` env var when spawning the executor.
pub(crate) config: Option<DriverAdapterConfig>,

/// The maximum number of bind values to use in a query for a driver adapter test runner.
pub(crate) max_bind_values: Option<usize>,
}

impl WithDriverAdapter {
Expand All @@ -181,6 +187,7 @@ impl From<TestConfigFromSerde> for TestConfig {
adapter,
test_executor: config.external_test_executor.unwrap(),
config: config.driver_adapter_config,
max_bind_values: config.driver_adapter_max_bind_values,
}),
None => None,
};
Expand Down Expand Up @@ -295,6 +302,9 @@ impl TestConfig {
let driver_adapter_config = std::env::var("DRIVER_ADAPTER_CONFIG")
.map(|config| serde_json::from_str::<DriverAdapterConfig>(config.as_str()).ok())
.unwrap_or_default();
let driver_adapter_max_bind_values = std::env::var("DRIVER_ADAPTER_MAX_BIND_VALUES")
.ok()
.map(|v| v.parse::<usize>().unwrap());

let mobile_emulator_url = std::env::var("MOBILE_EMULATOR_URL").ok();

Expand All @@ -310,6 +320,7 @@ impl TestConfig {
driver_adapter,
driver_adapter_config,
mobile_emulator_url,
driver_adapter_max_bind_values,
})
.map(Self::from)
}
Expand Down Expand Up @@ -387,7 +398,7 @@ impl TestConfig {
}

pub fn test_connector(&self) -> TestResult<(ConnectorTag, ConnectorVersion)> {
let version = ConnectorVersion::try_from((self.connector(), self.connector_version()))?;
let version = self.parse_connector_version()?;
let tag = match version {
ConnectorVersion::SqlServer(_) => &SqlServerConnectorTag as ConnectorTag,
ConnectorVersion::Postgres(_) => &PostgresConnectorTag,
Expand All @@ -401,6 +412,17 @@ impl TestConfig {
Ok((tag, version))
}

pub fn max_bind_values(&self) -> Option<usize> {
let version = self.parse_connector_version().unwrap();
let local_mbv = self.with_driver_adapter().and_then(|config| config.max_bind_values);

local_mbv.or_else(|| version.max_bind_values())
}

fn parse_connector_version(&self) -> TestResult<ConnectorVersion> {
ConnectorVersion::try_from((self.connector(), self.connector_version()))
}

#[rustfmt::skip]
pub fn for_external_executor(&self) -> Vec<(String, String)> {
let with_driver_adapter = self.with_driver_adapter().unwrap();
Expand Down
Expand Up @@ -290,19 +290,15 @@ impl ConnectorVersion {
/// From the PoV of the test binary, the target architecture is that of where the test runs,
/// generally x86_64, or aarch64, etc.
///
/// As a consequence there is an mismatch between the the max_bind_values as seen by the test
/// As a consequence there is a mismatch between the max_bind_values as seen by the test
/// binary (overriden by the QUERY_BATCH_SIZE env var) and the max_bind_values as seen by the
/// WASM engine being exercised in those tests, through the RunnerExecutor::External test runner.
///
/// What we do in here, is returning the number of max_bind_values hat the connector under test
/// What we do in here, is returning the number of max_bind_values that the connector under test
/// will use. i.e. if it's a WASM connector, the default, not overridable one. Otherwise the one
/// as seen by the test binary (which will be the same as the engine exercised)
pub fn max_bind_values(&self) -> Option<usize> {
if self.is_wasm() {
self.sql_family().map(|f| f.default_max_bind_values())
} else {
self.sql_family().map(|f| f.max_bind_values())
}
self.sql_family().map(|f| f.max_bind_values())
}

/// SQL family for the connector
Expand All @@ -317,17 +313,6 @@ impl ConnectorVersion {
_ => None,
}
}

/// Determines if the connector uses a driver adapter implemented in Wasm
fn is_wasm(&self) -> bool {
matches!(
self,
Self::Postgres(Some(PostgresVersion::PgJsWasm))
| Self::Postgres(Some(PostgresVersion::NeonJsWasm))
| Self::Vitess(Some(VitessVersion::PlanetscaleJsWasm))
| Self::Sqlite(Some(SqliteVersion::LibsqlJsWasm))
)
}
}

impl fmt::Display for ConnectorVersion {
Expand Down
Expand Up @@ -168,7 +168,7 @@ fn run_relation_link_test_impl(
run_with_tokio(
async move {
println!("Used datamodel:\n {}", datamodel.yellow());
let runner = Runner::load(datamodel.clone(), &[], version, connector_tag, metrics, log_capture)
let runner = Runner::load(datamodel.clone(), &[], version, connector_tag, CONFIG.max_bind_values(), metrics, log_capture)
.await
.unwrap();

Expand Down Expand Up @@ -286,6 +286,7 @@ fn run_connector_test_impl(
db_schemas,
version,
connector_tag,
CONFIG.max_bind_values(),
metrics,
log_capture,
)
Expand Down
Expand Up @@ -206,6 +206,9 @@ pub struct Runner {
metrics: MetricRegistry,
protocol: EngineProtocol,
log_capture: TestLogCapture,
// This is a local override for the max bind values that can be used in a query.
// It is set in the test config files, specifically for D1 for now, which has a lower limit than the SQLite native connector.
local_max_bind_values: Option<usize>,
}

impl Runner {
Expand All @@ -221,14 +224,16 @@ impl Runner {
}

pub fn max_bind_values(&self) -> Option<usize> {
self.connector_version().max_bind_values()
self.local_max_bind_values
.or_else(|| self.connector_version().max_bind_values())
}

pub async fn load(
datamodel: String,
db_schemas: &[&str],
connector_version: ConnectorVersion,
connector_tag: ConnectorTag,
local_max_bind_values: Option<usize>,
metrics: MetricRegistry,
log_capture: TestLogCapture,
) -> TestResult<Self> {
Expand All @@ -240,12 +245,13 @@ impl Runner {
let (executor, db_version) = match crate::CONFIG.with_driver_adapter() {
Some(with_driver_adapter) => {
let external_executor = ExternalExecutor::new();
let external_initializer: ExternalExecutorInitializer<'_> =
external_executor.init(&datamodel, url.as_str());
let executor = RunnerExecutor::External(external_executor);

let external_initializer = external_executor.init(&datamodel, url.as_str());

qe_setup::setup_external(with_driver_adapter.adapter, external_initializer, db_schemas).await?;

let executor = RunnerExecutor::External(external_executor);

let database_version = None;
(executor, database_version)
}
Expand All @@ -263,8 +269,8 @@ impl Runner {
let connector = query_executor.primary_connector();
let conn = connector.get_connection().await.unwrap();
let database_version = conn.version().await;

let executor = RunnerExecutor::Builtin(query_executor);

(executor, database_version)
}
};
Expand All @@ -283,6 +289,7 @@ impl Runner {
metrics,
protocol,
log_capture,
local_max_bind_values,
})
}

Expand Down
Expand Up @@ -2,5 +2,6 @@
"connector": "sqlite",
"version": "cfd1",
"driver_adapter": "d1",
"external_test_executor": "Wasm"
"external_test_executor": "Wasm",
"driver_adapter_max_bind_value": 100
}
5 changes: 2 additions & 3 deletions query-engine/connectors/sql-query-connector/src/context.rs
Expand Up @@ -13,9 +13,8 @@ pub(super) struct Context<'a> {

impl<'a> Context<'a> {
pub(crate) fn new(connection_info: &'a ConnectionInfo, trace_id: Option<&'a str>) -> Self {
let sql_family = connection_info.sql_family();
let max_insert_rows = sql_family.max_insert_rows();
let max_bind_values = sql_family.max_bind_values();
let max_insert_rows = connection_info.max_insert_rows();
let max_bind_values = connection_info.max_bind_values();

Context {
connection_info,
Expand Down
1 change: 1 addition & 0 deletions query-engine/driver-adapters/src/queryable.rs
Expand Up @@ -236,6 +236,7 @@ impl std::fmt::Debug for JsQueryable {
impl ExternalConnector for JsQueryable {
async fn get_connection_info(&self) -> quaint::Result<ExternalConnectionInfo> {
let conn_info = self.driver_proxy.get_connection_info().await?;

Ok(conn_info.into_external_connection_info(&self.inner.provider))
}
}
Expand Down
15 changes: 11 additions & 4 deletions query-engine/driver-adapters/src/types.rs
Expand Up @@ -40,8 +40,8 @@ impl FromStr for AdapterFlavour {
}
}

impl From<AdapterFlavour> for SqlFamily {
fn from(value: AdapterFlavour) -> Self {
impl From<&AdapterFlavour> for SqlFamily {
fn from(value: &AdapterFlavour) -> Self {
match value {
#[cfg(feature = "mysql")]
AdapterFlavour::Mysql => SqlFamily::Mysql,
Expand All @@ -59,14 +59,21 @@ impl From<AdapterFlavour> for SqlFamily {
#[derive(Default)]
pub(crate) struct JsConnectionInfo {
pub schema_name: Option<String>,
pub max_bind_values: Option<u32>,
}

impl JsConnectionInfo {
pub fn into_external_connection_info(self, provider: &AdapterFlavour) -> ExternalConnectionInfo {
let schema_name = self.get_schema_name(provider);
let sql_family = provider.to_owned().into();
ExternalConnectionInfo::new(sql_family, schema_name.to_owned())
let sql_family = SqlFamily::from(provider);

ExternalConnectionInfo::new(
sql_family,
schema_name.to_owned(),
self.max_bind_values.map(|v| v as usize),
)
}

fn get_schema_name(&self, provider: &AdapterFlavour) -> &str {
match self.schema_name.as_ref() {
Some(name) => name,
Expand Down

0 comments on commit 824d6e5

Please sign in to comment.