diff --git a/src/statement/input.rs b/src/statement/input.rs index a0def6b2..60034f5d 100644 --- a/src/statement/input.rs +++ b/src/statement/input.rs @@ -1,6 +1,7 @@ -use {ffi, Return, Result, Raii, Handle, Statement}; use super::types::OdbcType; use odbc_safe::AutocommitMode; +use statement::types::EncodedValue; +use {ffi, Handle, Raii, Result, Return, Statement}; impl<'a, 'b, S, R, AC: AutocommitMode> Statement<'a, 'b, S, R, AC> { /// Binds a parameter to a parameter marker in an SQL statement. @@ -48,9 +49,19 @@ impl<'a, 'b, S, R, AC: AutocommitMode> Statement<'a, 'b, S, R, AC> { let ind_ptr = self.param_ind_buffers.alloc(parameter_index as usize, ind); + //the result of value_ptr is changed per calling. + //binding and saving must have the same value. + let enc_value = value.encoded_value(); + self.raii - .bind_input_parameter(parameter_index, value, ind_ptr) + .bind_input_parameter(parameter_index, value, ind_ptr, &enc_value) .into_result(&self)?; + + // save encoded value to avoid memory reuse. + if enc_value.has_value() { + self.encoded_values.push(enc_value); + } + Ok(self) } @@ -58,17 +69,31 @@ impl<'a, 'b, S, R, AC: AutocommitMode> Statement<'a, 'b, S, R, AC> { /// and returns a new one those lifetime is no longer limited by the buffers bound. pub fn reset_parameters(mut self) -> Result> { self.param_ind_buffers.clear(); + self.encoded_values.clear(); self.raii.reset_parameters().into_result(&mut self)?; Ok(Statement::with_raii(self.raii)) } } impl Raii { - fn bind_input_parameter<'c, T>(&mut self, parameter_index: u16, value: &'c T, str_len_or_ind_ptr: *mut ffi::SQLLEN) -> Return<()> + fn bind_input_parameter<'c, T>( + &mut self, + parameter_index: u16, + value: &'c T, + str_len_or_ind_ptr: *mut ffi::SQLLEN, + enc_value: &EncodedValue, + ) -> Return<()> where T: OdbcType<'c>, T: ?Sized, { + //if encoded value exists, use it. + let (column_size, value_ptr) = if enc_value.has_value() { + (enc_value.column_size(), enc_value.value_ptr()) + } else { + (value.column_size(), value.value_ptr()) + }; + match unsafe { ffi::SQLBindParameter( self.handle(), @@ -76,10 +101,10 @@ impl Raii { ffi::SQL_PARAM_INPUT, T::c_data_type(), T::sql_data_type(), - value.column_size(), + column_size, value.decimal_digits(), - value.value_ptr(), - 0, // buffer length + value_ptr, + 0, // buffer length str_len_or_ind_ptr, // Note that this ptr has to be valid until statement is executed ) } { diff --git a/src/statement/mod.rs b/src/statement/mod.rs index 32b98b34..b2ec314d 100644 --- a/src/statement/mod.rs +++ b/src/statement/mod.rs @@ -9,7 +9,7 @@ use ffi::SQLRETURN::*; use ffi::Nullable; use std::marker::PhantomData; pub use self::types::OdbcType; -pub use self::types::{SqlDate, SqlTime, SqlSsTime2, SqlTimestamp}; +pub use self::types::{SqlDate, SqlTime, SqlSsTime2, SqlTimestamp, EncodedValue}; // Allocate CHUNK_LEN elements at a time const CHUNK_LEN: usize = 64; @@ -76,6 +76,8 @@ pub struct Statement<'con, 'b, S, R, AC: AutocommitMode> { result: PhantomData, parameters: PhantomData<&'b [u8]>, param_ind_buffers: Chunks, + // encoded values are saved to use its pointer. + encoded_values: Vec, } /// Used to retrieve data from the fields of a query result @@ -108,7 +110,8 @@ impl<'a, 'b, S, R, AC: AutocommitMode> Statement<'a, 'b, S, R, AC> { state: PhantomData, result: PhantomData, parameters: PhantomData, - param_ind_buffers: Chunks::new() + param_ind_buffers: Chunks::new(), + encoded_values: Vec::new(), } } } diff --git a/src/statement/types.rs b/src/statement/types.rs index 52d3c2b4..710e22c9 100644 --- a/src/statement/types.rs +++ b/src/statement/types.rs @@ -4,6 +4,36 @@ use std::mem::{size_of, transmute}; use std::ffi::CString; use std::borrow::Cow::{Borrowed, Owned}; +pub struct EncodedValue { + pub buf: Option>, +} + +impl EncodedValue { + pub fn new(buf: Option>) -> Self { + Self { buf } + } + + pub fn has_value(&self) -> bool { + self.buf.is_some() + } + + pub fn column_size(&self) -> ffi::SQLULEN { + if let Some(buf) = &self.buf { + buf.len() as ffi::SQLULEN + } else { + 0 + } + } + + pub fn value_ptr(&self) -> ffi::SQLPOINTER { + if let Some(buf) = &self.buf { + buf.as_ptr() as *const Self as ffi::SQLPOINTER + } else { + 0 as *const Self as ffi::SQLPOINTER + } + } +} + pub unsafe trait OdbcType<'a>: Sized { fn sql_data_type() -> ffi::SqlDataType; fn c_data_type() -> ffi::SqlCDataType; @@ -16,6 +46,7 @@ pub unsafe trait OdbcType<'a>: Sized { fn decimal_digits(&self) -> ffi::SQLSMALLINT { 0 } + fn encoded_value(&self) -> EncodedValue; } unsafe impl<'a> OdbcType<'a> for &'a[u8] { @@ -37,6 +68,10 @@ unsafe impl<'a> OdbcType<'a> for &'a[u8] { fn value_ptr(&self) -> ffi::SQLPOINTER { self.as_ptr() as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for Vec { @@ -58,6 +93,10 @@ unsafe impl<'a> OdbcType<'a> for Vec { fn value_ptr(&self) -> ffi::SQLPOINTER { self.as_ptr() as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for &'a[u16] { @@ -83,6 +122,10 @@ unsafe impl<'a> OdbcType<'a> for &'a[u16] { fn value_ptr(&self) -> ffi::SQLPOINTER { self.as_ptr() as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for Vec { @@ -109,6 +152,10 @@ unsafe impl<'a> OdbcType<'a> for Vec { fn value_ptr(&self) -> ffi::SQLPOINTER { self.as_ptr() as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for CString { @@ -134,6 +181,10 @@ unsafe impl<'a> OdbcType<'a> for CString { fn null_bytes_count() -> usize { 1 } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for String { @@ -159,6 +210,10 @@ unsafe impl<'a> OdbcType<'a> for String { fn null_bytes_count() -> usize { 1 } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(Some(unsafe { ::environment::DB_ENCODING }.encode(&self).0.to_vec())) + } } unsafe impl<'a> OdbcType<'a> for &'a str { @@ -188,6 +243,10 @@ unsafe impl<'a> OdbcType<'a> for &'a str { fn null_bytes_count() -> usize { 1 } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(Some(unsafe { ::environment::DB_ENCODING }.encode(&self).0.to_vec())) + } } unsafe impl<'a> OdbcType<'a> for ::std::borrow::Cow<'a, str> { @@ -213,6 +272,10 @@ unsafe impl<'a> OdbcType<'a> for ::std::borrow::Cow<'a, str> { fn null_bytes_count() -> usize { 1 } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(Some(unsafe { ::environment::DB_ENCODING }.encode(self).0.to_vec())) + } } fn convert_primitive(buf: &[u8]) -> T @@ -240,6 +303,10 @@ unsafe impl<'a> OdbcType<'a> for u8 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for i8 { @@ -260,6 +327,10 @@ unsafe impl<'a> OdbcType<'a> for i8 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for i16 { @@ -280,6 +351,10 @@ unsafe impl<'a> OdbcType<'a> for i16 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for u16 { @@ -300,6 +375,10 @@ unsafe impl<'a> OdbcType<'a> for u16 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for i32 { @@ -320,6 +399,10 @@ unsafe impl<'a> OdbcType<'a> for i32 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for u32 { @@ -340,6 +423,10 @@ unsafe impl<'a> OdbcType<'a> for u32 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for i64 { @@ -360,6 +447,10 @@ unsafe impl<'a> OdbcType<'a> for i64 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for u64 { @@ -380,6 +471,10 @@ unsafe impl<'a> OdbcType<'a> for u64 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for f32 { @@ -400,6 +495,10 @@ unsafe impl<'a> OdbcType<'a> for f32 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for f64 { @@ -420,6 +519,10 @@ unsafe impl<'a> OdbcType<'a> for f64 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a> OdbcType<'a> for bool { @@ -441,6 +544,10 @@ unsafe impl<'a> OdbcType<'a> for bool { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } pub type SqlDate = ffi::SQL_DATE_STRUCT; @@ -467,6 +574,10 @@ unsafe impl<'a> OdbcType<'a> for SqlDate { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } pub type SqlTime = ffi::SQL_TIME_STRUCT; @@ -493,6 +604,10 @@ unsafe impl<'a> OdbcType<'a> for SqlTime { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } pub type SqlTimestamp = ffi::SQL_TIMESTAMP_STRUCT; @@ -519,6 +634,10 @@ unsafe impl<'a> OdbcType<'a> for SqlTimestamp { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } pub type SqlSsTime2 = ffi::SQL_SS_TIME2_STRUCT; @@ -546,6 +665,10 @@ unsafe impl<'a> OdbcType<'a> for SqlSsTime2 { fn value_ptr(&self) -> ffi::SQLPOINTER { self as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } unsafe impl<'a, T> OdbcType<'a> for Option where T: OdbcType<'a> { @@ -579,4 +702,80 @@ unsafe impl<'a, T> OdbcType<'a> for Option where T: OdbcType<'a> { fn null_bytes_count() -> usize { T::null_bytes_count() } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } + + +mod test { + // use environment::create_environment_v3_with_os_db_encoding; + use super::*; + use std::collections::HashSet; + use std::borrow::Cow; + + #[test] + fn encoded_value_test() { + let mut checker = HashSet::new(); + let mut encoded_values = Vec::new(); + + // let _ = create_environment_v3_with_os_db_encoding("utf8", "sjis"); + + //string test + for i in 0..10 { + for h in 0..10 { + let string_value = format!("{}{}", i, h); + // println!("org value => {} address => {:?}", string_value, string_value.value_ptr()); + + let enc = string_value.encoded_value(); + // println!("{} {:?}", enc.column_size(), enc.buf); + if checker.len() == 0 || !checker.contains(&enc.value_ptr()) { + checker.insert(enc.value_ptr()); + encoded_values.push(enc); + } else { + panic!("same address occur!"); + } + } + } + checker.clear(); + encoded_values.clear(); + + //&str test + for i in 0..10 { + for h in 0..10 { + let str_value: &str = &format!("{}{}", i, h); + // println!("org value => {} address => {:?}", str_value, str_value.value_ptr()); + + let enc = str_value.encoded_value(); + if checker.len() == 0 || !checker.contains(&enc.value_ptr()) { + checker.insert(enc.value_ptr()); + encoded_values.push(enc); + } else { + panic!("same address occur!"); + } + } + } + checker.clear(); + encoded_values.clear(); + + //Cow test + for i in 0..10 { + for h in 0..10 { + let cow_value: Cow = Cow::from(format!("{}{}", i, h)); + // println!("org value => {} address => {:?}", cow_value, cow_value.value_ptr()); + + let enc = cow_value.encoded_value(); + if checker.len() == 0 || !checker.contains(&enc.value_ptr()) { + checker.insert(enc.value_ptr()); + encoded_values.push(enc); + } else { + panic!("same address occur!"); + } + } + } + checker.clear(); + encoded_values.clear(); + + } +} \ No newline at end of file diff --git a/tests/gbk.rs b/tests/gbk.rs index cc283a2d..b463183c 100644 --- a/tests/gbk.rs +++ b/tests/gbk.rs @@ -150,6 +150,10 @@ unsafe impl<'a> OdbcType<'a> for CustomOdbcType<'a> { fn value_ptr(&self) -> ffi::SQLPOINTER { self.data.as_ptr() as *const Self as ffi::SQLPOINTER } + + fn encoded_value(&self) -> EncodedValue { + EncodedValue::new(None) + } } #[test]