Skip to content

Commit

Permalink
Pass decimal256 integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 18, 2022
1 parent c585544 commit 884e717
Show file tree
Hide file tree
Showing 15 changed files with 118 additions and 21 deletions.
17 changes: 11 additions & 6 deletions arrow/src/array/array.rs
Expand Up @@ -404,6 +404,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef {
},
DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef,
DataType::Decimal(_, _) => Arc::new(DecimalArray::from(data)) as ArrayRef,
DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef,
dt => panic!("Unexpected data type {:?}", dt),
}
}
Expand Down Expand Up @@ -567,7 +568,10 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef {
)
})
}
DataType::Decimal(_, _) => new_null_sized_decimal(data_type, length),
DataType::Decimal(_, _) => {
new_null_sized_decimal(data_type, length, std::mem::size_of::<i128>())
}
DataType::Decimal256(_, _) => new_null_sized_decimal(data_type, length, 32),
}
}

Expand Down Expand Up @@ -632,18 +636,19 @@ fn new_null_sized_array<T: ArrowPrimitiveType>(
}

#[inline]
fn new_null_sized_decimal(data_type: &DataType, length: usize) -> ArrayRef {
fn new_null_sized_decimal(
data_type: &DataType,
length: usize,
byte_width: usize,
) -> ArrayRef {
make_array(unsafe {
ArrayData::new_unchecked(
data_type.clone(),
length,
Some(length),
Some(MutableBuffer::new_null(length).into()),
0,
vec![Buffer::from(vec![
0u8;
length * std::mem::size_of::<i128>()
])],
vec![Buffer::from(vec![0u8; length * byte_width])],
vec![],
)
})
Expand Down
11 changes: 8 additions & 3 deletions arrow/src/array/array_decimal.rs
Expand Up @@ -192,7 +192,12 @@ pub trait BasicDecimalArray<T: BasicDecimal, U: From<ArrayData>>:

let list_offset = v.offset();
let child_offset = child_data.offset();
let builder = ArrayData::builder(DataType::Decimal(precision, scale))
let data_type = if Self::VALUE_LENGTH == 16 {
DataType::Decimal(precision, scale)
} else {
DataType::Decimal256(precision, scale)
};
let builder = ArrayData::builder(data_type)
.len(v.len())
.add_buffer(child_data.buffers()[0].slice(child_offset))
.null_bit_buffer(v.data_ref().null_buffer().cloned())
Expand Down Expand Up @@ -344,8 +349,8 @@ impl From<ArrayData> for Decimal256Array {
);
let values = data.buffers()[0].as_ptr();
let (precision, scale) = match data.data_type() {
DataType::Decimal(precision, scale) => (*precision, *scale),
_ => panic!("Expected data type to be Decimal"),
DataType::Decimal256(precision, scale) => (*precision, *scale),
_ => panic!("Expected data type to be Decimal256"),
};
Self {
data,
Expand Down
9 changes: 7 additions & 2 deletions arrow/src/array/data.rs
Expand Up @@ -189,7 +189,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff
DataType::FixedSizeList(_, _) | DataType::Struct(_) => {
[empty_buffer, MutableBuffer::new(0)]
}
DataType::Decimal(_, _) => [
DataType::Decimal(_, _) | DataType::Decimal256(_, _) => [
MutableBuffer::new(capacity * mem::size_of::<u8>()),
empty_buffer,
],
Expand Down Expand Up @@ -572,7 +572,8 @@ impl ArrayData {
| DataType::LargeBinary
| DataType::Interval(_)
| DataType::FixedSizeBinary(_)
| DataType::Decimal(_, _) => vec![],
| DataType::Decimal(_, _)
| DataType::Decimal256(_, _) => vec![],
DataType::List(field) => {
vec![Self::new_empty(field.data_type())]
}
Expand Down Expand Up @@ -1307,6 +1308,10 @@ pub(crate) fn layout(data_type: &DataType) -> DataTypeLayout {
// always uses 16 bytes / size of i128
DataTypeLayout::new_fixed_width(size_of::<i128>())
}
DataType::Decimal256(_, _) => {
// Decimals are always some fixed width.
DataTypeLayout::new_fixed_width(32)
}
DataType::Map(_, _) => {
// same as ListType
DataTypeLayout::new_fixed_width(size_of::<i32>())
Expand Down
1 change: 1 addition & 0 deletions arrow/src/array/equal/decimal.rs
Expand Up @@ -30,6 +30,7 @@ pub(super) fn decimal_equal(
) -> bool {
let size = match lhs.data_type() {
DataType::Decimal(_, _) => 16,
DataType::Decimal256(_, _) => 32,
_ => unreachable!(),
};

Expand Down
4 changes: 3 additions & 1 deletion arrow/src/array/equal/mod.rs
Expand Up @@ -186,7 +186,9 @@ fn equal_values(
DataType::FixedSizeBinary(_) => {
fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::Decimal(_, _) => decimal_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Decimal(_, _) | DataType::Decimal256(_, _) => {
decimal_equal(lhs, rhs, lhs_start, rhs_start, len)
}
DataType::List(_) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
DataType::LargeList(_) => list_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
DataType::FixedSizeList(_, _) => {
Expand Down
2 changes: 2 additions & 0 deletions arrow/src/array/transform/fixed_binary.rs
Expand Up @@ -22,6 +22,7 @@ use super::{Extend, _MutableArrayData};
pub(super) fn build_extend(array: &ArrayData) -> Extend {
let size = match array.data_type() {
DataType::FixedSizeBinary(i) => *i as usize,
DataType::Decimal256(_, _) => 32,
_ => unreachable!(),
};

Expand Down Expand Up @@ -57,6 +58,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend {
pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) {
let size = match mutable.data_type {
DataType::FixedSizeBinary(i) => i as usize,
DataType::Decimal256(_, _) => 32,
_ => unreachable!(),
};

Expand Down
9 changes: 7 additions & 2 deletions arrow/src/array/transform/mod.rs
Expand Up @@ -241,7 +241,9 @@ fn build_extend(array: &ArrayData) -> Extend {
DataType::LargeList(_) => list::build_extend::<i64>(array),
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::FixedSizeBinary(_) | DataType::Decimal256(_, _) => {
fixed_binary::build_extend(array)
}
DataType::Float16 => primitive::build_extend::<f16>(array),
DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array),
DataType::Union(_, _, mode) => match mode {
Expand Down Expand Up @@ -292,7 +294,9 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls {
_ => unreachable!(),
},
DataType::Struct(_) => structure::extend_nulls,
DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls,
DataType::FixedSizeBinary(_) | DataType::Decimal256(_, _) => {
fixed_binary::extend_nulls
}
DataType::Float16 => primitive::extend_nulls::<f16>,
DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls,
DataType::Union(_, _, mode) => match mode {
Expand Down Expand Up @@ -407,6 +411,7 @@ impl<'a> MutableArrayData<'a> {

let child_data = match &data_type {
DataType::Decimal(_, _)
| DataType::Decimal256(_, _)
| DataType::Null
| DataType::Boolean
| DataType::UInt8
Expand Down
23 changes: 20 additions & 3 deletions arrow/src/datatypes/datatype.rs
Expand Up @@ -195,6 +195,8 @@ pub enum DataType {
///
/// For example the number 123.45 has precision 5 and scale 2.
Decimal(usize, usize),
/// Exact decimal value with 256 bits width
Decimal256(usize, usize),
/// A Map is a logical nested type that is represented as
///
/// `List<entries: Struct<key: K, value: V>>`
Expand Down Expand Up @@ -406,15 +408,27 @@ impl DataType {
None => Err(ArrowError::ParseError(
"Expecting a precision for decimal".to_string(),
)),
};
}?;
let scale = match map.get("scale") {
Some(s) => Ok(s.as_u64().unwrap() as usize),
_ => Err(ArrowError::ParseError(
"Expecting a scale for decimal".to_string(),
)),
}?;
let bit_width: usize = match map.get("bitWidth") {
Some(b) => b.as_u64().unwrap() as usize,
_ => 128, // Default bit width
};

Ok(DataType::Decimal(precision?, scale?))
if bit_width == 128 {
Ok(DataType::Decimal(precision, scale))
} else if bit_width == 256 {
Ok(DataType::Decimal256(precision, scale))
} else {
Err(ArrowError::ParseError(
"Decimal bit_width invalid".to_string(),
))
}
}
Some(s) if s == "floatingpoint" => match map.get("precision") {
Some(p) if p == "HALF" => Ok(DataType::Float16),
Expand Down Expand Up @@ -695,7 +709,10 @@ impl DataType {
}}),
DataType::Dictionary(_, _) => json!({ "name": "dictionary"}),
DataType::Decimal(precision, scale) => {
json!({"name": "decimal", "precision": precision, "scale": scale})
json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128})
}
DataType::Decimal256(precision, scale) => {
json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256})
}
DataType::Map(_, keys_sorted) => {
json!({"name": "map", "keysSorted": keys_sorted})
Expand Down
3 changes: 2 additions & 1 deletion arrow/src/datatypes/field.rs
Expand Up @@ -675,7 +675,8 @@ impl Field {
| DataType::FixedSizeBinary(_)
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Decimal(_, _) => {
| DataType::Decimal(_, _)
| DataType::Decimal256(_, _) => {
if self.data_type != from.data_type {
return Err(ArrowError::SchemaError(
"Fail to merge schema Field due to conflicting datatype"
Expand Down
20 changes: 19 additions & 1 deletion arrow/src/ipc/convert.rs
Expand Up @@ -320,7 +320,14 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT
}
ipc::Type::Decimal => {
let fsb = field.type_as_decimal().unwrap();
DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize)
let bit_width = fsb.bitWidth();
if bit_width == 128 {
DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize)
} else if bit_width == 256 {
DataType::Decimal256(fsb.precision() as usize, fsb.scale() as usize)
} else {
panic!("Unexpected decimal bit width {}", bit_width)
}
}
ipc::Type::Union => {
let union = field.type_as_union().unwrap();
Expand Down Expand Up @@ -671,6 +678,17 @@ pub(crate) fn get_fb_field_type<'a>(
children: Some(fbb.create_vector(&empty_fields[..])),
}
}
Decimal256(precision, scale) => {
let mut builder = ipc::DecimalBuilder::new(fbb);
builder.add_precision(*precision as i32);
builder.add_scale(*scale as i32);
builder.add_bitWidth(256);
FBFieldType {
type_type: ipc::Type::Decimal,
type_: builder.finish().as_union_value(),
children: Some(fbb.create_vector(&empty_fields[..])),
}
}
Union(fields, type_ids, mode) => {
let mut children = vec![];
for field in fields {
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/ipc/reader.rs
Expand Up @@ -506,7 +506,7 @@ fn create_primitive_array(

unsafe { builder.build_unchecked() }
}
Decimal(_, _) => {
Decimal(_, _) | Decimal256(_, _) => {
// read 3 buffers
let builder = ArrayData::builder(data_type.clone())
.len(length)
Expand Down
1 change: 1 addition & 0 deletions integration-testing/Cargo.toml
Expand Up @@ -44,3 +44,4 @@ serde_json = { version = "1.0", default-features = false, features = ["preserve_
tokio = { version = "1.0", default-features = false }
tonic = { version = "0.7", default-features = false }
tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true }
num = { version = "0.4", default-features = false, features = ["std"] }
34 changes: 34 additions & 0 deletions integration-testing/src/lib.rs
Expand Up @@ -33,6 +33,9 @@ use arrow::{
util::{bit_util, integration_util::*},
};

use arrow::util::decimal::{BasicDecimal, Decimal256};
use num::bigint::BigInt;
use num::Signed;
use std::collections::HashMap;
use std::fs::File;
use std::io::BufReader;
Expand Down Expand Up @@ -611,6 +614,37 @@ fn array_from_json(
}
Ok(Arc::new(b.finish()))
}
DataType::Decimal256(precision, scale) => {
let mut b = Decimal256Builder::new(json_col.count, *precision, *scale);
for (is_valid, value) in json_col
.validity
.as_ref()
.unwrap()
.iter()
.zip(json_col.data.unwrap())
{
match is_valid {
1 => {
let str = value.as_str().unwrap();
let integer = BigInt::parse_bytes(str.as_bytes(), 10).unwrap();
let integer_bytes = integer.to_signed_bytes_le();
let mut bytes = if integer.is_positive() {
[0_u8; 32]
} else {
[255_u8; 32]
};
bytes[0..integer_bytes.len()]
.copy_from_slice(integer_bytes.as_slice());
let decimal =
Decimal256::try_new_from_bytes(*precision, *scale, &bytes)
.unwrap();
b.append_value(&decimal)
}
_ => b.append_null(),
}?;
}
Ok(Arc::new(b.finish()))
}
DataType::Map(child_field, _) => {
let null_buf = create_null_buf(&json_col);
let children = json_col.children.clone().unwrap();
Expand Down
1 change: 1 addition & 0 deletions parquet/src/arrow/arrow_writer/mod.rs
Expand Up @@ -276,6 +276,7 @@ fn write_leaves<W: Write>(
| ArrowDataType::Utf8
| ArrowDataType::LargeUtf8
| ArrowDataType::Decimal(_, _)
| ArrowDataType::Decimal256(_, _)
| ArrowDataType::FixedSizeBinary(_) => {
let mut col_writer = get_col_writer(row_group_writer)?;
for (array, levels) in arrays.iter().zip(levels.iter_mut()) {
Expand Down
2 changes: 1 addition & 1 deletion parquet/src/arrow/schema.rs
Expand Up @@ -380,7 +380,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result<Type> {
.with_length(*length)
.build()
}
DataType::Decimal(precision, scale) => {
DataType::Decimal(precision, scale) | DataType::Decimal256(precision, scale) => {
// Decimal precision determines the Parquet physical type to use.
// TODO(ARROW-12018): Enable the below after ARROW-10818 Decimal support
//
Expand Down

0 comments on commit 884e717

Please sign in to comment.