diff --git a/sqlx-core/src/logger.rs b/sqlx-core/src/logger.rs index 6446d14daa..32583d2064 100644 --- a/sqlx-core/src/logger.rs +++ b/sqlx-core/src/logger.rs @@ -1,4 +1,10 @@ use crate::connection::LogSettings; +#[cfg(feature = "sqlite")] +use std::collections::HashSet; +#[cfg(feature = "sqlite")] +use std::fmt::Debug; +#[cfg(feature = "sqlite")] +use std::hash::Hash; use std::time::Instant; pub(crate) struct QueryLogger<'q> { @@ -78,6 +84,86 @@ impl<'q> Drop for QueryLogger<'q> { } } +#[cfg(feature = "sqlite")] +pub(crate) struct QueryPlanLogger<'q, O: Debug + Hash + Eq, R: Debug + Hash + Eq, P: Debug> { + sql: &'q str, + unknown_operations: HashSet, + results: HashSet, + program: Vec

, + settings: LogSettings, +} + +#[cfg(feature = "sqlite")] +impl<'q, O: Debug + Hash + Eq, R: Debug + Hash + Eq, P: Debug> QueryPlanLogger<'q, O, R, P> { + pub(crate) fn new(sql: &'q str, settings: LogSettings) -> Self { + Self { + sql, + unknown_operations: HashSet::new(), + results: HashSet::new(), + program: Vec::new(), + settings, + } + } + + pub(crate) fn add_result(&mut self, result: R) { + self.results.insert(result); + } + + pub(crate) fn add_program(&mut self, program: Vec

) { + self.program = program; + } + + pub(crate) fn add_unknown_operation(&mut self, operation: O) { + self.unknown_operations.insert(operation); + } + + pub(crate) fn finish(&self) { + let lvl = self.settings.statements_level; + + if let Some(lvl) = lvl + .to_level() + .filter(|lvl| log::log_enabled!(target: "sqlx::explain", *lvl)) + { + let mut summary = parse_query_summary(&self.sql); + + let sql = if summary != self.sql { + summary.push_str(" …"); + format!( + "\n\n{}\n", + sqlformat::format( + &self.sql, + &sqlformat::QueryParams::None, + sqlformat::FormatOptions::default() + ) + ) + } else { + String::new() + }; + + log::logger().log( + &log::Record::builder() + .args(format_args!( + "{}; program:{:?}, unknown_operations:{:?}, results: {:?}{}", + summary, self.program, self.unknown_operations, self.results, sql + )) + .level(lvl) + .module_path_static(Some("sqlx::explain")) + .target("sqlx::explain") + .build(), + ); + } + } +} + +#[cfg(feature = "sqlite")] +impl<'q, O: Debug + Hash + Eq, R: Debug + Hash + Eq, P: Debug> Drop + for QueryPlanLogger<'q, O, R, P> +{ + fn drop(&mut self) { + self.finish(); + } +} + fn parse_query_summary(sql: &str) -> String { // For now, just take the first 4 words sql.split_whitespace() diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 310d803227..101ca56a2b 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -17,19 +17,78 @@ const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */ // opcodes const OP_INIT: &str = "Init"; const OP_GOTO: &str = "Goto"; +const OP_DECR_JUMP_ZERO: &str = "DecrJumpZero"; +const OP_ELSE_EQ: &str = "ElseEq"; +const OP_EQ: &str = "Eq"; +const OP_END_COROUTINE: &str = "EndCoroutine"; +const OP_FILTER: &str = "Filter"; +const OP_FK_IF_ZERO: &str = "FkIfZero"; +const OP_FOUND: &str = "Found"; +const OP_GE: &str = "Ge"; +const OP_GO_SUB: &str = "Gosub"; +const OP_GT: &str = "Gt"; +const OP_IDX_GE: &str = "IdxGE"; +const OP_IDX_GT: &str = "IdxGT"; +const OP_IDX_LE: &str = "IdxLE"; +const OP_IDX_LT: &str = "IdxLT"; +const OP_IF: &str = "If"; +const OP_IF_NO_HOPE: &str = "IfNoHope"; +const OP_IF_NOT: &str = "IfNot"; +const OP_IF_NOT_OPEN: &str = "IfNotOpen"; +const OP_IF_NOT_ZERO: &str = "IfNotZero"; +const OP_IF_NULL_ROW: &str = "IfNullRow"; +const OP_IF_POS: &str = "IfPos"; +const OP_IF_SMALLER: &str = "IfSmaller"; +const OP_INCR_VACUUM: &str = "IncrVacuum"; +const OP_INIT_COROUTINE: &str = "InitCoroutine"; +const OP_IS_NULL: &str = "IsNull"; +const OP_IS_NULL_OR_TYPE: &str = "IsNullOrType"; +const OP_LAST: &str = "Last"; +const OP_LE: &str = "Le"; +const OP_LT: &str = "Lt"; +const OP_MUST_BE_INT: &str = "MustBeInt"; +const OP_NE: &str = "Ne"; +const OP_NEXT: &str = "Next"; +const OP_NO_CONFLICT: &str = "NoConflict"; +const OP_NOT_EXISTS: &str = "NotExists"; +const OP_NOT_NULL: &str = "NotNull"; +const OP_ONCE: &str = "Once"; +const OP_PREV: &str = "Prev"; +const OP_PROGRAM: &str = "Program"; +const OP_RETURN: &str = "Return"; +const OP_REWIND: &str = "Rewind"; +const OP_ROW_DATA: &str = "RowData"; +const OP_ROW_SET_READ: &str = "RowSetRead"; +const OP_ROW_SET_TEST: &str = "RowSetTest"; +const OP_SEEK_GE: &str = "SeekGE"; +const OP_SEEK_GT: &str = "SeekGT"; +const OP_SEEK_LE: &str = "SeekLE"; +const OP_SEEK_LT: &str = "SeekLT"; +const OP_SEEK_ROW_ID: &str = "SeekRowId"; +const OP_SEEK_SCAN: &str = "SeekScan"; +const OP_SEQUENCE_TEST: &str = "SequenceTest"; +const OP_SORTER_NEXT: &str = "SorterNext"; +const OP_SORTER_SORT: &str = "SorterSort"; +const OP_V_FILTER: &str = "VFilter"; +const OP_V_NEXT: &str = "VNext"; +const OP_YIELD: &str = "Yield"; +const OP_JUMP: &str = "Jump"; const OP_COLUMN: &str = "Column"; const OP_MAKE_RECORD: &str = "MakeRecord"; const OP_INSERT: &str = "Insert"; const OP_IDX_INSERT: &str = "IdxInsert"; +const OP_OPEN_PSEUDO: &str = "OpenPseudo"; const OP_OPEN_READ: &str = "OpenRead"; const OP_OPEN_WRITE: &str = "OpenWrite"; const OP_OPEN_EPHEMERAL: &str = "OpenEphemeral"; const OP_OPEN_AUTOINDEX: &str = "OpenAutoindex"; +const OP_AGG_FINAL: &str = "AggFinal"; const OP_AGG_STEP: &str = "AggStep"; const OP_FUNCTION: &str = "Function"; const OP_MOVE: &str = "Move"; const OP_COPY: &str = "Copy"; const OP_SCOPY: &str = "SCopy"; +const OP_NULL: &str = "Null"; const OP_NULL_ROW: &str = "NullRow"; const OP_INT_COPY: &str = "IntCopy"; const OP_CAST: &str = "Cast"; @@ -56,18 +115,115 @@ const OP_DIVIDE: &str = "Divide"; const OP_REMAINDER: &str = "Remainder"; const OP_CONCAT: &str = "Concat"; const OP_RESULT_ROW: &str = "ResultRow"; +const OP_HALT: &str = "Halt"; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +struct ColumnType { + pub datatype: DataType, + pub nullable: Option, +} + +impl Default for ColumnType { + fn default() -> Self { + Self { + datatype: DataType::Null, + nullable: None, + } + } +} + +impl ColumnType { + fn null() -> Self { + Self { + datatype: DataType::Null, + nullable: Some(true), + } + } +} #[derive(Debug, Clone, Eq, PartialEq)] enum RegDataType { - Single(DataType), - Record(Vec), + Single(ColumnType), + Record(Vec), + Int(i64), } impl RegDataType { - fn map_to_datatype(self) -> DataType { + fn map_to_datatype(&self) -> DataType { match self { - RegDataType::Single(d) => d, + RegDataType::Single(d) => d.datatype, RegDataType::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context + RegDataType::Int(_) => DataType::Int, + } + } + fn map_to_nullable(&self) -> Option { + match self { + RegDataType::Single(d) => d.nullable, + RegDataType::Record(_) => None, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context + RegDataType::Int(_) => Some(false), + } + } + fn map_to_columntype(&self) -> ColumnType { + match self { + RegDataType::Single(d) => *d, + RegDataType::Record(_) => ColumnType { + datatype: DataType::Null, + nullable: None, + }, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context + RegDataType::Int(_) => ColumnType { + datatype: DataType::Int, + nullable: Some(false), + }, + } + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +enum CursorDataType { + Normal(HashMap), + Pseudo(i64), +} + +impl CursorDataType { + fn from_sparse_record(record: &HashMap) -> Self { + Self::Normal( + record + .iter() + .map(|(&colnum, &datatype)| (colnum, datatype)) + .collect(), + ) + } + + fn from_dense_record(record: &Vec) -> Self { + Self::Normal((0..).zip(record.iter().copied()).collect()) + } + + fn map_to_dense_record(&self, registers: &HashMap) -> Vec { + match self { + Self::Normal(record) => { + let mut rowdata = vec![ColumnType::default(); record.len()]; + for (idx, col) in record.iter() { + rowdata[*idx as usize] = col.clone(); + } + rowdata + } + Self::Pseudo(i) => match registers.get(i) { + Some(RegDataType::Record(r)) => r.clone(), + _ => Vec::new(), + }, + } + } + + fn map_to_sparse_record( + &self, + registers: &HashMap, + ) -> HashMap { + match self { + Self::Normal(c) => c.clone(), + Self::Pseudo(i) => match registers.get(i) { + Some(RegDataType::Record(r)) => (0..).zip(r.iter().copied()).collect(), + _ => HashMap::new(), + }, } } } @@ -99,11 +255,11 @@ fn opcode_to_type(op: &str) -> DataType { fn root_block_columns( conn: &mut ConnectionState, -) -> Result>, Error> { - let table_block_columns: Vec<(i64, i64, String)> = execute::iter( +) -> Result>, Error> { + let table_block_columns: Vec<(i64, i64, String, bool)> = execute::iter( conn, - "SELECT s.rootpage, col.cid as colnum, col.type - FROM sqlite_schema s + "SELECT s.rootpage, col.cid as colnum, col.type, col.\"notnull\" + FROM (select * from sqlite_temp_schema UNION select * from sqlite_schema) s JOIN pragma_table_info(s.name) AS col WHERE s.type = 'table'", None, @@ -113,10 +269,10 @@ fn root_block_columns( .map(|row| FromRow::from_row(&row?)) .collect::, Error>>()?; - let index_block_columns: Vec<(i64, i64, String)> = execute::iter( + let index_block_columns: Vec<(i64, i64, String, bool)> = execute::iter( conn, - "SELECT s.rootpage, idx.seqno as colnum, col.type - FROM sqlite_schema s + "SELECT s.rootpage, idx.seqno as colnum, col.type, col.\"notnull\" + FROM (select * from sqlite_temp_schema UNION select * from sqlite_schema) s JOIN pragma_index_info(s.name) AS idx LEFT JOIN pragma_table_info(s.tbl_name) as col ON col.cid = idx.cid @@ -128,277 +284,533 @@ fn root_block_columns( .map(|row| FromRow::from_row(&row?)) .collect::, Error>>()?; - let mut row_info: HashMap> = HashMap::new(); - for (block, colnum, datatype) in table_block_columns { + let mut row_info: HashMap> = HashMap::new(); + for (block, colnum, datatype, notnull) in table_block_columns { let row_info = row_info.entry(block).or_default(); - row_info.insert(colnum, datatype.parse().unwrap_or(DataType::Null)); + row_info.insert( + colnum, + ColumnType { + datatype: datatype.parse().unwrap_or(DataType::Null), + nullable: Some(!notnull), + }, + ); } - for (block, colnum, datatype) in index_block_columns { + for (block, colnum, datatype, notnull) in index_block_columns { let row_info = row_info.entry(block).or_default(); - row_info.insert(colnum, datatype.parse().unwrap_or(DataType::Null)); + row_info.insert( + colnum, + ColumnType { + datatype: datatype.parse().unwrap_or(DataType::Null), + nullable: Some(!notnull), + }, + ); } return Ok(row_info); } +#[derive(Debug, Clone, PartialEq)] +struct QueryState { + pub visited: Vec, + pub history: Vec, + // Registers + pub r: HashMap, + // Rows that pointers point to + pub p: HashMap, + // Next instruction to execute + pub program_i: usize, + // Results published by the execution + pub result: Option, Option)>>, +} + // Opcode Reference: https://sqlite.org/opcode.html pub(super) fn explain( conn: &mut ConnectionState, query: &str, ) -> Result<(Vec, Vec>), Error> { - // Registers - let mut r = HashMap::::with_capacity(6); - // Map between pointer and register - let mut r_cursor = HashMap::>::with_capacity(6); - // Rows that pointers point to - let mut p = HashMap::>::with_capacity(6); - - // Nullable columns - let mut n = HashMap::::with_capacity(6); + let mut logger = crate::logger::QueryPlanLogger::new(query, conn.log_settings.clone()); let root_block_cols = root_block_columns(conn)?; - let program: Vec<(i64, String, i64, i64, i64, Vec)> = execute::iter(conn, &format!("EXPLAIN {}", query), None, false)? .filter_map(|res| res.map(|either| either.right()).transpose()) .map(|row| FromRow::from_row(&row?)) .collect::, Error>>()?; - - let mut program_i = 0; + logger.add_program(program.clone()); let program_size = program.len(); - let mut visited = vec![false; program_size]; - let mut output = Vec::new(); - let mut nullable = Vec::new(); + let mut states = vec![QueryState { + visited: vec![false; program_size], + history: Vec::new(), + r: HashMap::with_capacity(6), + p: HashMap::with_capacity(6), + program_i: 0, + result: None, + }]; + + let mut result_states = Vec::new(); + + while let Some(mut state) = states.pop() { + while state.program_i < program_size { + if state.visited[state.program_i] { + state.program_i += 1; + //avoid (infinite) loops by breaking if we ever hit the same instruction twice + break; + } + let (_, ref opcode, p1, p2, p3, ref p4) = program[state.program_i]; + state.history.push(state.program_i); + + match &**opcode { + OP_INIT => { + // start at + state.visited[state.program_i] = true; + state.program_i = p2 as usize; + continue; + } - let mut result = None; + OP_GOTO => { + // goto + state.visited[state.program_i] = true; + state.program_i = p2 as usize; + continue; + } - while program_i < program_size { - if visited[program_i] { - program_i += 1; - continue; - } - let (_, ref opcode, p1, p2, p3, ref p4) = program[program_i]; - - match &**opcode { - OP_INIT => { - // start at - visited[program_i] = true; - program_i = p2 as usize; - continue; - } + OP_DECR_JUMP_ZERO | OP_ELSE_EQ | OP_EQ | OP_FILTER | OP_FK_IF_ZERO | OP_FOUND + | OP_GE | OP_GO_SUB | OP_GT | OP_IDX_GE | OP_IDX_GT | OP_IDX_LE | OP_IDX_LT + | OP_IF | OP_IF_NO_HOPE | OP_IF_NOT | OP_IF_NOT_OPEN | OP_IF_NOT_ZERO + | OP_IF_NULL_ROW | OP_IF_POS | OP_IF_SMALLER | OP_INCR_VACUUM | OP_IS_NULL + | OP_IS_NULL_OR_TYPE | OP_LE | OP_LAST | OP_LT | OP_MUST_BE_INT | OP_NE + | OP_NEXT | OP_NO_CONFLICT | OP_NOT_EXISTS | OP_NOT_NULL | OP_ONCE | OP_PREV + | OP_PROGRAM | OP_ROW_SET_READ | OP_ROW_SET_TEST | OP_SEEK_GE | OP_SEEK_GT + | OP_SEEK_LE | OP_SEEK_LT | OP_SEEK_ROW_ID | OP_SEEK_SCAN | OP_SEQUENCE_TEST + | OP_SORTER_NEXT | OP_SORTER_SORT | OP_V_FILTER | OP_V_NEXT | OP_REWIND => { + // goto or next instruction (depending on actual values) + state.visited[state.program_i] = true; + + let mut branch_state = state.clone(); + branch_state.program_i = p2 as usize; + states.push(branch_state); + + state.program_i += 1; + continue; + } - OP_GOTO => { - // goto - visited[program_i] = true; - program_i = p2 as usize; - continue; - } + OP_INIT_COROUTINE => { + // goto or next instruction (depending on actual values) + state.visited[state.program_i] = true; + state.r.insert(p1, RegDataType::Int(p3)); - OP_COLUMN => { - //Get the row stored at p1, or NULL; get the column stored at p2, or NULL - if let Some(record) = p.get(&p1) { - if let Some(col) = record.get(&p2) { - // insert into p3 the datatype of the col - r.insert(p3, RegDataType::Single(*col)); - // map between pointer p1 and register p3 - r_cursor.entry(p1).or_default().push(p3); + if p2 != 0 { + state.program_i = p2 as usize; } else { - r.insert(p3, RegDataType::Single(DataType::Null)); + state.program_i += 1; } - } else { - r.insert(p3, RegDataType::Single(DataType::Null)); + continue; } - } - OP_MAKE_RECORD => { - // p3 = Record([p1 .. p1 + p2]) - let mut record = Vec::with_capacity(p2 as usize); - for reg in p1..p1 + p2 { - record.push( - r.get(®) - .map(|d| d.clone().map_to_datatype()) - .unwrap_or(DataType::Null), - ); + OP_END_COROUTINE => { + // jump to p2 of the yield instruction pointed at by register p1 + state.visited[state.program_i] = true; + if let Some(RegDataType::Int(yield_i)) = state.r.get(&p1) { + if let Some((_, yield_op, _, yield_p2, _, _)) = + program.get(*yield_i as usize) + { + if OP_YIELD == yield_op.as_str() { + state.program_i = (*yield_p2) as usize; + state.r.remove(&p1); + continue; + } else { + break; + } + } else { + break; + } + } else { + break; + } } - r.insert(p3, RegDataType::Record(record)); - } - OP_INSERT | OP_IDX_INSERT => { - if let Some(RegDataType::Record(record)) = r.get(&p2) { - if let Some(row) = p.get_mut(&p1) { - // Insert the record into wherever pointer p1 is - *row = (0..).zip(record.iter().copied()).collect(); + OP_RETURN => { + // jump to the instruction after the instruction pointed at by register p1 + state.visited[state.program_i] = true; + if let Some(RegDataType::Int(return_i)) = state.r.get(&p1) { + state.program_i = (*return_i + 1) as usize; + state.r.remove(&p1); + continue; + } else { + break; } } - //Noop if the register p2 isn't a record, or if pointer p1 does not exist - } - - OP_OPEN_READ | OP_OPEN_WRITE | OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => { - //Create a new pointer which is referenced by p1 - //Create a new pointer which is referenced by p1, take column metadata from db schema if found - if p3 == 0 { - if let Some(columns) = root_block_cols.get(&p2) { - p.insert( - p1, - columns - .iter() - .map(|(&colnum, &datatype)| (colnum, datatype)) - .collect(), - ); + OP_YIELD => { + // jump to p2 of the yield instruction pointed at by register p1, store prior instruction in p1 + state.visited[state.program_i] = true; + if let Some(RegDataType::Int(yield_i)) = state.r.get_mut(&p1) { + let program_i: usize = state.program_i; + + //if yielding to a yield operation, go to the NEXT instruction after that instruction + if program + .get(*yield_i as usize) + .map(|(_, yield_op, _, _, _, _)| yield_op.as_str()) + == Some(OP_YIELD) + { + state.program_i = (*yield_i + 1) as usize; + *yield_i = program_i as i64; + continue; + } else { + state.program_i = *yield_i as usize; + *yield_i = program_i as i64; + continue; + } } else { - p.insert(p1, HashMap::with_capacity(6)); + break; } - } else { - p.insert(p1, HashMap::with_capacity(6)); } - } - OP_VARIABLE => { - // r[p2] = - r.insert(p2, RegDataType::Single(DataType::Null)); - n.insert(p3, true); - } + OP_JUMP => { + // goto one of , , or based on the result of a prior compare + state.visited[state.program_i] = true; + + let mut branch_state = state.clone(); + branch_state.program_i = p1 as usize; + states.push(branch_state); + + let mut branch_state = state.clone(); + branch_state.program_i = p2 as usize; + states.push(branch_state); + + let mut branch_state = state.clone(); + branch_state.program_i = p3 as usize; + states.push(branch_state); + } - OP_FUNCTION => { - // r[p1] = func( _ ) - match from_utf8(p4).map_err(Error::protocol)? { - "last_insert_rowid(0)" => { - // last_insert_rowid() -> INTEGER - r.insert(p3, RegDataType::Single(DataType::Int64)); - n.insert(p3, n.get(&p3).copied().unwrap_or(false)); + OP_COLUMN => { + //Get the row stored at p1, or NULL; get the column stored at p2, or NULL + if let Some(record) = state.p.get(&p1).map(|c| c.map_to_sparse_record(&state.r)) + { + if let Some(col) = record.get(&p2) { + // insert into p3 the datatype of the col + state.r.insert(p3, RegDataType::Single(*col)); + } else { + state + .r + .insert(p3, RegDataType::Single(ColumnType::default())); + } + } else { + state + .r + .insert(p3, RegDataType::Single(ColumnType::default())); } + } - _ => {} + OP_ROW_DATA => { + //Get entire row from cursor p1, store it into register p2 + if let Some(record) = state.p.get(&p1) { + let rowdata = record.map_to_dense_record(&state.r); + state.r.insert(p2, RegDataType::Record(rowdata)); + } else { + state.r.insert(p2, RegDataType::Record(Vec::new())); + } } - } - OP_NULL_ROW => { - // all registers that map to cursor X are potentially nullable - for register in &r_cursor[&p1] { - n.insert(*register, true); + OP_MAKE_RECORD => { + // p3 = Record([p1 .. p1 + p2]) + let mut record = Vec::with_capacity(p2 as usize); + for reg in p1..p1 + p2 { + record.push( + state + .r + .get(®) + .map(|d| d.clone().map_to_columntype()) + .unwrap_or(ColumnType::default()), + ); + } + state.r.insert(p3, RegDataType::Record(record)); } - } - OP_AGG_STEP => { - let p4 = from_utf8(p4).map_err(Error::protocol)?; + OP_INSERT | OP_IDX_INSERT => { + if let Some(RegDataType::Record(record)) = state.r.get(&p2) { + if let Some(CursorDataType::Normal(row)) = state.p.get_mut(&p1) { + // Insert the record into wherever pointer p1 is + *row = (0..).zip(record.iter().copied()).collect(); + } + } + //Noop if the register p2 isn't a record, or if pointer p1 does not exist + } - if p4.starts_with("count(") { - // count(_) -> INTEGER - r.insert(p3, RegDataType::Single(DataType::Int64)); - n.insert(p3, n.get(&p3).copied().unwrap_or(false)); - } else if let Some(v) = r.get(&p2).cloned() { - // r[p3] = AGG ( r[p2] ) - r.insert(p3, v); - let val = n.get(&p2).copied().unwrap_or(true); - n.insert(p3, val); + OP_OPEN_PSEUDO => { + // Create a cursor p1 aliasing the record from register p2 + state.p.insert(p1, CursorDataType::Pseudo(p2)); + } + OP_OPEN_READ | OP_OPEN_WRITE => { + //Create a new pointer which is referenced by p1, take column metadata from db schema if found + if p3 == 0 { + if let Some(columns) = root_block_cols.get(&p2) { + state + .p + .insert(p1, CursorDataType::from_sparse_record(columns)); + } else { + state + .p + .insert(p1, CursorDataType::Normal(HashMap::with_capacity(6))); + } + } else { + state + .p + .insert(p1, CursorDataType::Normal(HashMap::with_capacity(6))); + } } - } - OP_CAST => { - // affinity(r[p1]) - if let Some(v) = r.get_mut(&p1) { - *v = RegDataType::Single(affinity_to_type(p2 as u8)); + OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => { + //Create a new pointer which is referenced by p1 + state.p.insert( + p1, + CursorDataType::from_dense_record(&vec![ColumnType::null(); p2 as usize]), + ); } - } - OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => { - // r[p2] = r[p1] - if let Some(v) = r.get(&p1).cloned() { - r.insert(p2, v); + OP_VARIABLE => { + // r[p2] = + state.r.insert(p2, RegDataType::Single(ColumnType::null())); + } - if let Some(null) = n.get(&p1).copied() { - n.insert(p2, null); + OP_FUNCTION => { + // r[p1] = func( _ ) + match from_utf8(p4).map_err(Error::protocol)? { + "last_insert_rowid(0)" => { + // last_insert_rowid() -> INTEGER + state.r.insert( + p3, + RegDataType::Single(ColumnType { + datatype: DataType::Int64, + nullable: Some(false), + }), + ); + } + + _ => logger.add_unknown_operation(program[state.program_i].clone()), } } - } - - OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID - | OP_NEWROWID => { - // r[p2] = - r.insert(p2, RegDataType::Single(opcode_to_type(&opcode))); - n.insert(p2, n.get(&p2).copied().unwrap_or(false)); - } - OP_NOT => { - // r[p2] = NOT r[p1] - if let Some(a) = r.get(&p1).cloned() { - r.insert(p2, a); - let val = n.get(&p1).copied().unwrap_or(true); - n.insert(p2, val); + OP_NULL_ROW => { + // all columns in cursor X are potentially nullable + if let Some(CursorDataType::Normal(ref mut cursor)) = state.p.get_mut(&p1) { + for ref mut col in cursor.values_mut() { + col.nullable = Some(true); + } + } + //else we don't know about the cursor } - } - OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT - | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => { - // r[p3] = r[p1] + r[p2] - match (r.get(&p1).cloned(), r.get(&p2).cloned()) { - (Some(a), Some(b)) => { - r.insert( + OP_AGG_STEP => { + //assume that AGG_FINAL will be called + let p4 = from_utf8(p4).map_err(Error::protocol)?; + + if p4.starts_with("count(") { + // count(_) -> INTEGER + state.r.insert( p3, - if matches!(a, RegDataType::Single(DataType::Null)) { - b - } else { - a - }, + RegDataType::Single(ColumnType { + datatype: DataType::Int64, + nullable: Some(false), + }), ); + } else if let Some(v) = state.r.get(&p2).cloned() { + // r[p3] = AGG ( r[p2] ) + state.r.insert(p3, v); } + } - (Some(v), None) => { - r.insert(p3, v); + OP_AGG_FINAL => { + let p4 = from_utf8(p4).map_err(Error::protocol)?; + + if p4.starts_with("count(") { + // count(_) -> INTEGER + state.r.insert( + p1, + RegDataType::Single(ColumnType { + datatype: DataType::Int64, + nullable: Some(false), + }), + ); + } else if let Some(v) = state.r.get(&p2).cloned() { + // r[p3] = AGG ( r[p2] ) + state.r.insert(p3, v); + } + } + + OP_CAST => { + // affinity(r[p1]) + if let Some(v) = state.r.get_mut(&p1) { + *v = RegDataType::Single(ColumnType { + datatype: affinity_to_type(p2 as u8), + nullable: v.map_to_nullable(), + }); } + } - (None, Some(v)) => { - r.insert(p3, v); + OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => { + // r[p2] = r[p1] + if let Some(v) = state.r.get(&p1).cloned() { + state.r.insert(p2, v); } + } + + OP_INTEGER => { + // r[p2] = p1 + state.r.insert(p2, RegDataType::Int(p1)); + } - _ => {} + OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_ROWID | OP_NEWROWID => { + // r[p2] = + state.r.insert( + p2, + RegDataType::Single(ColumnType { + datatype: opcode_to_type(&opcode), + nullable: Some(false), + }), + ); } - match (n.get(&p1).copied(), n.get(&p2).copied()) { - (Some(a), Some(b)) => { - n.insert(p3, a || b); + OP_NOT => { + // r[p2] = NOT r[p1] + if let Some(a) = state.r.get(&p1).cloned() { + state.r.insert(p2, a); } + } - _ => {} + OP_NULL => { + // r[p2..p3] = null + let idx_range = if p2 < p3 { p2..=p3 } else { p2..=p2 }; + + for idx in idx_range { + state.r.insert(idx, RegDataType::Single(ColumnType::null())); + } } - } - OP_RESULT_ROW => { - // the second time we hit ResultRow we short-circuit and get out - if result.is_some() { - break; + OP_OR | OP_AND | OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT + | OP_ADD | OP_SUBTRACT | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => { + // r[p3] = r[p1] + r[p2] + match (state.r.get(&p1).cloned(), state.r.get(&p2).cloned()) { + (Some(a), Some(b)) => { + state.r.insert( + p3, + RegDataType::Single(ColumnType { + datatype: if matches!(a.map_to_datatype(), DataType::Null) { + b.map_to_datatype() + } else { + a.map_to_datatype() + }, + nullable: match (a.map_to_nullable(), b.map_to_nullable()) { + (Some(a_n), Some(b_n)) => Some(a_n | b_n), + (Some(a_n), None) => Some(a_n), + (None, Some(b_n)) => Some(b_n), + (None, None) => None, + }, + }), + ); + } + + (Some(v), None) => { + state.r.insert( + p3, + RegDataType::Single(ColumnType { + datatype: v.map_to_datatype(), + nullable: None, + }), + ); + } + + (None, Some(v)) => { + state.r.insert( + p3, + RegDataType::Single(ColumnType { + datatype: v.map_to_datatype(), + nullable: None, + }), + ); + } + + _ => {} + } } - // output = r[p1 .. p1 + p2] - output.reserve(p2 as usize); - nullable.reserve(p2 as usize); + OP_RESULT_ROW => { + // output = r[p1 .. p1 + p2] + state.visited[state.program_i] = true; + state.result = Some( + (p1..p1 + p2) + .map(|i| { + let coltype = state.r.get(&i); + + let sqltype = + coltype.map(|d| d.map_to_datatype()).map(SqliteTypeInfo); + let nullable = + coltype.map(|d| d.map_to_nullable()).unwrap_or_default(); + + (sqltype, nullable) + }) + .collect(), + ); - result = Some(p1..p1 + p2); - } + let program_history: Vec<(i64, String, i64, i64, i64, Vec)> = + state.history.iter().map(|i| program[*i].clone()).collect(); + logger.add_result((program_history, state.result.clone())); + result_states.push(state.clone()); + } + + OP_HALT => { + break; + } - _ => { - // ignore unsupported operations - // if we fail to find an r later, we just give up + _ => { + // ignore unsupported operations + // if we fail to find an r later, we just give up + logger.add_unknown_operation(program[state.program_i].clone()); + } } - } - visited[program_i] = true; - program_i += 1; + state.visited[state.program_i] = true; + state.program_i += 1; + } } - if let Some(result) = result { - for i in result { - output.push(SqliteTypeInfo( - r.remove(&i) - .map(|d| d.map_to_datatype()) - .unwrap_or(DataType::Null), - )); - nullable.push(n.remove(&i)); + let mut output: Vec> = Vec::new(); + let mut nullable: Vec> = Vec::new(); + + while let Some(state) = result_states.pop() { + // find the datatype info from each ResultRow execution + if let Some(result) = state.result { + let mut idx = 0; + for (this_type, this_nullable) in result { + if output.len() == idx { + output.push(this_type); + } else if output[idx].is_none() + || matches!(output[idx], Some(SqliteTypeInfo(DataType::Null))) + { + output[idx] = this_type; + } + + if nullable.len() == idx { + nullable.push(this_nullable); + } else if let Some(ref mut null) = nullable[idx] { + //if any ResultRow's column is nullable, the final result is nullable + if let Some(this_null) = this_nullable { + *null |= this_null; + } + } else { + nullable[idx] = this_nullable; + } + idx += 1; + } } } + let output = output + .into_iter() + .map(|o| o.unwrap_or(SqliteTypeInfo(DataType::Null))) + .collect(); + Ok((output, nullable)) } @@ -438,7 +850,7 @@ fn test_root_block_columns_has_types() { .is_some()); assert!(execute::iter( &mut conn, - r"CREATE TABLE t2(a INTEGER, b_null NUMERIC NULL, b NUMERIC NOT NULL);", + r"CREATE TABLE t2(a INTEGER NOT NULL, b_null NUMERIC NULL, b NUMERIC NOT NULL);", None, false ) @@ -488,39 +900,123 @@ fn test_root_block_columns_has_types() { //prove that each block has the correct information { let blocknum = table_block_nums["t"]; - assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); - assert_eq!((DataType::Text), root_block_cols[&blocknum][&1]); - assert_eq!((DataType::Text), root_block_cols[&blocknum][&2]); + assert_eq!( + ColumnType { + datatype: DataType::Int64, + nullable: Some(true) //sqlite primary key columns are nullable unless declared not null + }, + root_block_cols[&blocknum][&0] + ); + assert_eq!( + ColumnType { + datatype: DataType::Text, + nullable: Some(true) + }, + root_block_cols[&blocknum][&1] + ); + assert_eq!( + ColumnType { + datatype: DataType::Text, + nullable: Some(false) + }, + root_block_cols[&blocknum][&2] + ); } { let blocknum = table_block_nums["i1"]; - assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); - assert_eq!((DataType::Text), root_block_cols[&blocknum][&1]); + assert_eq!( + ColumnType { + datatype: DataType::Int64, + nullable: Some(true) //sqlite primary key columns are nullable unless declared not null + }, + root_block_cols[&blocknum][&0] + ); + assert_eq!( + ColumnType { + datatype: DataType::Text, + nullable: Some(true) + }, + root_block_cols[&blocknum][&1] + ); } { let blocknum = table_block_nums["i2"]; - assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); - assert_eq!((DataType::Text), root_block_cols[&blocknum][&1]); + assert_eq!( + ColumnType { + datatype: DataType::Int64, + nullable: Some(true) //sqlite primary key columns are nullable unless declared not null + }, + root_block_cols[&blocknum][&0] + ); + assert_eq!( + ColumnType { + datatype: DataType::Text, + nullable: Some(true) + }, + root_block_cols[&blocknum][&1] + ); } { let blocknum = table_block_nums["t2"]; - assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); - assert_eq!((DataType::Null), root_block_cols[&blocknum][&1]); - assert_eq!((DataType::Null), root_block_cols[&blocknum][&2]); + assert_eq!( + ColumnType { + datatype: DataType::Int64, + nullable: Some(false) + }, + root_block_cols[&blocknum][&0] + ); + assert_eq!( + ColumnType { + datatype: DataType::Null, + nullable: Some(true) + }, + root_block_cols[&blocknum][&1] + ); + assert_eq!( + ColumnType { + datatype: DataType::Null, + nullable: Some(false) + }, + root_block_cols[&blocknum][&2] + ); } { let blocknum = table_block_nums["t2i1"]; - assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); - assert_eq!((DataType::Null), root_block_cols[&blocknum][&1]); + assert_eq!( + ColumnType { + datatype: DataType::Int64, + nullable: Some(false) + }, + root_block_cols[&blocknum][&0] + ); + assert_eq!( + ColumnType { + datatype: DataType::Null, + nullable: Some(true) + }, + root_block_cols[&blocknum][&1] + ); } { let blocknum = table_block_nums["t2i2"]; - assert_eq!((DataType::Int64), root_block_cols[&blocknum][&0]); - assert_eq!((DataType::Null), root_block_cols[&blocknum][&1]); + assert_eq!( + ColumnType { + datatype: DataType::Int64, + nullable: Some(false) + }, + root_block_cols[&blocknum][&0] + ); + assert_eq!( + ColumnType { + datatype: DataType::Null, + nullable: Some(false) + }, + root_block_cols[&blocknum][&1] + ); } } diff --git a/sqlx-core/src/sqlite/type_info.rs b/sqlx-core/src/sqlite/type_info.rs index 70695ee606..b21bd51f98 100644 --- a/sqlx-core/src/sqlite/type_info.rs +++ b/sqlx-core/src/sqlite/type_info.rs @@ -7,7 +7,7 @@ use libsqlite3_sys::{SQLITE_BLOB, SQLITE_FLOAT, SQLITE_INTEGER, SQLITE_NULL, SQL use crate::error::BoxDynError; use crate::type_info::TypeInfo; -#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub(crate) enum DataType { Null, @@ -29,7 +29,7 @@ pub(crate) enum DataType { } /// Type information for a SQLite type. -#[derive(Debug, Clone, Eq, PartialEq)] +#[derive(Debug, Clone, Eq, PartialEq, Hash)] #[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] pub struct SqliteTypeInfo(pub(crate) DataType); diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index e75d606223..fe42220663 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -44,13 +44,13 @@ async fn it_describes_variables() -> anyhow::Result<()> { let info = conn.describe("SELECT ?1").await?; assert_eq!(info.columns()[0].type_info().name(), "NULL"); - assert_eq!(info.nullable(0), None); // unknown + assert_eq!(info.nullable(0), Some(true)); // nothing prevents the value from being bound to null // context can be provided by using CAST(_ as _) let info = conn.describe("SELECT CAST(?1 AS REAL)").await?; assert_eq!(info.columns()[0].type_info().name(), "REAL"); - assert_eq!(info.nullable(0), None); // unknown + assert_eq!(info.nullable(0), Some(true)); // nothing prevents the value from being bound to null Ok(()) } @@ -60,7 +60,7 @@ async fn it_describes_expression() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef'") + .describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef', null") .await?; let columns = d.columns(); @@ -81,6 +81,10 @@ async fn it_describes_expression() -> anyhow::Result<()> { assert_eq!(columns[3].name(), "x'deadbeef'"); assert_eq!(d.nullable(3), Some(false)); // literal constant + assert_eq!(columns[4].type_info().name(), "NULL"); + assert_eq!(columns[4].name(), "null"); + assert_eq!(d.nullable(4), Some(true)); // literal null + Ok(()) } @@ -99,10 +103,10 @@ async fn it_describes_expression_from_empty_table() -> anyhow::Result<()> { assert_eq!(d.nullable(0), Some(false)); // COUNT(*) assert_eq!(d.columns()[1].type_info().name(), "INTEGER"); - assert_eq!(d.nullable(1), None); // `a + 1` is potentially nullable but we don't know for sure currently + assert_eq!(d.nullable(1), Some(true)); // `a+1` is nullable, because a is nullable assert_eq!(d.columns()[2].type_info().name(), "TEXT"); - assert_eq!(d.nullable(2), Some(false)); // `name` is not nullable + assert_eq!(d.nullable(2), Some(true)); // `name` is not nullable, but the query can be null due to zero rows assert_eq!(d.columns()[3].type_info().name(), "REAL"); assert_eq!(d.nullable(3), Some(false)); // literal constant @@ -256,3 +260,123 @@ async fn it_describes_left_join() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_describes_literal_subquery() -> anyhow::Result<()> { + async fn assert_literal_described( + conn: &mut sqlx::SqliteConnection, + query: &str, + ) -> anyhow::Result<()> { + let info = conn.describe(query).await?; + + assert_eq!(info.column(0).type_info().name(), "TEXT", "{}", query); + assert_eq!(info.nullable(0), Some(false), "{}", query); + assert_eq!(info.column(1).type_info().name(), "NULL", "{}", query); + assert_eq!(info.nullable(1), Some(true), "{}", query); + + Ok(()) + } + + let mut conn = new::().await?; + assert_literal_described(&mut conn, "SELECT 'a', NULL").await?; + assert_literal_described(&mut conn, "SELECT * FROM (SELECT 'a', NULL)").await?; + assert_literal_described( + &mut conn, + "WITH cte AS (SELECT 'a', NULL) SELECT * FROM cte", + ) + .await?; + assert_literal_described( + &mut conn, + "WITH cte AS MATERIALIZED (SELECT 'a', NULL) SELECT * FROM cte", + ) + .await?; + assert_literal_described( + &mut conn, + "WITH RECURSIVE cte(a,b) AS (SELECT 'a', NULL UNION ALL SELECT a||a, NULL FROM cte WHERE length(a)<3) SELECT * FROM cte", + ) + .await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_table_subquery() -> anyhow::Result<()> { + async fn assert_tweet_described( + conn: &mut sqlx::SqliteConnection, + query: &str, + ) -> anyhow::Result<()> { + let info = conn.describe(query).await?; + let columns = info.columns(); + + assert_eq!(columns[0].name(), "id", "{}", query); + assert_eq!(columns[1].name(), "text", "{}", query); + assert_eq!(columns[2].name(), "is_sent", "{}", query); + assert_eq!(columns[3].name(), "owner_id", "{}", query); + + assert_eq!(columns[0].ordinal(), 0, "{}", query); + assert_eq!(columns[1].ordinal(), 1, "{}", query); + assert_eq!(columns[2].ordinal(), 2, "{}", query); + assert_eq!(columns[3].ordinal(), 3, "{}", query); + + assert_eq!(info.nullable(0), Some(false), "{}", query); + assert_eq!(info.nullable(1), Some(false), "{}", query); + assert_eq!(info.nullable(2), Some(false), "{}", query); + assert_eq!(info.nullable(3), Some(true), "{}", query); + + assert_eq!(columns[0].type_info().name(), "INTEGER", "{}", query); + assert_eq!(columns[1].type_info().name(), "TEXT", "{}", query); + assert_eq!(columns[2].type_info().name(), "BOOLEAN", "{}", query); + assert_eq!(columns[3].type_info().name(), "INTEGER", "{}", query); + + Ok(()) + } + + let mut conn = new::().await?; + assert_tweet_described(&mut conn, "SELECT * FROM tweet").await?; + assert_tweet_described(&mut conn, "SELECT * FROM (SELECT * FROM tweet)").await?; + assert_tweet_described( + &mut conn, + "WITH cte AS (SELECT * FROM tweet) SELECT * FROM cte", + ) + .await?; + assert_tweet_described( + &mut conn, + "WITH cte AS MATERIALIZED (SELECT * FROM tweet) SELECT * FROM cte", + ) + .await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_describes_union() -> anyhow::Result<()> { + async fn assert_union_described( + conn: &mut sqlx::SqliteConnection, + query: &str, + ) -> anyhow::Result<()> { + let info = conn.describe(query).await?; + + assert_eq!(info.column(0).type_info().name(), "TEXT", "{}", query); + assert_eq!(info.nullable(0), Some(false), "{}", query); + assert_eq!(info.column(1).type_info().name(), "TEXT", "{}", query); + assert_eq!(info.nullable(1), Some(true), "{}", query); + assert_eq!(info.column(2).type_info().name(), "INTEGER", "{}", query); + assert_eq!(info.nullable(2), Some(true), "{}", query); + //TODO: mixed type columns not handled correctly + //assert_eq!(info.column(3).type_info().name(), "NULL", "{}", query); + //assert_eq!(info.nullable(3), Some(false), "{}", query); + + Ok(()) + } + + let mut conn = new::().await?; + assert_union_described( + &mut conn, + "SELECT 'txt','a',null,'b' UNION ALL SELECT 'int',NULL,1,2 ", + ) + .await?; + //TODO: insert into temp-table not merging datatype/nullable of all operations - currently keeping last-writer + //assert_union_described(&mut conn, "SELECT 'txt','a',null,'b' UNION SELECT 'int',NULL,1,2 ").await?; + + Ok(()) +}