From df4c2c982ce020fb915c29d4a91d52ba9243bb57 Mon Sep 17 00:00:00 2001 From: zhaokaiyuan Date: Tue, 12 Feb 2019 11:08:02 +0800 Subject: [PATCH 1/5] add encoding_rs --- Cargo.toml | 2 ++ src/diagnostics.rs | 14 ++++++-------- src/environment/list_data_sources.rs | 29 ++++++++++++++-------------- src/environment/mod.rs | 12 ++++++++++++ src/lib.rs | 1 + src/statement/mod.rs | 16 +++++++-------- src/statement/prepare.rs | 6 ++++-- src/statement/types.rs | 11 ++++++----- 8 files changed, 54 insertions(+), 37 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4ce3ab53..508e0f21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ travis = [] odbc-sys = "0.6.3" odbc-safe = "0.4.1" log = "0.4.1" +encoding_rs = "0.8.14" + [dev-dependencies] chrono = "0.4" diff --git a/src/diagnostics.rs b/src/diagnostics.rs index 3bcdbd87..ed263f0c 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -17,6 +17,7 @@ pub struct DiagnosticRecord { // The numbers of characters in message not nul message_length: ffi::SQLSMALLINT, native_error: ffi::SQLINTEGER, + message_string: String, } impl DiagnosticRecord { @@ -41,6 +42,7 @@ impl DiagnosticRecord { message: [0u8; MAX_DIAGNOSTIC_MESSAGE_SIZE], native_error: -1, message_length: message.len() as ffi::SQLSMALLINT, + message_string: String::from(""), }; rec.message[..message.len()].copy_from_slice(message); rec @@ -51,16 +53,13 @@ impl fmt::Display for DiagnosticRecord { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // Todo: replace unwrap with `?` in Rust 1.17 let state = CStr::from_bytes_with_nul(&self.state).unwrap(); - let message = CStr::from_bytes_with_nul( - &self.message[0..(self.message_length as usize + 1)], - ).unwrap(); write!( f, "State: {}, Native error: {}, Message: {}", state.to_str().unwrap(), self.native_error, - message.to_str().unwrap() + self.message_string, ) } } @@ -73,10 +72,7 @@ impl fmt::Debug for DiagnosticRecord { impl Error for DiagnosticRecord { fn description(&self) -> &str { - CStr::from_bytes_with_nul(&self.message[0..(self.message_length as usize + 1)]) - .unwrap() - .to_str() - .unwrap() + &self.message_string } fn cause(&self) -> Option<&Error> { None @@ -113,6 +109,7 @@ where native_error: result.native_error, message_length, message, + message_string: unsafe { ::environment::ENCODING.decode(&message[0..(message_length as usize + 1)]).0.into_owned()}, }) } NoData(()) => None, @@ -133,6 +130,7 @@ mod test { message: [0u8; MAX_DIAGNOSTIC_MESSAGE_SIZE], native_error: 0, message_length: 0, + message_string: String::from(""), } } } diff --git a/src/environment/list_data_sources.rs b/src/environment/list_data_sources.rs index 33ab648f..3b607d26 100644 --- a/src/environment/list_data_sources.rs +++ b/src/environment/list_data_sources.rs @@ -1,7 +1,6 @@ use super::{safe, try_into_option, Environment, DiagnosticRecord, GetDiagRec, Result, Version3}; use ffi; use std::collections::HashMap; -use std::str::from_utf8; use std::cmp::max; /// Holds name and description of a datasource @@ -73,8 +72,8 @@ impl Environment { )? { driver_list.push(DriverInfo { - description: desc.to_owned(), - attributes: Self::parse_attributes(attr), + description: desc.into_owned(), + attributes: Self::parse_attributes(&attr), }) } } @@ -124,10 +123,10 @@ impl Environment { &mut name_buffer, &mut description_buffer, )? - { - source_list.push(DataSourceInfo { - server_name: name.to_owned(), - driver: desc.to_owned(), + { + source_list.push(DataSourceInfo { + server_name: name.into_owned(), + driver: desc.into_owned(), }) } else { return Ok(source_list); @@ -142,8 +141,8 @@ impl Environment { )? { source_list.push(DataSourceInfo { - server_name: name.to_owned(), - driver: desc.to_owned(), + server_name: name.into_owned(), + driver: desc.into_owned(), }) } } @@ -159,13 +158,15 @@ impl Environment { direction: ffi::FetchOrientation, buf1: &'a mut [u8], buf2: &'b mut [u8], - ) -> Result> { + ) -> Result, ::std::borrow::Cow<'b, str>)>> { let result = f(&mut self.safe, direction, buf1, buf2); match try_into_option(result, self)? { - Some((len1, len2)) => Ok(Some(( - from_utf8(&buf1[0..(len1 as usize)]).unwrap(), - from_utf8(&buf2[0..(len2 as usize)]).unwrap(), - ))), + Some((len1, len2)) => unsafe { + Ok(Some(( + ::environment::ENCODING.decode(&buf1[0..(len1 as usize)]).0, + ::environment::ENCODING.decode(&buf2[0..(len2 as usize)]).0, + ))) + } None => Ok(None), } } diff --git a/src/environment/mod.rs b/src/environment/mod.rs index 5996a510..4dae7701 100644 --- a/src/environment/mod.rs +++ b/src/environment/mod.rs @@ -7,6 +7,8 @@ use std; /// Environment state used to represent that environment has been set to odbc version 3 pub type Version3 = safe::Odbc3; +pub static mut ENCODING : &encoding_rs::Encoding = encoding_rs::UTF_8; + /// Handle to an ODBC Environment /// /// Creating an instance of this type is the first thing you do then using ODBC. The environment @@ -94,3 +96,13 @@ pub fn create_environment_v3() { Environment::new() } + + +pub fn create_environment_v3_with_os_encoding(encoding: String) + -> std::result::Result, Option> +{ + unsafe { + ENCODING = encoding_rs::Encoding::for_label(encoding.as_bytes()).unwrap(); + } + Environment::new() +} diff --git a/src/lib.rs b/src/lib.rs index e6f1eefe..c0ab9f6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,7 @@ #[macro_use] extern crate log; extern crate odbc_safe; +extern crate encoding_rs; pub mod ffi; diff --git a/src/statement/mod.rs b/src/statement/mod.rs index fb7f6ee7..42b8b7a8 100644 --- a/src/statement/mod.rs +++ b/src/statement/mod.rs @@ -289,9 +289,8 @@ impl Raii { &mut nullable as *mut ffi::Nullable, ) { SQL_SUCCESS => Return::Success(ColumnDescriptor { - name: ::std::str::from_utf8(&name_buffer[..(name_length as usize)]) - .unwrap() - .to_owned(), + name: ::environment::ENCODING.decode(&name_buffer[..(name_length as usize)]).0 + .to_string(), data_type: data_type, column_size: if column_size == 0 { None @@ -310,9 +309,8 @@ impl Raii { }, }), SQL_SUCCESS_WITH_INFO => Return::SuccessWithInfo(ColumnDescriptor { - name: ::std::str::from_utf8(&name_buffer[..(name_length as usize)]) - .unwrap() - .to_owned(), + name: ::environment::ENCODING.decode(&name_buffer[..(name_length as usize)]).0 + .to_string(), data_type: data_type, column_size: if column_size == 0 { None @@ -338,14 +336,16 @@ impl Raii { } fn exec_direct(&mut self, statement_text: &str) -> Return { - let length = statement_text.len(); + let bytes = unsafe { crate::environment::ENCODING }.encode(statement_text).0; + + let length = bytes.len(); if length > ffi::SQLINTEGER::max_value() as usize { panic!("Statement text too long"); } match unsafe { ffi::SQLExecDirect( self.handle(), - statement_text.as_ptr(), + bytes.as_ptr(), length as ffi::SQLINTEGER, ) } { diff --git a/src/statement/prepare.rs b/src/statement/prepare.rs index e16345bd..fde58daa 100644 --- a/src/statement/prepare.rs +++ b/src/statement/prepare.rs @@ -99,11 +99,13 @@ impl<'a, 'b> Statement<'a, 'b, Prepared, NoResult> { impl Raii { fn prepare(&mut self, sql_text: &str) -> Return<()> { + + let bytes = unsafe { crate::environment::ENCODING }.encode(sql_text).0; match unsafe { ffi::SQLPrepare( self.handle(), - sql_text.as_bytes().as_ptr(), - sql_text.as_bytes().len() as ffi::SQLINTEGER, + bytes.as_ptr(), + bytes.len() as ffi::SQLINTEGER, ) } { ffi::SQL_SUCCESS => Return::Success(()), diff --git a/src/statement/types.rs b/src/statement/types.rs index b50e4ba5..39f80a0f 100644 --- a/src/statement/types.rs +++ b/src/statement/types.rs @@ -1,5 +1,4 @@ use ffi; -use std::str::from_utf8; use std::slice::from_raw_parts; use std::mem::{size_of, transmute}; use std::ffi::CString; @@ -145,7 +144,8 @@ unsafe impl<'a> OdbcType<'a> for String { } fn convert(buffer: &'a [u8]) -> Self { - from_utf8(buffer).unwrap().to_owned() + unsafe {::environment::ENCODING.decode(buffer).0} + .to_string() } fn column_size(&self) -> ffi::SQLULEN { @@ -161,7 +161,7 @@ unsafe impl<'a> OdbcType<'a> for String { } } -unsafe impl<'a> OdbcType<'a> for &'a str { +unsafe impl<'a> OdbcType<'a> for ::std::borrow::Cow<'a, str> { fn sql_data_type() -> ffi::SqlDataType { ffi::SQL_VARCHAR } @@ -170,7 +170,8 @@ unsafe impl<'a> OdbcType<'a> for &'a str { } fn convert(buffer: &'a [u8]) -> Self { - from_utf8(buffer).unwrap() + let aa = unsafe {::environment::ENCODING.decode(buffer).0}; + aa } fn column_size(&self) -> ffi::SQLULEN { @@ -550,4 +551,4 @@ unsafe impl<'a, T> OdbcType<'a> for Option where T: OdbcType<'a> { fn null_bytes_count() -> usize { T::null_bytes_count() } -} \ No newline at end of file +} From 9b6cdaa845e87cebf9dd486adf762699cc48fe98 Mon Sep 17 00:00:00 2001 From: zhaokaiyuan Date: Tue, 12 Feb 2019 15:52:30 +0800 Subject: [PATCH 2/5] add db_encoding --- src/diagnostics.rs | 2 +- src/environment/list_data_sources.rs | 6 ++++-- src/environment/mod.rs | 8 +++++--- src/statement/mod.rs | 6 +++--- src/statement/prepare.rs | 3 +-- src/statement/types.rs | 14 ++++++-------- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/diagnostics.rs b/src/diagnostics.rs index ed263f0c..f2a60276 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -109,7 +109,7 @@ where native_error: result.native_error, message_length, message, - message_string: unsafe { ::environment::ENCODING.decode(&message[0..(message_length as usize + 1)]).0.into_owned()}, + message_string: unsafe { ::environment::OS_ENCODING.decode(&message[0..(message_length as usize + 1)]).0.into_owned()}, }) } NoData(()) => None, diff --git a/src/environment/list_data_sources.rs b/src/environment/list_data_sources.rs index 3b607d26..e2da09ed 100644 --- a/src/environment/list_data_sources.rs +++ b/src/environment/list_data_sources.rs @@ -163,8 +163,8 @@ impl Environment { match try_into_option(result, self)? { Some((len1, len2)) => unsafe { Ok(Some(( - ::environment::ENCODING.decode(&buf1[0..(len1 as usize)]).0, - ::environment::ENCODING.decode(&buf2[0..(len2 as usize)]).0, + ::environment::DB_ENCODING.decode(&buf1[0..(len1 as usize)]).0, + ::environment::DB_ENCODING.decode(&buf2[0..(len2 as usize)]).0, ))) } None => Ok(None), @@ -191,6 +191,7 @@ impl Environment { ); loop { + match result { safe::ReturnOption::Success((buf1_length_out, buf2_length_out)) | safe::ReturnOption::Info((buf1_length_out, buf2_length_out)) => { @@ -220,6 +221,7 @@ impl Environment { #[cfg(test)] mod test { + use super::*; #[test] diff --git a/src/environment/mod.rs b/src/environment/mod.rs index 4dae7701..04cf5e27 100644 --- a/src/environment/mod.rs +++ b/src/environment/mod.rs @@ -7,7 +7,8 @@ use std; /// Environment state used to represent that environment has been set to odbc version 3 pub type Version3 = safe::Odbc3; -pub static mut ENCODING : &encoding_rs::Encoding = encoding_rs::UTF_8; +pub static mut OS_ENCODING: &encoding_rs::Encoding = encoding_rs::UTF_8; +pub static mut DB_ENCODING: &encoding_rs::Encoding = encoding_rs::UTF_8; /// Handle to an ODBC Environment /// @@ -98,11 +99,12 @@ pub fn create_environment_v3() } -pub fn create_environment_v3_with_os_encoding(encoding: String) +pub fn create_environment_v3_with_os_db_encoding(os_encoding: &str, db_encoding: &str) -> std::result::Result, Option> { unsafe { - ENCODING = encoding_rs::Encoding::for_label(encoding.as_bytes()).unwrap(); + OS_ENCODING = encoding_rs::Encoding::for_label(os_encoding.as_bytes()).unwrap(); + DB_ENCODING = encoding_rs::Encoding::for_label(db_encoding.as_bytes()).unwrap(); } Environment::new() } diff --git a/src/statement/mod.rs b/src/statement/mod.rs index 42b8b7a8..78042ce9 100644 --- a/src/statement/mod.rs +++ b/src/statement/mod.rs @@ -289,7 +289,7 @@ impl Raii { &mut nullable as *mut ffi::Nullable, ) { SQL_SUCCESS => Return::Success(ColumnDescriptor { - name: ::environment::ENCODING.decode(&name_buffer[..(name_length as usize)]).0 + name: ::environment::DB_ENCODING.decode(&name_buffer[..(name_length as usize)]).0 .to_string(), data_type: data_type, column_size: if column_size == 0 { @@ -309,7 +309,7 @@ impl Raii { }, }), SQL_SUCCESS_WITH_INFO => Return::SuccessWithInfo(ColumnDescriptor { - name: ::environment::ENCODING.decode(&name_buffer[..(name_length as usize)]).0 + name: ::environment::DB_ENCODING.decode(&name_buffer[..(name_length as usize)]).0 .to_string(), data_type: data_type, column_size: if column_size == 0 { @@ -336,7 +336,7 @@ impl Raii { } fn exec_direct(&mut self, statement_text: &str) -> Return { - let bytes = unsafe { crate::environment::ENCODING }.encode(statement_text).0; + let bytes = unsafe { crate::environment::DB_ENCODING }.encode(statement_text).0; let length = bytes.len(); if length > ffi::SQLINTEGER::max_value() as usize { diff --git a/src/statement/prepare.rs b/src/statement/prepare.rs index fde58daa..223b726c 100644 --- a/src/statement/prepare.rs +++ b/src/statement/prepare.rs @@ -99,8 +99,7 @@ impl<'a, 'b> Statement<'a, 'b, Prepared, NoResult> { impl Raii { fn prepare(&mut self, sql_text: &str) -> Return<()> { - - let bytes = unsafe { crate::environment::ENCODING }.encode(sql_text).0; + let bytes = unsafe { crate::environment::DB_ENCODING }.encode(sql_text).0; match unsafe { ffi::SQLPrepare( self.handle(), diff --git a/src/statement/types.rs b/src/statement/types.rs index 39f80a0f..3fce3748 100644 --- a/src/statement/types.rs +++ b/src/statement/types.rs @@ -144,16 +144,15 @@ unsafe impl<'a> OdbcType<'a> for String { } fn convert(buffer: &'a [u8]) -> Self { - unsafe {::environment::ENCODING.decode(buffer).0} - .to_string() + unsafe { ::environment::DB_ENCODING }.decode(buffer).0.to_string() } fn column_size(&self) -> ffi::SQLULEN { - self.as_bytes().len() as ffi::SQLULEN + unsafe { ::environment::DB_ENCODING }.encode(&self).0.len() as ffi::SQLULEN } fn value_ptr(&self) -> ffi::SQLPOINTER { - self.as_bytes().as_ptr() as *const Self as ffi::SQLPOINTER + unsafe { ::environment::DB_ENCODING }.encode(&self).0.as_ptr() as *const Self as ffi::SQLPOINTER } fn null_bytes_count() -> usize { @@ -170,16 +169,15 @@ unsafe impl<'a> OdbcType<'a> for ::std::borrow::Cow<'a, str> { } fn convert(buffer: &'a [u8]) -> Self { - let aa = unsafe {::environment::ENCODING.decode(buffer).0}; - aa + unsafe {::environment::DB_ENCODING.decode(buffer).0} } fn column_size(&self) -> ffi::SQLULEN { - self.as_bytes().len() as ffi::SQLULEN + unsafe { ::environment::DB_ENCODING }.encode(self).0.len() as ffi::SQLULEN } fn value_ptr(&self) -> ffi::SQLPOINTER { - self.as_bytes().as_ptr() as *const Self as ffi::SQLPOINTER + unsafe { ::environment::DB_ENCODING }.encode(self).0.as_ptr() as *const Self as ffi::SQLPOINTER } fn null_bytes_count() -> usize { From b70765dc42c3f1e51457564eef2216049e009119 Mon Sep 17 00:00:00 2001 From: Konstantin Salikhov Date: Thu, 27 Jun 2019 11:10:43 +0400 Subject: [PATCH 3/5] Allow to use &str as input bindings --- src/statement/types.rs | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/statement/types.rs b/src/statement/types.rs index 3fce3748..d6762d61 100644 --- a/src/statement/types.rs +++ b/src/statement/types.rs @@ -2,6 +2,7 @@ use ffi; use std::slice::from_raw_parts; use std::mem::{size_of, transmute}; use std::ffi::CString; +use std::borrow::Cow::{Borrowed, Owned}; pub unsafe trait OdbcType<'a>: Sized { fn sql_data_type() -> ffi::SqlDataType; @@ -160,6 +161,35 @@ unsafe impl<'a> OdbcType<'a> for String { } } +unsafe impl<'a> OdbcType<'a> for &'a str { + fn sql_data_type() -> ffi::SqlDataType { + ffi::SQL_VARCHAR + } + fn c_data_type() -> ffi::SqlCDataType { + ffi::SQL_C_CHAR + } + + fn convert(buffer: &'a [u8]) -> Self { + let cow = unsafe { ::environment::DB_ENCODING }.decode(buffer).0; + match cow { + Borrowed(strref) => strref, + Owned(_string) => unimplemented!(), + } + } + + fn column_size(&self) -> ffi::SQLULEN { + unsafe { ::environment::DB_ENCODING }.encode(self).0.len() as ffi::SQLULEN + } + + fn value_ptr(&self) -> ffi::SQLPOINTER { + unsafe { ::environment::DB_ENCODING }.encode(self).0.as_ptr() as *const Self as ffi::SQLPOINTER + } + + fn null_bytes_count() -> usize { + 1 + } +} + unsafe impl<'a> OdbcType<'a> for ::std::borrow::Cow<'a, str> { fn sql_data_type() -> ffi::SqlDataType { ffi::SQL_VARCHAR From 21cd552e707eae9e342dfee51ea78215d1a44407 Mon Sep 17 00:00:00 2001 From: Konstantin Salikhov Date: Thu, 27 Jun 2019 11:41:43 +0400 Subject: [PATCH 4/5] Fix failing test --- src/diagnostics.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diagnostics.rs b/src/diagnostics.rs index f2a60276..5618c622 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -142,6 +142,7 @@ mod test { let message = b"[Microsoft][ODBC Driver Manager] Function sequence error\0"; let mut rec = DiagnosticRecord::new(); rec.state = b"HY010\0".clone(); + rec.message_string = CStr::from_bytes_with_nul(message).unwrap().to_str().unwrap().to_owned(); rec.message_length = 56; for i in 0..(rec.message_length as usize) { rec.message[i] = message[i]; From 2d2253d6d5620f94f3eb6c0a415146dac61b7735 Mon Sep 17 00:00:00 2001 From: Konstantin Salikhov Date: Thu, 27 Jun 2019 12:10:15 +0400 Subject: [PATCH 5/5] Prevent trailing zero byte --- src/diagnostics.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diagnostics.rs b/src/diagnostics.rs index 5618c622..542d68b6 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -109,7 +109,9 @@ where native_error: result.native_error, message_length, message, - message_string: unsafe { ::environment::OS_ENCODING.decode(&message[0..(message_length as usize + 1)]).0.into_owned()}, + message_string: unsafe { + ::environment::OS_ENCODING.decode(&message[0..message_length as usize]).0.into_owned() + }, }) } NoData(()) => None,