Skip to content

Commit

Permalink
Utf8array casting (#2456)
Browse files Browse the repository at this point in the history
* Implement utf8 to binary in compute::kernels::cast

* Fix clippy lints

* Improve performance

* Fix formatting errors
  • Loading branch information
psvri committed Aug 16, 2022
1 parent 7a74465 commit 6a1b9ee
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 59 deletions.
174 changes: 116 additions & 58 deletions arrow/src/compute/kernels/cast.rs
Expand Up @@ -146,7 +146,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Utf8, LargeUtf8) => true,
(LargeUtf8, Utf8) => true,
(Utf8,
Date32
Binary
| Date32
| Date64
| Time32(TimeUnit::Second)
| Time32(TimeUnit::Millisecond)
Expand All @@ -156,7 +157,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
) => true,
(Utf8, _) => DataType::is_numeric(to_type),
(LargeUtf8,
Date32
LargeBinary
| Date32
| Date64
| Time32(TimeUnit::Second)
| Time32(TimeUnit::Millisecond)
Expand Down Expand Up @@ -693,6 +695,7 @@ pub fn cast_with_options(
Float64 => cast_string_to_numeric::<Float64Type, i32>(array, cast_options),
Date32 => cast_string_to_date32::<i32>(&**array, cast_options),
Date64 => cast_string_to_date64::<i32>(&**array, cast_options),
Binary => cast_string_to_binary(array),
Time32(TimeUnit::Second) => {
cast_string_to_time32second::<i32>(&**array, cast_options)
}
Expand Down Expand Up @@ -839,6 +842,7 @@ pub fn cast_with_options(
Float64 => cast_string_to_numeric::<Float64Type, i64>(array, cast_options),
Date32 => cast_string_to_date32::<i64>(&**array, cast_options),
Date64 => cast_string_to_date64::<i64>(&**array, cast_options),
LargeBinary => cast_string_to_binary(array),
Time32(TimeUnit::Second) => {
cast_string_to_time32second::<i64>(&**array, cast_options)
}
Expand Down Expand Up @@ -1254,6 +1258,41 @@ pub fn cast_with_options(
}
}

/// Cast to string array to binary array
fn cast_string_to_binary(array: &ArrayRef) -> Result<ArrayRef> {
let from_type = array.data_type();
match *from_type {
DataType::Utf8 => {
let data = unsafe {
array
.data()
.clone()
.into_builder()
.data_type(DataType::Binary)
.build_unchecked()
};

Ok(Arc::new(BinaryArray::from(data)) as ArrayRef)
}
DataType::LargeUtf8 => {
let data = unsafe {
array
.data()
.clone()
.into_builder()
.data_type(DataType::LargeBinary)
.build_unchecked()
};

Ok(Arc::new(LargeBinaryArray::from(data)) as ArrayRef)
}
_ => Err(ArrowError::InvalidArgumentError(format!(
"{:?} cannot be converted to binary array",
from_type
))),
}
}

/// Get the time unit as a multiple of a second
const fn time_unit_multiple(unit: &TimeUnit) -> i64 {
match unit {
Expand Down Expand Up @@ -3471,6 +3510,34 @@ mod tests {
}
}

#[test]
fn test_cast_string_to_binary() {
let string_1 = "Hi";
let string_2 = "Hello";

let bytes_1 = string_1.as_bytes();
let bytes_2 = string_2.as_bytes();

let string_data = vec![Some(string_1), Some(string_2), None];
let a1 = Arc::new(StringArray::from(string_data.clone())) as ArrayRef;
let a2 = Arc::new(LargeStringArray::from(string_data)) as ArrayRef;

let mut array_ref = cast(&a1, &DataType::Binary).unwrap();
let down_cast = array_ref.as_any().downcast_ref::<BinaryArray>().unwrap();
assert_eq!(bytes_1, down_cast.value(0));
assert_eq!(bytes_2, down_cast.value(1));
assert!(down_cast.is_null(2));

array_ref = cast(&a2, &DataType::LargeBinary).unwrap();
let down_cast = array_ref
.as_any()
.downcast_ref::<LargeBinaryArray>()
.unwrap();
assert_eq!(bytes_1, down_cast.value(0));
assert_eq!(bytes_2, down_cast.value(1));
assert!(down_cast.is_null(2));
}

#[test]
fn test_cast_date32_to_int32() {
let a = Date32Array::from(vec![10000, 17890]);
Expand Down Expand Up @@ -3688,15 +3755,15 @@ mod tests {
#[test]
fn test_cast_from_f64() {
let f64_values: Vec<f64> = vec![
std::i64::MIN as f64,
std::i32::MIN as f64,
std::i16::MIN as f64,
std::i8::MIN as f64,
i64::MIN as f64,
i32::MIN as f64,
i16::MIN as f64,
i8::MIN as f64,
0_f64,
std::u8::MAX as f64,
std::u16::MAX as f64,
std::u32::MAX as f64,
std::u64::MAX as f64,
u8::MAX as f64,
u16::MAX as f64,
u32::MAX as f64,
u64::MAX as f64,
];
let f64_array: ArrayRef = Arc::new(Float64Array::from(f64_values));

Expand Down Expand Up @@ -3838,15 +3905,15 @@ mod tests {
#[test]
fn test_cast_from_f32() {
let f32_values: Vec<f32> = vec![
std::i32::MIN as f32,
std::i32::MIN as f32,
std::i16::MIN as f32,
std::i8::MIN as f32,
i32::MIN as f32,
i32::MIN as f32,
i16::MIN as f32,
i8::MIN as f32,
0_f32,
std::u8::MAX as f32,
std::u16::MAX as f32,
std::u32::MAX as f32,
std::u32::MAX as f32,
u8::MAX as f32,
u16::MAX as f32,
u32::MAX as f32,
u32::MAX as f32,
];
let f32_array: ArrayRef = Arc::new(Float32Array::from(f32_values));

Expand Down Expand Up @@ -3975,10 +4042,10 @@ mod tests {
fn test_cast_from_uint64() {
let u64_values: Vec<u64> = vec![
0,
std::u8::MAX as u64,
std::u16::MAX as u64,
std::u32::MAX as u64,
std::u64::MAX,
u8::MAX as u64,
u16::MAX as u64,
u32::MAX as u64,
u64::MAX,
];
let u64_array: ArrayRef = Arc::new(UInt64Array::from(u64_values));

Expand Down Expand Up @@ -4054,12 +4121,8 @@ mod tests {

#[test]
fn test_cast_from_uint32() {
let u32_values: Vec<u32> = vec![
0,
std::u8::MAX as u32,
std::u16::MAX as u32,
std::u32::MAX as u32,
];
let u32_values: Vec<u32> =
vec![0, u8::MAX as u32, u16::MAX as u32, u32::MAX as u32];
let u32_array: ArrayRef = Arc::new(UInt32Array::from(u32_values));

let f64_expected = vec!["0.0", "255.0", "65535.0", "4294967295.0"];
Expand Down Expand Up @@ -4125,7 +4188,7 @@ mod tests {

#[test]
fn test_cast_from_uint16() {
let u16_values: Vec<u16> = vec![0, std::u8::MAX as u16, std::u16::MAX as u16];
let u16_values: Vec<u16> = vec![0, u8::MAX as u16, u16::MAX as u16];
let u16_array: ArrayRef = Arc::new(UInt16Array::from(u16_values));

let f64_expected = vec!["0.0", "255.0", "65535.0"];
Expand Down Expand Up @@ -4191,7 +4254,7 @@ mod tests {

#[test]
fn test_cast_from_uint8() {
let u8_values: Vec<u8> = vec![0, std::u8::MAX];
let u8_values: Vec<u8> = vec![0, u8::MAX];
let u8_array: ArrayRef = Arc::new(UInt8Array::from(u8_values));

let f64_expected = vec!["0.0", "255.0"];
Expand Down Expand Up @@ -4258,15 +4321,15 @@ mod tests {
#[test]
fn test_cast_from_int64() {
let i64_values: Vec<i64> = vec![
std::i64::MIN,
std::i32::MIN as i64,
std::i16::MIN as i64,
std::i8::MIN as i64,
i64::MIN,
i32::MIN as i64,
i16::MIN as i64,
i8::MIN as i64,
0,
std::i8::MAX as i64,
std::i16::MAX as i64,
std::i32::MAX as i64,
std::i64::MAX,
i8::MAX as i64,
i16::MAX as i64,
i32::MAX as i64,
i64::MAX,
];
let i64_array: ArrayRef = Arc::new(Int64Array::from(i64_values));

Expand Down Expand Up @@ -4413,13 +4476,13 @@ mod tests {
#[test]
fn test_cast_from_int32() {
let i32_values: Vec<i32> = vec![
std::i32::MIN as i32,
std::i16::MIN as i32,
std::i8::MIN as i32,
i32::MIN as i32,
i16::MIN as i32,
i8::MIN as i32,
0,
std::i8::MAX as i32,
std::i16::MAX as i32,
std::i32::MAX as i32,
i8::MAX as i32,
i16::MAX as i32,
i32::MAX as i32,
];
let i32_array: ArrayRef = Arc::new(Int32Array::from(i32_values));

Expand Down Expand Up @@ -4507,13 +4570,8 @@ mod tests {

#[test]
fn test_cast_from_int16() {
let i16_values: Vec<i16> = vec![
std::i16::MIN,
std::i8::MIN as i16,
0,
std::i8::MAX as i16,
std::i16::MAX,
];
let i16_values: Vec<i16> =
vec![i16::MIN, i8::MIN as i16, 0, i8::MAX as i16, i16::MAX];
let i16_array: ArrayRef = Arc::new(Int16Array::from(i16_values));

let f64_expected = vec!["-32768.0", "-128.0", "0.0", "127.0", "32767.0"];
Expand Down Expand Up @@ -4580,13 +4638,13 @@ mod tests {
#[test]
fn test_cast_from_date32() {
let i32_values: Vec<i32> = vec![
std::i32::MIN as i32,
std::i16::MIN as i32,
std::i8::MIN as i32,
i32::MIN as i32,
i16::MIN as i32,
i8::MIN as i32,
0,
std::i8::MAX as i32,
std::i16::MAX as i32,
std::i32::MAX as i32,
i8::MAX as i32,
i16::MAX as i32,
i32::MAX as i32,
];
let date32_array: ArrayRef = Arc::new(Date32Array::from(i32_values));

Expand All @@ -4607,7 +4665,7 @@ mod tests {

#[test]
fn test_cast_from_int8() {
let i8_values: Vec<i8> = vec![std::i8::MIN, 0, std::i8::MAX];
let i8_values: Vec<i8> = vec![i8::MIN, 0, i8::MAX];
let i8_array: ArrayRef = Arc::new(Int8Array::from(i8_values));

let f64_expected = vec!["-128.0", "0.0", "127.0"];
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/datatypes/datatype.rs
Expand Up @@ -1391,7 +1391,7 @@ impl DataType {
}
}

/// Returns true if this type is numeric: (UInt*, Unit*, or Float*).
/// Returns true if this type is numeric: (UInt*, Int*, or Float*).
pub fn is_numeric(t: &DataType) -> bool {
use DataType::*;
matches!(
Expand Down

0 comments on commit 6a1b9ee

Please sign in to comment.