diff --git a/resources/test/fixtures/pyupgrade/UP018.py b/resources/test/fixtures/pyupgrade/UP018.py index c000447c827b9..4f97f3a0bb6d6 100644 --- a/resources/test/fixtures/pyupgrade/UP018.py +++ b/resources/test/fixtures/pyupgrade/UP018.py @@ -7,12 +7,14 @@ str("foo", encoding="UTF-8") str("foo" "bar") +str(b"foo") bytes("foo", encoding="UTF-8") bytes(*a) bytes("foo", *a) bytes("foo", **a) bytes(b"foo" b"bar") +bytes("foo") # These become string or byte literals str() diff --git a/src/pyupgrade/plugins/native_literals.rs b/src/pyupgrade/plugins/native_literals.rs index 88808c87b1e9d..d0366f483dfb0 100644 --- a/src/pyupgrade/plugins/native_literals.rs +++ b/src/pyupgrade/plugins/native_literals.rs @@ -17,21 +17,31 @@ pub fn native_literals( ) { let ExprKind::Name { id, .. } = &func.node else { return; }; - if (id == "str" || id == "bytes") - && keywords.is_empty() - && args.len() <= 1 - && checker.is_builtin(id) - { + if !keywords.is_empty() || args.len() > 1 { + return; + } + + if (id == "str" || id == "bytes") && checker.is_builtin(id) { let Some(arg) = args.get(0) else { - let literal_type = if id == "str" { + let mut check = Check::new(CheckKind::NativeLiterals(if id == "str" { LiteralType::Str } else { LiteralType::Bytes - }; - let mut check = Check::new(CheckKind::NativeLiterals(literal_type), Range::from_located(expr)); + }), Range::from_located(expr)); if checker.patch(&CheckCode::UP018) { check.amend(Fix::replacement( - format!("{}\"\"", if id == "bytes" { "b" } else { "" }), + if id == "bytes" { + let mut content = String::with_capacity(3); + content.push('b'); + content.push(checker.style.quote().into()); + content.push(checker.style.quote().into()); + content + } else { + let mut content = String::with_capacity(2); + content.push(checker.style.quote().into()); + content.push(checker.style.quote().into()); + content + }, expr.location, expr.end_location.unwrap(), )); @@ -40,14 +50,31 @@ pub fn native_literals( return; }; - let ExprKind::Constant { value, ..} = &arg.node else { + // Look for `str("")`. + if id == "str" + && !matches!( + &arg.node, + ExprKind::Constant { + value: Constant::Str(_), + .. + }, + ) + { return; - }; - let literal_type = match value { - Constant::Str { .. } => LiteralType::Str, - Constant::Bytes { .. } => LiteralType::Bytes, - _ => return, - }; + } + + // Look for `bytes(b"")` + if id == "bytes" + && !matches!( + &arg.node, + ExprKind::Constant { + value: Constant::Bytes(_), + .. + }, + ) + { + return; + } // rust-python merges adjacent string/bytes literals into one node, but we can't // safely remove the outer call in this situation. We're following pyupgrade @@ -65,7 +92,11 @@ pub fn native_literals( } let mut check = Check::new( - CheckKind::NativeLiterals(literal_type), + CheckKind::NativeLiterals(if id == "str" { + LiteralType::Str + } else { + LiteralType::Bytes + }), Range::from_located(expr), ); if checker.patch(&CheckCode::UP018) { diff --git a/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP018_UP018.py.snap b/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP018_UP018.py.snap index 2187e42525e55..6c7ff772bcbf3 100644 --- a/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP018_UP018.py.snap +++ b/src/pyupgrade/snapshots/ruff__pyupgrade__tests__UP018_UP018.py.snap @@ -5,103 +5,103 @@ expression: checks - kind: NativeLiterals: Str location: - row: 18 + row: 20 column: 0 end_location: - row: 18 + row: 20 column: 5 fix: content: "\"\"" location: - row: 18 + row: 20 column: 0 end_location: - row: 18 + row: 20 column: 5 parent: ~ - kind: NativeLiterals: Str location: - row: 19 + row: 21 column: 0 end_location: - row: 19 + row: 21 column: 10 fix: content: "\"foo\"" location: - row: 19 + row: 21 column: 0 end_location: - row: 19 + row: 21 column: 10 parent: ~ - kind: NativeLiterals: Str location: - row: 20 + row: 22 column: 0 end_location: - row: 21 + row: 23 column: 7 fix: content: "\"\"\"\nfoo\"\"\"" location: - row: 20 + row: 22 column: 0 end_location: - row: 21 + row: 23 column: 7 parent: ~ - kind: NativeLiterals: Bytes location: - row: 22 + row: 24 column: 0 end_location: - row: 22 + row: 24 column: 7 fix: content: "b\"\"" location: - row: 22 + row: 24 column: 0 end_location: - row: 22 + row: 24 column: 7 parent: ~ - kind: NativeLiterals: Bytes location: - row: 23 + row: 25 column: 0 end_location: - row: 23 + row: 25 column: 13 fix: content: "b\"foo\"" location: - row: 23 + row: 25 column: 0 end_location: - row: 23 + row: 25 column: 13 parent: ~ - kind: NativeLiterals: Bytes location: - row: 24 + row: 26 column: 0 end_location: - row: 25 + row: 27 column: 7 fix: content: "b\"\"\"\nfoo\"\"\"" location: - row: 24 + row: 26 column: 0 end_location: - row: 25 + row: 27 column: 7 parent: ~ diff --git a/src/source_code_style.rs b/src/source_code_style.rs index fe85a9ad980ef..8a83b371cba9e 100644 --- a/src/source_code_style.rs +++ b/src/source_code_style.rs @@ -1,5 +1,6 @@ //! Detect code style from Python source code. +use std::fmt; use std::ops::Deref; use once_cell::unsync::OnceCell; @@ -69,6 +70,24 @@ impl From<&Quote> for vendor::str::Quote { } } +impl fmt::Display for Quote { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Quote::Single => write!(f, "\'"), + Quote::Double => write!(f, "\""), + } + } +} + +impl From<&Quote> for char { + fn from(val: &Quote) -> Self { + match val { + Quote::Single => '\'', + Quote::Double => '"', + } + } +} + /// The indentation style used in Python source code. #[derive(Debug, PartialEq, Eq)] pub struct Indentation(String);