diff --git a/CHANGELOG.md b/CHANGELOG.md index 01b0494920a8..30cd2d7ca136 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -182,6 +182,8 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * The sqlite backend now uses a single batch insert statement if there are now default values present in the values clause +* The MySQL connection is using the CLIENT_FOUND_ROWS from now on. This means that updating rows without changing any values will return the number of matched rows (like most other SQL servers do), as opposed to the number of changed rows. + ### Fixed * Many types were incorrectly considered non-aggregate when they should not diff --git a/diesel/src/mysql/connection/mod.rs b/diesel/src/mysql/connection/mod.rs index 8ba6eb950a32..8656c06e08ac 100644 --- a/diesel/src/mysql/connection/mod.rs +++ b/diesel/src/mysql/connection/mod.rs @@ -166,4 +166,23 @@ mod tests { assert!(connection.execute("SELECT 1").is_ok()); assert!(connection.execute("SELECT 1").is_ok()); } + + #[test] + fn check_client_found_rows_flag() { + let conn = &mut crate::test_helpers::connection(); + conn.execute("DROP TABLE IF EXISTS update_test CASCADE") + .unwrap(); + + conn.execute("CREATE TABLE update_test(id INTEGER PRIMARY KEY, num INTEGER NOT NULL)") + .unwrap(); + + conn.execute("INSERT INTO update_test(id, num) VALUES (1, 5)") + .unwrap(); + + let output = conn + .execute("UPDATE update_test SET num = 5 WHERE id = 1") + .unwrap(); + + assert_eq!(output, 1); + } } diff --git a/diesel/src/mysql/connection/raw.rs b/diesel/src/mysql/connection/raw.rs index 0ac4041aefab..4c29c930a872 100644 --- a/diesel/src/mysql/connection/raw.rs +++ b/diesel/src/mysql/connection/raw.rs @@ -46,6 +46,7 @@ impl RawConnection { let database = connection_options.database(); let port = connection_options.port(); let unix_socket = connection_options.unix_socket(); + let client_flags = connection_options.client_flags(); unsafe { // Make sure you don't use the fake one! @@ -63,7 +64,7 @@ impl RawConnection { unix_socket .map(CStr::as_ptr) .unwrap_or_else(|| ptr::null_mut()), - 0, + client_flags.bits().into(), ) }; diff --git a/diesel/src/mysql/connection/url.rs b/diesel/src/mysql/connection/url.rs index 7e3566777f3b..897cc920494a 100644 --- a/diesel/src/mysql/connection/url.rs +++ b/diesel/src/mysql/connection/url.rs @@ -8,6 +8,36 @@ use std::ffi::{CStr, CString}; use crate::result::{ConnectionError, ConnectionResult}; +bitflags::bitflags! { + pub struct CapabilityFlags: u32 { + const CLIENT_LONG_PASSWORD = 0x00000001; + const CLIENT_FOUND_ROWS = 0x00000002; + const CLIENT_LONG_FLAG = 0x00000004; + const CLIENT_CONNECT_WITH_DB = 0x00000008; + const CLIENT_NO_SCHEMA = 0x00000010; + const CLIENT_COMPRESS = 0x00000020; + const CLIENT_ODBC = 0x00000040; + const CLIENT_LOCAL_FILES = 0x00000080; + const CLIENT_IGNORE_SPACE = 0x00000100; + const CLIENT_PROTOCOL_41 = 0x00000200; + const CLIENT_INTERACTIVE = 0x00000400; + const CLIENT_SSL = 0x00000800; + const CLIENT_IGNORE_SIGPIPE = 0x00001000; + const CLIENT_TRANSACTIONS = 0x00002000; + const CLIENT_RESERVED = 0x00004000; + const CLIENT_SECURE_CONNECTION = 0x00008000; + const CLIENT_MULTI_STATEMENTS = 0x00010000; + const CLIENT_MULTI_RESULTS = 0x00020000; + const CLIENT_PS_MULTI_RESULTS = 0x00040000; + const CLIENT_PLUGIN_AUTH = 0x00080000; + const CLIENT_CONNECT_ATTRS = 0x00100000; + const CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000; + const CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS = 0x00400000; + const CLIENT_SESSION_TRACK = 0x00800000; + const CLIENT_DEPRECATE_EOF = 0x01000000; + } +} + pub struct ConnectionOptions { host: Option, user: CString, @@ -15,6 +45,7 @@ pub struct ConnectionOptions { database: Option, port: Option, unix_socket: Option, + client_flags: CapabilityFlags, } impl ConnectionOptions { @@ -59,6 +90,9 @@ impl ConnectionOptions { Some(segment) => Some(CString::new(segment.as_bytes())?), }; + // this is not present in the database_url, using a default value + let client_flags = CapabilityFlags::CLIENT_FOUND_ROWS; + Ok(ConnectionOptions { host: host, user: user, @@ -66,6 +100,7 @@ impl ConnectionOptions { database: database, port: url.port(), unix_socket: unix_socket, + client_flags: client_flags, }) } @@ -92,6 +127,10 @@ impl ConnectionOptions { pub fn unix_socket(&self) -> Option<&CStr> { self.unix_socket.as_deref() } + + pub fn client_flags(&self) -> CapabilityFlags { + self.client_flags + } } fn decode_into_cstring(s: &str) -> ConnectionResult {