Skip to content

Commit

Permalink
Fix cast kernel and take kernel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Oct 6, 2022
1 parent 21e6dec commit 102f86f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
2 changes: 1 addition & 1 deletion arrow-array/src/array/primitive_array.rs
Expand Up @@ -1355,7 +1355,7 @@ mod tests {

#[test]
#[should_panic(
expected = "PrimitiveArray expected ArrayData with type Int64 got Int32"
expected = "PrimitiveArray expected ArrayData with type Int64 got Int32"
)]
fn test_from_array_data_validation() {
let foo = PrimitiveArray::<Int32Type>::from_iter([1, 2, 3]);
Expand Down
23 changes: 11 additions & 12 deletions arrow/src/compute/kernels/cast.rs
Expand Up @@ -43,12 +43,12 @@ use std::str;
use std::sync::Arc;

use crate::buffer::MutableBuffer;
use crate::compute::divide_scalar;
use crate::compute::kernels::arithmetic::{divide, multiply};
use crate::compute::kernels::arity::unary;
use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
use crate::compute::kernels::temporal::extract_component_from_array;
use crate::compute::kernels::temporal::return_compute_error_with;
use crate::compute::{divide_scalar, multiply_scalar};
use crate::compute::{try_unary, using_chrono_tz_and_utc_naive_date_time};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -1241,14 +1241,14 @@ pub fn cast_with_options(
}
//(Time32(TimeUnit::Second), Time64(_)) => {},
(Time32(from_unit), Time64(to_unit)) => {
let time_array = Int32Array::from(array.data().clone());
let array = cast_with_options(array, &Int32, cast_options)?;
let time_array = as_primitive_array::<Int32Type>(array.as_ref());
// note: (numeric_cast + SIMD multiply) is faster than (cast & multiply)
let c: Int64Array = numeric_cast(&time_array);
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// from is only smaller than to if 64milli/64second don't exist
let mult = Int64Array::from(vec![to_size / from_size; array.len()]);
let converted = multiply(&c, &mult)?;
let converted = multiply_scalar(&c, to_size / from_size)?;
let array_ref = Arc::new(converted) as ArrayRef;
use TimeUnit::*;
match to_unit {
Expand Down Expand Up @@ -1284,7 +1284,8 @@ pub fn cast_with_options(
Ok(Arc::new(values) as ArrayRef)
}
(Time64(from_unit), Time32(to_unit)) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
let divisor = from_size / to_size;
Expand Down Expand Up @@ -1321,18 +1322,16 @@ pub fn cast_with_options(
}
}
(Timestamp(from_unit, _), Timestamp(to_unit, _)) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit);
let to_size = time_unit_multiple(to_unit);
// we either divide or multiply, depending on size of each unit
// units are never the same when the types are the same
let converted = if from_size >= to_size {
divide_scalar(&time_array, from_size / to_size)?
} else {
multiply(
&time_array,
&Int64Array::from(vec![to_size / from_size; array.len()]),
)?
multiply_scalar(&time_array, to_size / from_size)?
};
let array_ref = Arc::new(converted) as ArrayRef;
use TimeUnit::*;
Expand All @@ -1355,10 +1354,10 @@ pub fn cast_with_options(
}
}
(Timestamp(from_unit, _), Date32) => {
let time_array = Int64Array::from(array.data().clone());
let array = cast_with_options(array, &Int64, cast_options)?;
let time_array = as_primitive_array::<Int64Type>(array.as_ref());
let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY;

// Int32Array::from_iter(tim.iter)
let mut b = Date32Builder::with_capacity(array.len());

for i in 0..array.len() {
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/compute/kernels/take.rs
Expand Up @@ -1398,7 +1398,7 @@ mod tests {
fn test_take_bool_nullable_index() {
// indices where the masked invalid elements would be out of bounds
let index_data = ArrayData::try_new(
DataType::Int32,
DataType::UInt32,
6,
Some(Buffer::from_iter(vec![
false, true, false, true, false, true,
Expand All @@ -1421,7 +1421,7 @@ mod tests {
fn test_take_bool_nullable_index_nonnull_values() {
// indices where the masked invalid elements would be out of bounds
let index_data = ArrayData::try_new(
DataType::Int32,
DataType::UInt32,
6,
Some(Buffer::from_iter(vec![
false, true, false, true, false, true,
Expand Down

0 comments on commit 102f86f

Please sign in to comment.