diff --git a/Cargo.toml b/Cargo.toml index 4ce3ab53..508e0f21 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,8 @@ travis = [] odbc-sys = "0.6.3" odbc-safe = "0.4.1" log = "0.4.1" +encoding_rs = "0.8.14" + [dev-dependencies] chrono = "0.4" diff --git a/src/diagnostics.rs b/src/diagnostics.rs index 3bcdbd87..542d68b6 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -17,6 +17,7 @@ pub struct DiagnosticRecord { // The numbers of characters in message not nul message_length: ffi::SQLSMALLINT, native_error: ffi::SQLINTEGER, + message_string: String, } impl DiagnosticRecord { @@ -41,6 +42,7 @@ impl DiagnosticRecord { message: [0u8; MAX_DIAGNOSTIC_MESSAGE_SIZE], native_error: -1, message_length: message.len() as ffi::SQLSMALLINT, + message_string: String::from(""), }; rec.message[..message.len()].copy_from_slice(message); rec @@ -51,16 +53,13 @@ impl fmt::Display for DiagnosticRecord { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // Todo: replace unwrap with `?` in Rust 1.17 let state = CStr::from_bytes_with_nul(&self.state).unwrap(); - let message = CStr::from_bytes_with_nul( - &self.message[0..(self.message_length as usize + 1)], - ).unwrap(); write!( f, "State: {}, Native error: {}, Message: {}", state.to_str().unwrap(), self.native_error, - message.to_str().unwrap() + self.message_string, ) } } @@ -73,10 +72,7 @@ impl fmt::Debug for DiagnosticRecord { impl Error for DiagnosticRecord { fn description(&self) -> &str { - CStr::from_bytes_with_nul(&self.message[0..(self.message_length as usize + 1)]) - .unwrap() - .to_str() - .unwrap() + &self.message_string } fn cause(&self) -> Option<&Error> { None @@ -113,6 +109,9 @@ where native_error: result.native_error, message_length, message, + message_string: unsafe { + ::environment::OS_ENCODING.decode(&message[0..message_length as usize]).0.into_owned() + }, }) } NoData(()) => None, @@ -133,6 +132,7 @@ mod test { message: [0u8; MAX_DIAGNOSTIC_MESSAGE_SIZE], native_error: 0, message_length: 0, + message_string: String::from(""), } } } @@ -144,6 +144,7 @@ mod test { let message = b"[Microsoft][ODBC Driver Manager] Function sequence error\0"; let mut rec = DiagnosticRecord::new(); rec.state = b"HY010\0".clone(); + rec.message_string = CStr::from_bytes_with_nul(message).unwrap().to_str().unwrap().to_owned(); rec.message_length = 56; for i in 0..(rec.message_length as usize) { rec.message[i] = message[i]; diff --git a/src/environment/list_data_sources.rs b/src/environment/list_data_sources.rs index 33ab648f..e2da09ed 100644 --- a/src/environment/list_data_sources.rs +++ b/src/environment/list_data_sources.rs @@ -1,7 +1,6 @@ use super::{safe, try_into_option, Environment, DiagnosticRecord, GetDiagRec, Result, Version3}; use ffi; use std::collections::HashMap; -use std::str::from_utf8; use std::cmp::max; /// Holds name and description of a datasource @@ -73,8 +72,8 @@ impl Environment { )? { driver_list.push(DriverInfo { - description: desc.to_owned(), - attributes: Self::parse_attributes(attr), + description: desc.into_owned(), + attributes: Self::parse_attributes(&attr), }) } } @@ -124,10 +123,10 @@ impl Environment { &mut name_buffer, &mut description_buffer, )? - { - source_list.push(DataSourceInfo { - server_name: name.to_owned(), - driver: desc.to_owned(), + { + source_list.push(DataSourceInfo { + server_name: name.into_owned(), + driver: desc.into_owned(), }) } else { return Ok(source_list); @@ -142,8 +141,8 @@ impl Environment { )? { source_list.push(DataSourceInfo { - server_name: name.to_owned(), - driver: desc.to_owned(), + server_name: name.into_owned(), + driver: desc.into_owned(), }) } } @@ -159,13 +158,15 @@ impl Environment { direction: ffi::FetchOrientation, buf1: &'a mut [u8], buf2: &'b mut [u8], - ) -> Result> { + ) -> Result, ::std::borrow::Cow<'b, str>)>> { let result = f(&mut self.safe, direction, buf1, buf2); match try_into_option(result, self)? { - Some((len1, len2)) => Ok(Some(( - from_utf8(&buf1[0..(len1 as usize)]).unwrap(), - from_utf8(&buf2[0..(len2 as usize)]).unwrap(), - ))), + Some((len1, len2)) => unsafe { + Ok(Some(( + ::environment::DB_ENCODING.decode(&buf1[0..(len1 as usize)]).0, + ::environment::DB_ENCODING.decode(&buf2[0..(len2 as usize)]).0, + ))) + } None => Ok(None), } } @@ -190,6 +191,7 @@ impl Environment { ); loop { + match result { safe::ReturnOption::Success((buf1_length_out, buf2_length_out)) | safe::ReturnOption::Info((buf1_length_out, buf2_length_out)) => { @@ -219,6 +221,7 @@ impl Environment { #[cfg(test)] mod test { + use super::*; #[test] diff --git a/src/environment/mod.rs b/src/environment/mod.rs index 5996a510..04cf5e27 100644 --- a/src/environment/mod.rs +++ b/src/environment/mod.rs @@ -7,6 +7,9 @@ use std; /// Environment state used to represent that environment has been set to odbc version 3 pub type Version3 = safe::Odbc3; +pub static mut OS_ENCODING: &encoding_rs::Encoding = encoding_rs::UTF_8; +pub static mut DB_ENCODING: &encoding_rs::Encoding = encoding_rs::UTF_8; + /// Handle to an ODBC Environment /// /// Creating an instance of this type is the first thing you do then using ODBC. The environment @@ -94,3 +97,14 @@ pub fn create_environment_v3() { Environment::new() } + + +pub fn create_environment_v3_with_os_db_encoding(os_encoding: &str, db_encoding: &str) + -> std::result::Result, Option> +{ + unsafe { + OS_ENCODING = encoding_rs::Encoding::for_label(os_encoding.as_bytes()).unwrap(); + DB_ENCODING = encoding_rs::Encoding::for_label(db_encoding.as_bytes()).unwrap(); + } + Environment::new() +} diff --git a/src/lib.rs b/src/lib.rs index e6f1eefe..c0ab9f6c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,7 @@ #[macro_use] extern crate log; extern crate odbc_safe; +extern crate encoding_rs; pub mod ffi; diff --git a/src/statement/mod.rs b/src/statement/mod.rs index fb7f6ee7..78042ce9 100644 --- a/src/statement/mod.rs +++ b/src/statement/mod.rs @@ -289,9 +289,8 @@ impl Raii { &mut nullable as *mut ffi::Nullable, ) { SQL_SUCCESS => Return::Success(ColumnDescriptor { - name: ::std::str::from_utf8(&name_buffer[..(name_length as usize)]) - .unwrap() - .to_owned(), + name: ::environment::DB_ENCODING.decode(&name_buffer[..(name_length as usize)]).0 + .to_string(), data_type: data_type, column_size: if column_size == 0 { None @@ -310,9 +309,8 @@ impl Raii { }, }), SQL_SUCCESS_WITH_INFO => Return::SuccessWithInfo(ColumnDescriptor { - name: ::std::str::from_utf8(&name_buffer[..(name_length as usize)]) - .unwrap() - .to_owned(), + name: ::environment::DB_ENCODING.decode(&name_buffer[..(name_length as usize)]).0 + .to_string(), data_type: data_type, column_size: if column_size == 0 { None @@ -338,14 +336,16 @@ impl Raii { } fn exec_direct(&mut self, statement_text: &str) -> Return { - let length = statement_text.len(); + let bytes = unsafe { crate::environment::DB_ENCODING }.encode(statement_text).0; + + let length = bytes.len(); if length > ffi::SQLINTEGER::max_value() as usize { panic!("Statement text too long"); } match unsafe { ffi::SQLExecDirect( self.handle(), - statement_text.as_ptr(), + bytes.as_ptr(), length as ffi::SQLINTEGER, ) } { diff --git a/src/statement/prepare.rs b/src/statement/prepare.rs index e16345bd..223b726c 100644 --- a/src/statement/prepare.rs +++ b/src/statement/prepare.rs @@ -99,11 +99,12 @@ impl<'a, 'b> Statement<'a, 'b, Prepared, NoResult> { impl Raii { fn prepare(&mut self, sql_text: &str) -> Return<()> { + let bytes = unsafe { crate::environment::DB_ENCODING }.encode(sql_text).0; match unsafe { ffi::SQLPrepare( self.handle(), - sql_text.as_bytes().as_ptr(), - sql_text.as_bytes().len() as ffi::SQLINTEGER, + bytes.as_ptr(), + bytes.len() as ffi::SQLINTEGER, ) } { ffi::SQL_SUCCESS => Return::Success(()), diff --git a/src/statement/types.rs b/src/statement/types.rs index b50e4ba5..d6762d61 100644 --- a/src/statement/types.rs +++ b/src/statement/types.rs @@ -1,8 +1,8 @@ use ffi; -use std::str::from_utf8; use std::slice::from_raw_parts; use std::mem::{size_of, transmute}; use std::ffi::CString; +use std::borrow::Cow::{Borrowed, Owned}; pub unsafe trait OdbcType<'a>: Sized { fn sql_data_type() -> ffi::SqlDataType; @@ -145,15 +145,15 @@ unsafe impl<'a> OdbcType<'a> for String { } fn convert(buffer: &'a [u8]) -> Self { - from_utf8(buffer).unwrap().to_owned() + unsafe { ::environment::DB_ENCODING }.decode(buffer).0.to_string() } fn column_size(&self) -> ffi::SQLULEN { - self.as_bytes().len() as ffi::SQLULEN + unsafe { ::environment::DB_ENCODING }.encode(&self).0.len() as ffi::SQLULEN } fn value_ptr(&self) -> ffi::SQLPOINTER { - self.as_bytes().as_ptr() as *const Self as ffi::SQLPOINTER + unsafe { ::environment::DB_ENCODING }.encode(&self).0.as_ptr() as *const Self as ffi::SQLPOINTER } fn null_bytes_count() -> usize { @@ -170,15 +170,44 @@ unsafe impl<'a> OdbcType<'a> for &'a str { } fn convert(buffer: &'a [u8]) -> Self { - from_utf8(buffer).unwrap() + let cow = unsafe { ::environment::DB_ENCODING }.decode(buffer).0; + match cow { + Borrowed(strref) => strref, + Owned(_string) => unimplemented!(), + } } fn column_size(&self) -> ffi::SQLULEN { - self.as_bytes().len() as ffi::SQLULEN + unsafe { ::environment::DB_ENCODING }.encode(self).0.len() as ffi::SQLULEN } fn value_ptr(&self) -> ffi::SQLPOINTER { - self.as_bytes().as_ptr() as *const Self as ffi::SQLPOINTER + unsafe { ::environment::DB_ENCODING }.encode(self).0.as_ptr() as *const Self as ffi::SQLPOINTER + } + + fn null_bytes_count() -> usize { + 1 + } +} + +unsafe impl<'a> OdbcType<'a> for ::std::borrow::Cow<'a, str> { + fn sql_data_type() -> ffi::SqlDataType { + ffi::SQL_VARCHAR + } + fn c_data_type() -> ffi::SqlCDataType { + ffi::SQL_C_CHAR + } + + fn convert(buffer: &'a [u8]) -> Self { + unsafe {::environment::DB_ENCODING.decode(buffer).0} + } + + fn column_size(&self) -> ffi::SQLULEN { + unsafe { ::environment::DB_ENCODING }.encode(self).0.len() as ffi::SQLULEN + } + + fn value_ptr(&self) -> ffi::SQLPOINTER { + unsafe { ::environment::DB_ENCODING }.encode(self).0.as_ptr() as *const Self as ffi::SQLPOINTER } fn null_bytes_count() -> usize { @@ -550,4 +579,4 @@ unsafe impl<'a, T> OdbcType<'a> for Option where T: OdbcType<'a> { fn null_bytes_count() -> usize { T::null_bytes_count() } -} \ No newline at end of file +}