diff --git a/arrow/src/array/array.rs b/arrow/src/array/array.rs index d29cf839d9e..3d8fdd70b30 100644 --- a/arrow/src/array/array.rs +++ b/arrow/src/array/array.rs @@ -404,6 +404,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { }, DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, DataType::Decimal(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, + DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, dt => panic!("Unexpected data type {:?}", dt), } } @@ -567,7 +568,10 @@ pub fn new_null_array(data_type: &DataType, length: usize) -> ArrayRef { ) }) } - DataType::Decimal(_, _) => new_null_sized_decimal(data_type, length), + DataType::Decimal(_, _) => { + new_null_sized_decimal(data_type, length, std::mem::size_of::()) + } + DataType::Decimal256(_, _) => new_null_sized_decimal(data_type, length, 32), } } @@ -632,7 +636,11 @@ fn new_null_sized_array( } #[inline] -fn new_null_sized_decimal(data_type: &DataType, length: usize) -> ArrayRef { +fn new_null_sized_decimal( + data_type: &DataType, + length: usize, + byte_width: usize, +) -> ArrayRef { make_array(unsafe { ArrayData::new_unchecked( data_type.clone(), @@ -640,10 +648,7 @@ fn new_null_sized_decimal(data_type: &DataType, length: usize) -> ArrayRef { Some(length), Some(MutableBuffer::new_null(length).into()), 0, - vec![Buffer::from(vec![ - 0u8; - length * std::mem::size_of::() - ])], + vec![Buffer::from(vec![0u8; length * byte_width])], vec![], ) }) diff --git a/arrow/src/array/array_decimal.rs b/arrow/src/array/array_decimal.rs index 261d811f9d7..f4219265878 100644 --- a/arrow/src/array/array_decimal.rs +++ b/arrow/src/array/array_decimal.rs @@ -163,10 +163,12 @@ pub trait BasicDecimalArray>: v.value_length(), Self::VALUE_LENGTH, ); - let builder = v - .into_data() - .into_builder() - .data_type(DataType::Decimal(precision, scale)); + let data_type = if Self::VALUE_LENGTH == 16 { + DataType::Decimal(precision, scale) + } else { + DataType::Decimal256(precision, scale) + }; + let builder = v.into_data().into_builder().data_type(data_type); let array_data = unsafe { builder.build_unchecked() }; U::from(array_data) @@ -197,7 +199,12 @@ pub trait BasicDecimalArray>: let list_offset = v.offset(); let child_offset = child_data.offset(); - let builder = ArrayData::builder(DataType::Decimal(precision, scale)) + let data_type = if Self::VALUE_LENGTH == 16 { + DataType::Decimal(precision, scale) + } else { + DataType::Decimal256(precision, scale) + }; + let builder = ArrayData::builder(data_type) .len(v.len()) .add_buffer(child_data.buffers()[0].slice(child_offset)) .null_bit_buffer(v.data_ref().null_buffer().cloned()) @@ -349,8 +356,8 @@ impl From for Decimal256Array { ); let values = data.buffers()[0].as_ptr(); let (precision, scale) = match data.data_type() { - DataType::Decimal(precision, scale) => (*precision, *scale), - _ => panic!("Expected data type to be Decimal"), + DataType::Decimal256(precision, scale) => (*precision, *scale), + _ => panic!("Expected data type to be Decimal256"), }; Self { data, diff --git a/arrow/src/array/builder/decimal_builder.rs b/arrow/src/array/builder/decimal_builder.rs index e5dfa32f029..d78396be49d 100644 --- a/arrow/src/array/builder/decimal_builder.rs +++ b/arrow/src/array/builder/decimal_builder.rs @@ -269,7 +269,7 @@ mod tests { let decimal_array: Decimal256Array = builder.finish(); - assert_eq!(&DataType::Decimal(40, 6), decimal_array.data_type()); + assert_eq!(&DataType::Decimal256(40, 6), decimal_array.data_type()); assert_eq!(4, decimal_array.len()); assert_eq!(1, decimal_array.null_count()); assert_eq!(64, decimal_array.value_offset(2)); diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 5c7bd69d817..a06cea96e31 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -189,7 +189,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff DataType::FixedSizeList(_, _) | DataType::Struct(_) => { [empty_buffer, MutableBuffer::new(0)] } - DataType::Decimal(_, _) => [ + DataType::Decimal(_, _) | DataType::Decimal256(_, _) => [ MutableBuffer::new(capacity * mem::size_of::()), empty_buffer, ], @@ -572,7 +572,8 @@ impl ArrayData { | DataType::LargeBinary | DataType::Interval(_) | DataType::FixedSizeBinary(_) - | DataType::Decimal(_, _) => vec![], + | DataType::Decimal(_, _) + | DataType::Decimal256(_, _) => vec![], DataType::List(field) => { vec![Self::new_empty(field.data_type())] } @@ -1307,6 +1308,10 @@ pub(crate) fn layout(data_type: &DataType) -> DataTypeLayout { // always uses 16 bytes / size of i128 DataTypeLayout::new_fixed_width(size_of::()) } + DataType::Decimal256(_, _) => { + // Decimals are always some fixed width. + DataTypeLayout::new_fixed_width(32) + } DataType::Map(_, _) => { // same as ListType DataTypeLayout::new_fixed_width(size_of::()) diff --git a/arrow/src/array/equal/decimal.rs b/arrow/src/array/equal/decimal.rs index e9879f3f281..7c44037be39 100644 --- a/arrow/src/array/equal/decimal.rs +++ b/arrow/src/array/equal/decimal.rs @@ -30,6 +30,7 @@ pub(super) fn decimal_equal( ) -> bool { let size = match lhs.data_type() { DataType::Decimal(_, _) => 16, + DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index b8a7bc1bc5e..28ec8d7cf98 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -187,7 +187,9 @@ fn equal_values( DataType::FixedSizeBinary(_) => { fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Decimal(_, _) => decimal_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal(_, _) | DataType::Decimal256(_, _) => { + decimal_equal(lhs, rhs, lhs_start, rhs_start, len) + } DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::FixedSizeList(_, _) => { diff --git a/arrow/src/array/transform/fixed_binary.rs b/arrow/src/array/transform/fixed_binary.rs index 36952d46a4d..6d6262ca3c4 100644 --- a/arrow/src/array/transform/fixed_binary.rs +++ b/arrow/src/array/transform/fixed_binary.rs @@ -22,6 +22,7 @@ use super::{Extend, _MutableArrayData}; pub(super) fn build_extend(array: &ArrayData) -> Extend { let size = match array.data_type() { DataType::FixedSizeBinary(i) => *i as usize, + DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; @@ -57,6 +58,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { let size = match mutable.data_type { DataType::FixedSizeBinary(i) => i as usize, + DataType::Decimal256(_, _) => 32, _ => unreachable!(), }; diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index 5c15503a9db..2cd30f1d1b3 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -241,7 +241,9 @@ fn build_extend(array: &ArrayData) -> Extend { DataType::LargeList(_) => list::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), - DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), + DataType::FixedSizeBinary(_) | DataType::Decimal256(_, _) => { + fixed_binary::build_extend(array) + } DataType::Float16 => primitive::build_extend::(array), DataType::FixedSizeList(_, _) => fixed_size_list::build_extend(array), DataType::Union(_, _, mode) => match mode { @@ -292,7 +294,9 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { _ => unreachable!(), }, DataType::Struct(_) => structure::extend_nulls, - DataType::FixedSizeBinary(_) => fixed_binary::extend_nulls, + DataType::FixedSizeBinary(_) | DataType::Decimal256(_, _) => { + fixed_binary::extend_nulls + } DataType::Float16 => primitive::extend_nulls::, DataType::FixedSizeList(_, _) => fixed_size_list::extend_nulls, DataType::Union(_, _, mode) => match mode { @@ -407,6 +411,7 @@ impl<'a> MutableArrayData<'a> { let child_data = match &data_type { DataType::Decimal(_, _) + | DataType::Decimal256(_, _) | DataType::Null | DataType::Boolean | DataType::UInt8 diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/datatype.rs index d65915bd7ad..8f787b97a90 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/datatype.rs @@ -195,6 +195,8 @@ pub enum DataType { /// /// For example the number 123.45 has precision 5 and scale 2. Decimal(usize, usize), + /// Exact decimal value with 256 bits width + Decimal256(usize, usize), /// A Map is a logical nested type that is represented as /// /// `List>` @@ -406,15 +408,27 @@ impl DataType { None => Err(ArrowError::ParseError( "Expecting a precision for decimal".to_string(), )), - }; + }?; let scale = match map.get("scale") { Some(s) => Ok(s.as_u64().unwrap() as usize), _ => Err(ArrowError::ParseError( "Expecting a scale for decimal".to_string(), )), + }?; + let bit_width: usize = match map.get("bitWidth") { + Some(b) => b.as_u64().unwrap() as usize, + _ => 128, // Default bit width }; - Ok(DataType::Decimal(precision?, scale?)) + if bit_width == 128 { + Ok(DataType::Decimal(precision, scale)) + } else if bit_width == 256 { + Ok(DataType::Decimal256(precision, scale)) + } else { + Err(ArrowError::ParseError( + "Decimal bit_width invalid".to_string(), + )) + } } Some(s) if s == "floatingpoint" => match map.get("precision") { Some(p) if p == "HALF" => Ok(DataType::Float16), @@ -695,7 +709,10 @@ impl DataType { }}), DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), DataType::Decimal(precision, scale) => { - json!({"name": "decimal", "precision": precision, "scale": scale}) + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128}) + } + DataType::Decimal256(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 256}) } DataType::Map(_, keys_sorted) => { json!({"name": "map", "keysSorted": keys_sorted}) diff --git a/arrow/src/datatypes/field.rs b/arrow/src/datatypes/field.rs index ade48d93dab..42fb8ce1db9 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow/src/datatypes/field.rs @@ -675,7 +675,8 @@ impl Field { | DataType::FixedSizeBinary(_) | DataType::Utf8 | DataType::LargeUtf8 - | DataType::Decimal(_, _) => { + | DataType::Decimal(_, _) + | DataType::Decimal256(_, _) => { if self.data_type != from.data_type { return Err(ArrowError::SchemaError( "Fail to merge schema Field due to conflicting datatype" diff --git a/arrow/src/ipc/convert.rs b/arrow/src/ipc/convert.rs index c81ea8278c4..dbbb6b961a1 100644 --- a/arrow/src/ipc/convert.rs +++ b/arrow/src/ipc/convert.rs @@ -320,7 +320,14 @@ pub(crate) fn get_data_type(field: ipc::Field, may_be_dictionary: bool) -> DataT } ipc::Type::Decimal => { let fsb = field.type_as_decimal().unwrap(); - DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) + let bit_width = fsb.bitWidth(); + if bit_width == 128 { + DataType::Decimal(fsb.precision() as usize, fsb.scale() as usize) + } else if bit_width == 256 { + DataType::Decimal256(fsb.precision() as usize, fsb.scale() as usize) + } else { + panic!("Unexpected decimal bit width {}", bit_width) + } } ipc::Type::Union => { let union = field.type_as_union().unwrap(); @@ -671,6 +678,17 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } + Decimal256(precision, scale) => { + let mut builder = ipc::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(256); + FBFieldType { + type_type: ipc::Type::Decimal, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } Union(fields, type_ids, mode) => { let mut children = vec![]; for field in fields { diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index e8abd3a6326..a9d28bd67f4 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -506,7 +506,7 @@ fn create_primitive_array( unsafe { builder.build_unchecked() } } - Decimal(_, _) => { + Decimal(_, _) | Decimal256(_, _) => { // read 3 buffers let builder = ArrayData::builder(data_type.clone()) .len(length) diff --git a/integration-testing/Cargo.toml b/integration-testing/Cargo.toml index 897c7cfa5a5..dab45897ff5 100644 --- a/integration-testing/Cargo.toml +++ b/integration-testing/Cargo.toml @@ -44,3 +44,4 @@ serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false } tonic = { version = "0.7", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } +num = { version = "0.4", default-features = false, features = ["std"] } diff --git a/integration-testing/src/lib.rs b/integration-testing/src/lib.rs index 32ea6339e59..0bc92019856 100644 --- a/integration-testing/src/lib.rs +++ b/integration-testing/src/lib.rs @@ -33,6 +33,9 @@ use arrow::{ util::{bit_util, integration_util::*}, }; +use arrow::util::decimal::{BasicDecimal, Decimal256}; +use num::bigint::BigInt; +use num::Signed; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; @@ -611,6 +614,37 @@ fn array_from_json( } Ok(Arc::new(b.finish())) } + DataType::Decimal256(precision, scale) => { + let mut b = Decimal256Builder::new(json_col.count, *precision, *scale); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => { + let str = value.as_str().unwrap(); + let integer = BigInt::parse_bytes(str.as_bytes(), 10).unwrap(); + let integer_bytes = integer.to_signed_bytes_le(); + let mut bytes = if integer.is_positive() { + [0_u8; 32] + } else { + [255_u8; 32] + }; + bytes[0..integer_bytes.len()] + .copy_from_slice(integer_bytes.as_slice()); + let decimal = + Decimal256::try_new_from_bytes(*precision, *scale, &bytes) + .unwrap(); + b.append_value(&decimal) + } + _ => b.append_null(), + }?; + } + Ok(Arc::new(b.finish())) + } DataType::Map(child_field, _) => { let null_buf = create_null_buf(&json_col); let children = json_col.children.clone().unwrap(); diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 70ddf60f4aa..358a079fdeb 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -276,6 +276,7 @@ fn write_leaves( | ArrowDataType::Utf8 | ArrowDataType::LargeUtf8 | ArrowDataType::Decimal(_, _) + | ArrowDataType::Decimal256(_, _) | ArrowDataType::FixedSizeBinary(_) => { let mut col_writer = get_col_writer(row_group_writer)?; for (array, levels) in arrays.iter().zip(levels.iter_mut()) { diff --git a/parquet/src/arrow/schema.rs b/parquet/src/arrow/schema.rs index 97611d0ec30..1ff9c0f03c4 100644 --- a/parquet/src/arrow/schema.rs +++ b/parquet/src/arrow/schema.rs @@ -380,7 +380,7 @@ fn arrow_to_parquet_type(field: &Field) -> Result { .with_length(*length) .build() } - DataType::Decimal(precision, scale) => { + DataType::Decimal(precision, scale) | DataType::Decimal256(precision, scale) => { // Decimal precision determines the Parquet physical type to use. // TODO(ARROW-12018): Enable the below after ARROW-10818 Decimal support //