diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index d81a551a3b4..2b7ab2039a1 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -63,6 +63,8 @@ jobs: cargo run --example read_csv_infer_schema - name: Run non-archery based integration-tests run: cargo test -p arrow-integration-testing + - name: Test arrow-schema with all features + run: cargo test -p arrow-schema --all-features # test compilaton features linux-features: diff --git a/Cargo.toml b/Cargo.toml index d0233ccb376..355c65a8b80 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ [workspace] members = [ "arrow", + "arrow-schema", "arrow-buffer", "arrow-flight", "parquet", diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 086b2183465..f9e70eb8d77 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -28,9 +28,13 @@ use arrow::compute::kernels; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; use arrow::ffi_stream::ArrowArrayStreamReader; -use arrow::pyarrow::PyArrowConvert; +use arrow::pyarrow::{PyArrowConvert, PyArrowException, PyArrowType}; use arrow::record_batch::RecordBatch; +fn to_py_err(err: ArrowError) -> PyErr { + PyArrowException::new_err(err.to_string()) +} + /// Returns `array + array` of an int64 array. #[pyfunction] fn double(array: &PyAny, py: Python) -> PyResult { @@ -41,8 +45,10 @@ fn double(array: &PyAny, py: Python) -> PyResult { let array = array .as_any() .downcast_ref::() - .ok_or(ArrowError::ParseError("Expects an int64".to_string()))?; - let array = kernels::arithmetic::add(array, array)?; + .ok_or_else(|| ArrowError::ParseError("Expects an int64".to_string())) + .map_err(to_py_err)?; + + let array = kernels::arithmetic::add(array, array).map_err(to_py_err)?; // export array.to_pyarrow(py) @@ -66,56 +72,61 @@ fn double_py(lambda: &PyAny, py: Python) -> PyResult { /// Returns the substring #[pyfunction] -fn substring(array: ArrayData, start: i64) -> PyResult { +fn substring( + array: PyArrowType, + start: i64, +) -> PyResult> { // import - let array = ArrayRef::from(array); + let array = ArrayRef::from(array.0); // substring - let array = kernels::substring::substring(array.as_ref(), start, None)?; + let array = kernels::substring::substring(array.as_ref(), start, None).map_err(to_py_err)?; - Ok(array.data().to_owned()) + Ok(array.data().to_owned().into()) } /// Returns the concatenate #[pyfunction] -fn concatenate(array: ArrayData, py: Python) -> PyResult { - let array = ArrayRef::from(array); +fn concatenate(array: PyArrowType, py: Python) -> PyResult { + let array = ArrayRef::from(array.0); // concat - let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()])?; + let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).map_err(to_py_err)?; array.to_pyarrow(py) } #[pyfunction] -fn round_trip_type(obj: DataType) -> PyResult { +fn round_trip_type(obj: PyArrowType) -> PyResult> { Ok(obj) } #[pyfunction] -fn round_trip_field(obj: Field) -> PyResult { +fn round_trip_field(obj: PyArrowType) -> PyResult> { Ok(obj) } #[pyfunction] -fn round_trip_schema(obj: Schema) -> PyResult { +fn round_trip_schema(obj: PyArrowType) -> PyResult> { Ok(obj) } #[pyfunction] -fn round_trip_array(obj: ArrayData) -> PyResult { +fn round_trip_array(obj: PyArrowType) -> PyResult> { Ok(obj) } #[pyfunction] -fn round_trip_record_batch(obj: RecordBatch) -> PyResult { +fn round_trip_record_batch( + obj: PyArrowType, +) -> PyResult> { Ok(obj) } #[pyfunction] fn round_trip_record_batch_reader( - obj: ArrowArrayStreamReader, -) -> PyResult { + obj: PyArrowType, +) -> PyResult> { Ok(obj) } diff --git a/arrow-schema/Cargo.toml b/arrow-schema/Cargo.toml new file mode 100644 index 00000000000..d35a99a6d15 --- /dev/null +++ b/arrow-schema/Cargo.toml @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "arrow-schema" +version = "23.0.0" +description = "Defines the logical types for arrow arrays" +homepage = "https://github.com/apache/arrow-rs" +repository = "https://github.com/apache/arrow-rs" +authors = ["Apache Arrow "] +license = "Apache-2.0" +keywords = ["arrow"] +include = [ + "benches/*.rs", + "src/**/*.rs", + "Cargo.toml", +] +edition = "2021" +rust-version = "1.62" + +[lib] +name = "arrow_schema" +path = "src/lib.rs" +bench = false + +[dependencies] +serde = { version = "1.0", default-features = false, features = ["derive", "std"], optional = true } + +[features] +default = [] + +[dev-dependencies] +serde_json = "1.0" diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs new file mode 100644 index 00000000000..9037f7c9a53 --- /dev/null +++ b/arrow-schema/src/datatype.rs @@ -0,0 +1,492 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt; + +use crate::field::Field; + +/// The set of datatypes that are supported by this implementation of Apache Arrow. +/// +/// The Arrow specification on data types includes some more types. +/// See also [`Schema.fbs`](https://github.com/apache/arrow/blob/master/format/Schema.fbs) +/// for Arrow's specification. +/// +/// The variants of this enum include primitive fixed size types as well as parametric or +/// nested types. +/// Currently the Rust implementation supports the following nested types: +/// - `List` +/// - `Struct` +/// +/// Nested types can themselves be nested within other arrays. +/// For more information on these types please see +/// [the physical memory layout of Apache Arrow](https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout). +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum DataType { + /// Null type + Null, + /// A boolean datatype representing the values `true` and `false`. + Boolean, + /// A signed 8-bit integer. + Int8, + /// A signed 16-bit integer. + Int16, + /// A signed 32-bit integer. + Int32, + /// A signed 64-bit integer. + Int64, + /// An unsigned 8-bit integer. + UInt8, + /// An unsigned 16-bit integer. + UInt16, + /// An unsigned 32-bit integer. + UInt32, + /// An unsigned 64-bit integer. + UInt64, + /// A 16-bit floating point number. + Float16, + /// A 32-bit floating point number. + Float32, + /// A 64-bit floating point number. + Float64, + /// A timestamp with an optional timezone. + /// + /// Time is measured as a Unix epoch, counting the seconds from + /// 00:00:00.000 on 1 January 1970, excluding leap seconds, + /// as a 64-bit integer. + /// + /// The time zone is a string indicating the name of a time zone, one of: + /// + /// * As used in the Olson time zone database (the "tz database" or + /// "tzdata"), such as "America/New_York" + /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 + /// + /// Timestamps with a non-empty timezone + /// ------------------------------------ + /// + /// If a Timestamp column has a non-empty timezone value, its epoch is + /// 1970-01-01 00:00:00 (January 1st 1970, midnight) in the *UTC* timezone + /// (the Unix epoch), regardless of the Timestamp's own timezone. + /// + /// Therefore, timestamp values with a non-empty timezone correspond to + /// physical points in time together with some additional information about + /// how the data was obtained and/or how to display it (the timezone). + /// + /// For example, the timestamp value 0 with the timezone string "Europe/Paris" + /// corresponds to "January 1st 1970, 00h00" in the UTC timezone, but the + /// application may prefer to display it as "January 1st 1970, 01h00" in + /// the Europe/Paris timezone (which is the same physical point in time). + /// + /// One consequence is that timestamp values with a non-empty timezone + /// can be compared and ordered directly, since they all share the same + /// well-known point of reference (the Unix epoch). + /// + /// Timestamps with an unset / empty timezone + /// ----------------------------------------- + /// + /// If a Timestamp column has no timezone value, its epoch is + /// 1970-01-01 00:00:00 (January 1st 1970, midnight) in an *unknown* timezone. + /// + /// Therefore, timestamp values without a timezone cannot be meaningfully + /// interpreted as physical points in time, but only as calendar / clock + /// indications ("wall clock time") in an unspecified timezone. + /// + /// For example, the timestamp value 0 with an empty timezone string + /// corresponds to "January 1st 1970, 00h00" in an unknown timezone: there + /// is not enough information to interpret it as a well-defined physical + /// point in time. + /// + /// One consequence is that timestamp values without a timezone cannot + /// be reliably compared or ordered, since they may have different points of + /// reference. In particular, it is *not* possible to interpret an unset + /// or empty timezone as the same as "UTC". + /// + /// Conversion between timezones + /// ---------------------------- + /// + /// If a Timestamp column has a non-empty timezone, changing the timezone + /// to a different non-empty value is a metadata-only operation: + /// the timestamp values need not change as their point of reference remains + /// the same (the Unix epoch). + /// + /// However, if a Timestamp column has no timezone value, changing it to a + /// non-empty value requires to think about the desired semantics. + /// One possibility is to assume that the original timestamp values are + /// relative to the epoch of the timezone being set; timestamp values should + /// then adjusted to the Unix epoch (for example, changing the timezone from + /// empty to "Europe/Paris" would require converting the timestamp values + /// from "Europe/Paris" to "UTC", which seems counter-intuitive but is + /// nevertheless correct). + Timestamp(TimeUnit, Option), + /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in days (32 bits). + Date32, + /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) + /// in milliseconds (64 bits). Values are evenly divisible by 86400000. + Date64, + /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + Time32(TimeUnit), + /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. + Time64(TimeUnit), + /// Measure of elapsed time in either seconds, milliseconds, microseconds or nanoseconds. + Duration(TimeUnit), + /// A "calendar" interval which models types that don't necessarily + /// have a precise duration without the context of a base timestamp (e.g. + /// days can differ in length during day light savings time transitions). + Interval(IntervalUnit), + /// Opaque binary data of variable length. + Binary, + /// Opaque binary data of fixed size. + /// Enum parameter specifies the number of bytes per value. + FixedSizeBinary(i32), + /// Opaque binary data of variable length and 64-bit offsets. + LargeBinary, + /// A variable-length string in Unicode with UTF-8 encoding. + Utf8, + /// A variable-length string in Unicode with UFT-8 encoding and 64-bit offsets. + LargeUtf8, + /// A list of some logical data type with variable length. + List(Box), + /// A list of some logical data type with fixed length. + FixedSizeList(Box, i32), + /// A list of some logical data type with variable length and 64-bit offsets. + LargeList(Box), + /// A nested datatype that contains a number of sub-fields. + Struct(Vec), + /// A nested datatype that can represent slots of differing types. Components: + /// + /// 1. [`Field`] for each possible child type the Union can hold + /// 2. The corresponding `type_id` used to identify which Field + /// 3. The type of union (Sparse or Dense) + Union(Vec, Vec, UnionMode), + /// A dictionary encoded array (`key_type`, `value_type`), where + /// each array element is an index of `key_type` into an + /// associated dictionary of `value_type`. + /// + /// Dictionary arrays are used to store columns of `value_type` + /// that contain many repeated values using less memory, but with + /// a higher CPU overhead for some operations. + /// + /// This type mostly used to represent low cardinality string + /// arrays or a limited set of primitive types as integers. + Dictionary(Box, Box), + /// Exact 128-bit width decimal value with precision and scale + /// + /// * precision is the total number of digits + /// * scale is the number of digits past the decimal + /// + /// For example the number 123.45 has precision 5 and scale 2. + Decimal128(u8, u8), + /// Exact 256-bit width decimal value with precision and scale + /// + /// * precision is the total number of digits + /// * scale is the number of digits past the decimal + /// + /// For example the number 123.45 has precision 5 and scale 2. + Decimal256(u8, u8), + /// A Map is a logical nested type that is represented as + /// + /// `List>` + /// + /// The keys and values are each respectively contiguous. + /// The key and value types are not constrained, but keys should be + /// hashable and unique. + /// Whether the keys are sorted can be set in the `bool` after the `Field`. + /// + /// In a field with Map type, the field has a child Struct field, which then + /// has two children: key type and the second the value type. The names of the + /// child fields may be respectively "entries", "key", and "value", but this is + /// not enforced. + Map(Box, bool), +} + +/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum TimeUnit { + /// Time in seconds. + Second, + /// Time in milliseconds. + Millisecond, + /// Time in microseconds. + Microsecond, + /// Time in nanoseconds. + Nanosecond, +} + +/// YEAR_MONTH, DAY_TIME, MONTH_DAY_NANO interval in SQL style. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum IntervalUnit { + /// Indicates the number of elapsed whole months, stored as 4-byte integers. + YearMonth, + /// Indicates the number of elapsed days and milliseconds, + /// stored as 2 contiguous 32-bit integers (days, milliseconds) (8-bytes in total). + DayTime, + /// A triple of the number of elapsed months, days, and nanoseconds. + /// The values are stored contiguously in 16 byte blocks. Months and + /// days are encoded as 32 bit integers and nanoseconds is encoded as a + /// 64 bit integer. All integers are signed. Each field is independent + /// (e.g. there is no constraint that nanoseconds have the same sign + /// as days or that the quantity of nanoseconds represents less + /// than a day's worth of time). + MonthDayNano, +} + +// Sparse or Dense union layouts +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum UnionMode { + Sparse, + Dense, +} + +impl fmt::Display for DataType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl DataType { + /// Returns true if the type is primitive: (numeric, temporal). + pub fn is_primitive(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + Int8 | Int16 + | Int32 + | Int64 + | UInt8 + | UInt16 + | UInt32 + | UInt64 + | Float32 + | Float64 + | Date32 + | Date64 + | Time32(_) + | Time64(_) + | Timestamp(_, _) + | Interval(_) + | Duration(_) + ) + } + + /// Returns true if this type is numeric: (UInt*, Int*, or Float*). + pub fn is_numeric(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + UInt8 + | UInt16 + | UInt32 + | UInt64 + | Int8 + | Int16 + | Int32 + | Int64 + | Float32 + | Float64 + ) + } + + /// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval). + pub fn is_temporal(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + Date32 + | Date64 + | Timestamp(_, _) + | Time32(_) + | Time64(_) + | Duration(_) + | Interval(_) + ) + } + + /// Returns true if this type is valid as a dictionary key + pub fn is_dictionary_key_type(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 + ) + } + + /// Returns true if this type is nested (List, FixedSizeList, LargeList, Struct, Union, or Map) + pub fn is_nested(t: &DataType) -> bool { + use DataType::*; + matches!( + t, + List(_) + | FixedSizeList(_, _) + | LargeList(_) + | Struct(_) + | Union(_, _, _) + | Map(_, _) + ) + } + + /// Compares the datatype with another, ignoring nested field names + /// and metadata. + pub fn equals_datatype(&self, other: &DataType) -> bool { + match (&self, other) { + (DataType::List(a), DataType::List(b)) + | (DataType::LargeList(a), DataType::LargeList(b)) => { + a.is_nullable() == b.is_nullable() + && a.data_type().equals_datatype(b.data_type()) + } + (DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => { + a_size == b_size + && a.is_nullable() == b.is_nullable() + && a.data_type().equals_datatype(b.data_type()) + } + (DataType::Struct(a), DataType::Struct(b)) => { + a.len() == b.len() + && a.iter().zip(b).all(|(a, b)| { + a.is_nullable() == b.is_nullable() + && a.data_type().equals_datatype(b.data_type()) + }) + } + ( + DataType::Map(a_field, a_is_sorted), + DataType::Map(b_field, b_is_sorted), + ) => a_field == b_field && a_is_sorted == b_is_sorted, + _ => self == other, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + #[cfg(feature = "serde")] + fn serde_struct_type() { + use std::collections::BTreeMap; + + let kv_array = [("k".to_string(), "v".to_string())]; + let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); + + // Non-empty map: should be converted as JSON obj { ... } + let first_name = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(field_metadata)); + + // Empty map: should be omitted. + let last_name = Field::new("last_name", DataType::Utf8, false) + .with_metadata(Some(BTreeMap::default())); + + let person = DataType::Struct(vec![ + first_name, + last_name, + Field::new( + "address", + DataType::Struct(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ]), + false, + ), + ]); + + let serialized = serde_json::to_string(&person).unwrap(); + + // NOTE that this is testing the default (derived) serialization format, not the + // JSON format specified in metadata.md + + assert_eq!( + "{\"Struct\":[\ + {\"name\":\"first_name\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{\"k\":\"v\"}},\ + {\"name\":\"last_name\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false},\ + {\"name\":\"address\",\"data_type\":{\"Struct\":\ + [{\"name\":\"street\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false},\ + {\"name\":\"zip\",\"data_type\":\"UInt16\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false}\ + ]},\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false}]}", + serialized + ); + + let deserialized = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(person, deserialized); + } + + #[test] + fn test_list_datatype_equality() { + // tests that list type equality is checked while ignoring list names + let list_a = DataType::List(Box::new(Field::new("item", DataType::Int32, true))); + let list_b = DataType::List(Box::new(Field::new("array", DataType::Int32, true))); + let list_c = DataType::List(Box::new(Field::new("item", DataType::Int32, false))); + let list_d = DataType::List(Box::new(Field::new("item", DataType::UInt32, true))); + assert!(list_a.equals_datatype(&list_b)); + assert!(!list_a.equals_datatype(&list_c)); + assert!(!list_b.equals_datatype(&list_c)); + assert!(!list_a.equals_datatype(&list_d)); + + let list_e = + DataType::FixedSizeList(Box::new(Field::new("item", list_a, false)), 3); + let list_f = + DataType::FixedSizeList(Box::new(Field::new("array", list_b, false)), 3); + let list_g = DataType::FixedSizeList( + Box::new(Field::new("item", DataType::FixedSizeBinary(3), true)), + 3, + ); + assert!(list_e.equals_datatype(&list_f)); + assert!(!list_e.equals_datatype(&list_g)); + assert!(!list_f.equals_datatype(&list_g)); + + let list_h = DataType::Struct(vec![Field::new("f1", list_e, true)]); + let list_i = DataType::Struct(vec![Field::new("f1", list_f.clone(), true)]); + let list_j = DataType::Struct(vec![Field::new("f1", list_f.clone(), false)]); + let list_k = DataType::Struct(vec![ + Field::new("f1", list_f.clone(), false), + Field::new("f2", list_g.clone(), false), + Field::new("f3", DataType::Utf8, true), + ]); + let list_l = DataType::Struct(vec![ + Field::new("ff1", list_f.clone(), false), + Field::new("ff2", list_g.clone(), false), + Field::new("ff3", DataType::LargeUtf8, true), + ]); + let list_m = DataType::Struct(vec![ + Field::new("ff1", list_f, false), + Field::new("ff2", list_g, false), + Field::new("ff3", DataType::Utf8, true), + ]); + assert!(list_h.equals_datatype(&list_i)); + assert!(!list_h.equals_datatype(&list_j)); + assert!(!list_k.equals_datatype(&list_l)); + assert!(list_k.equals_datatype(&list_m)); + } + + #[test] + fn create_struct_type() { + let _person = DataType::Struct(vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new( + "address", + DataType::Struct(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ]), + false, + ), + ]); + } +} diff --git a/arrow-schema/src/error.rs b/arrow-schema/src/error.rs new file mode 100644 index 00000000000..105d4d5e21f --- /dev/null +++ b/arrow-schema/src/error.rs @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines `ArrowError` for representing failures in various Arrow operations. +use std::fmt::{Debug, Display, Formatter}; +use std::io::Write; + +use std::error::Error; + +/// Many different operations in the `arrow` crate return this error type. +#[derive(Debug)] +pub enum ArrowError { + /// Returned when functionality is not yet available. + NotYetImplemented(String), + ExternalError(Box), + CastError(String), + MemoryError(String), + ParseError(String), + SchemaError(String), + ComputeError(String), + DivideByZero, + CsvError(String), + JsonError(String), + IoError(String), + InvalidArgumentError(String), + ParquetError(String), + /// Error during import or export to/from the C Data Interface + CDataInterface(String), + DictionaryKeyOverflowError, +} + +impl ArrowError { + /// Wraps an external error in an `ArrowError`. + pub fn from_external_error(error: Box) -> Self { + Self::ExternalError(error) + } +} + +impl From<::std::io::Error> for ArrowError { + fn from(error: std::io::Error) -> Self { + ArrowError::IoError(error.to_string()) + } +} + +impl From<::std::string::FromUtf8Error> for ArrowError { + fn from(error: std::string::FromUtf8Error) -> Self { + ArrowError::ParseError(error.to_string()) + } +} + +impl From<::std::io::IntoInnerError> for ArrowError { + fn from(error: std::io::IntoInnerError) -> Self { + ArrowError::IoError(error.to_string()) + } +} + +impl Display for ArrowError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + ArrowError::NotYetImplemented(source) => { + write!(f, "Not yet implemented: {}", &source) + } + ArrowError::ExternalError(source) => write!(f, "External error: {}", &source), + ArrowError::CastError(desc) => write!(f, "Cast error: {}", desc), + ArrowError::MemoryError(desc) => write!(f, "Memory error: {}", desc), + ArrowError::ParseError(desc) => write!(f, "Parser error: {}", desc), + ArrowError::SchemaError(desc) => write!(f, "Schema error: {}", desc), + ArrowError::ComputeError(desc) => write!(f, "Compute error: {}", desc), + ArrowError::DivideByZero => write!(f, "Divide by zero error"), + ArrowError::CsvError(desc) => write!(f, "Csv error: {}", desc), + ArrowError::JsonError(desc) => write!(f, "Json error: {}", desc), + ArrowError::IoError(desc) => write!(f, "Io error: {}", desc), + ArrowError::InvalidArgumentError(desc) => { + write!(f, "Invalid argument error: {}", desc) + } + ArrowError::ParquetError(desc) => { + write!(f, "Parquet argument error: {}", desc) + } + ArrowError::CDataInterface(desc) => { + write!(f, "C Data interface error: {}", desc) + } + ArrowError::DictionaryKeyOverflowError => { + write!(f, "Dictionary key bigger than the key type") + } + } + } +} + +impl Error for ArrowError {} diff --git a/arrow/src/datatypes/field.rs b/arrow-schema/src/field.rs similarity index 98% rename from arrow/src/datatypes/field.rs rename to arrow-schema/src/field.rs index 03d07807743..adafbfa9b72 100644 --- a/arrow/src/datatypes/field.rs +++ b/arrow-schema/src/field.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::error::{ArrowError, Result}; +use crate::error::ArrowError; use std::cmp::Ordering; use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; -use super::DataType; +use crate::datatype::DataType; /// Describes a single column in a [`Schema`](super::Schema). /// @@ -145,7 +145,7 @@ impl Field { /// Set the name of the [`Field`] and returns self. /// /// ``` - /// # use arrow::datatypes::*; + /// # use arrow_schema::*; /// let field = Field::new("c1", DataType::Int64, false) /// .with_name("c2"); /// @@ -165,7 +165,7 @@ impl Field { /// Set [`DataType`] of the [`Field`] and returns self. /// /// ``` - /// # use arrow::datatypes::*; + /// # use arrow_schema::*; /// let field = Field::new("c1", DataType::Int64, false) /// .with_data_type(DataType::Utf8); /// @@ -185,7 +185,7 @@ impl Field { /// Set `nullable` of the [`Field`] and returns self. /// /// ``` - /// # use arrow::datatypes::*; + /// # use arrow_schema::*; /// let field = Field::new("c1", DataType::Int64, false) /// .with_nullable(true); /// @@ -259,12 +259,12 @@ impl Field { /// Example: /// /// ``` - /// # use arrow::datatypes::*; + /// # use arrow_schema::*; /// let mut field = Field::new("c1", DataType::Int64, false); /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok()); /// assert!(field.is_nullable()); /// ``` - pub fn try_merge(&mut self, from: &Field) -> Result<()> { + pub fn try_merge(&mut self, from: &Field) -> Result<(), ArrowError> { if from.dict_id != self.dict_id { return Err(ArrowError::SchemaError( "Fail to merge schema Field due to conflicting dict_id".to_string(), diff --git a/arrow-schema/src/lib.rs b/arrow-schema/src/lib.rs new file mode 100644 index 00000000000..34030f2d356 --- /dev/null +++ b/arrow-schema/src/lib.rs @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Arrow logical types + +mod datatype; +pub use datatype::*; +mod error; +pub use error::*; +mod field; +pub use field::*; +mod schema; +pub use schema::*; diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs new file mode 100644 index 00000000000..9605cdda720 --- /dev/null +++ b/arrow-schema/src/schema.rs @@ -0,0 +1,782 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fmt; +use std::hash::Hash; + +use crate::error::ArrowError; +use crate::field::Field; + +/// Describes the meta-data of an ordered sequence of relative types. +/// +/// Note that this information is only part of the meta-data and not part of the physical +/// memory layout. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct Schema { + pub fields: Vec, + /// A map of key-value pairs containing additional meta data. + #[cfg_attr( + feature = "serde", + serde(skip_serializing_if = "HashMap::is_empty", default) + )] + pub metadata: HashMap, +} + +impl Schema { + /// Creates an empty `Schema` + pub fn empty() -> Self { + Self { + fields: vec![], + metadata: HashMap::new(), + } + } + + /// Creates a new [`Schema`] from a sequence of [`Field`] values. + /// + /// # Example + /// + /// ``` + /// # use arrow_schema::*; + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let schema = Schema::new(vec![field_a, field_b]); + /// ``` + pub fn new(fields: Vec) -> Self { + Self::new_with_metadata(fields, HashMap::new()) + } + + /// Creates a new [`Schema`] from a sequence of [`Field`] values + /// and adds additional metadata in form of key value pairs. + /// + /// # Example + /// + /// ``` + /// # use arrow_schema::*; + /// # use std::collections::HashMap; + /// + /// let field_a = Field::new("a", DataType::Int64, false); + /// let field_b = Field::new("b", DataType::Boolean, false); + /// + /// let mut metadata: HashMap = HashMap::new(); + /// metadata.insert("row_count".to_string(), "100".to_string()); + /// + /// let schema = Schema::new_with_metadata(vec![field_a, field_b], metadata); + /// ``` + #[inline] + pub const fn new_with_metadata( + fields: Vec, + metadata: HashMap, + ) -> Self { + Self { fields, metadata } + } + + /// Sets the metadata of this `Schema` to be `metadata` and returns self + pub fn with_metadata(mut self, metadata: HashMap) -> Self { + self.metadata = metadata; + self + } + + /// Returns a new schema with only the specified columns in the new schema + /// This carries metadata from the parent schema over as well + pub fn project(&self, indices: &[usize]) -> Result { + let new_fields = indices + .iter() + .map(|i| { + self.fields.get(*i).cloned().ok_or_else(|| { + ArrowError::SchemaError(format!( + "project index {} out of bounds, max field {}", + i, + self.fields().len() + )) + }) + }) + .collect::, _>>()?; + Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) + } + + /// Merge schema into self if it is compatible. Struct fields will be merged recursively. + /// + /// Example: + /// + /// ``` + /// # use arrow_schema::*; + /// + /// let merged = Schema::try_merge(vec![ + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, false), + /// Field::new("c2", DataType::Utf8, false), + /// ]), + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ]).unwrap(); + /// + /// assert_eq!( + /// merged, + /// Schema::new(vec![ + /// Field::new("c1", DataType::Int64, true), + /// Field::new("c2", DataType::Utf8, false), + /// Field::new("c3", DataType::Utf8, false), + /// ]), + /// ); + /// ``` + pub fn try_merge( + schemas: impl IntoIterator, + ) -> Result { + schemas + .into_iter() + .try_fold(Self::empty(), |mut merged, schema| { + let Schema { metadata, fields } = schema; + for (key, value) in metadata.into_iter() { + // merge metadata + if let Some(old_val) = merged.metadata.get(&key) { + if old_val != &value { + return Err(ArrowError::SchemaError(format!( + "Fail to merge schema due to conflicting metadata. \ + Key '{}' has different values '{}' and '{}'", + key, old_val, value + ))); + } + } + merged.metadata.insert(key, value); + } + // merge fields + for field in fields.into_iter() { + let merged_field = + merged.fields.iter_mut().find(|f| f.name() == field.name()); + match merged_field { + Some(merged_field) => merged_field.try_merge(&field)?, + // found a new field, add to field list + None => merged.fields.push(field), + } + } + Ok(merged) + }) + } + + /// Returns an immutable reference of the vector of `Field` instances. + #[inline] + pub const fn fields(&self) -> &Vec { + &self.fields + } + + /// Returns a vector with references to all fields (including nested fields) + #[inline] + pub fn all_fields(&self) -> Vec<&Field> { + self.fields.iter().flat_map(|f| f.fields()).collect() + } + + /// Returns an immutable reference of a specific [`Field`] instance selected using an + /// offset within the internal `fields` vector. + pub fn field(&self, i: usize) -> &Field { + &self.fields[i] + } + + /// Returns an immutable reference of a specific [`Field`] instance selected by name. + pub fn field_with_name(&self, name: &str) -> Result<&Field, ArrowError> { + Ok(&self.fields[self.index_of(name)?]) + } + + /// Returns a vector of immutable references to all [`Field`] instances selected by + /// the dictionary ID they use. + pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { + self.fields + .iter() + .flat_map(|f| f.fields_with_dict_id(dict_id)) + .collect() + } + + /// Find the index of the column with the given name. + pub fn index_of(&self, name: &str) -> Result { + (0..self.fields.len()) + .find(|idx| self.fields[*idx].name() == name) + .ok_or_else(|| { + let valid_fields: Vec = + self.fields.iter().map(|f| f.name().clone()).collect(); + ArrowError::SchemaError(format!( + "Unable to get field named \"{}\". Valid fields: {:?}", + name, valid_fields + )) + }) + } + + /// Returns an immutable reference to the Map of custom metadata key-value pairs. + #[inline] + pub const fn metadata(&self) -> &HashMap { + &self.metadata + } + + /// Look up a column by name and return a immutable reference to the column along with + /// its index. + pub fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { + self.fields + .iter() + .enumerate() + .find(|&(_, c)| c.name() == name) + } + + /// Check to see if `self` is a superset of `other` schema. Here are the comparison rules: + /// + /// * `self` and `other` should contain the same number of fields + /// * for every field `f` in `other`, the field in `self` with corresponding index should be a + /// superset of `f`. + /// * self.metadata is a superset of other.metadata + /// + /// In other words, any record conforms to `other` should also conform to `self`. + pub fn contains(&self, other: &Schema) -> bool { + self.fields.len() == other.fields.len() + && self.fields.iter().zip(other.fields.iter()).all(|(f1, f2)| f1.contains(f2)) + // make sure self.metadata is a superset of other.metadata + && other.metadata.iter().all(|(k, v1)| match self.metadata.get(k) { + Some(v2) => v1 == v2, + _ => false, + }) + } +} + +impl fmt::Display for Schema { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str( + &self + .fields + .iter() + .map(|c| c.to_string()) + .collect::>() + .join(", "), + ) + } +} + +// need to implement `Hash` manually because `HashMap` implement Eq but no `Hash` +#[allow(clippy::derive_hash_xor_eq)] +impl Hash for Schema { + fn hash(&self, state: &mut H) { + self.fields.hash(state); + + // ensure deterministic key order + let mut keys: Vec<&String> = self.metadata.keys().collect(); + keys.sort(); + for k in keys { + k.hash(state); + self.metadata.get(k).expect("key valid").hash(state); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::datatype::DataType; + use crate::{TimeUnit, UnionMode}; + use std::collections::BTreeMap; + + #[test] + #[cfg(feature = "serde")] + fn test_ser_de_metadata() { + // ser/de with empty metadata + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]); + + let json = serde_json::to_string(&schema).unwrap(); + let de_schema = serde_json::from_str(&json).unwrap(); + + assert_eq!(schema, de_schema); + + // ser/de with non-empty metadata + let schema = schema + .with_metadata([("key".to_owned(), "val".to_owned())].into_iter().collect()); + let json = serde_json::to_string(&schema).unwrap(); + let de_schema = serde_json::from_str(&json).unwrap(); + + assert_eq!(schema, de_schema); + } + + #[test] + fn test_projection() { + let mut metadata = HashMap::new(); + metadata.insert("meta".to_string(), "data".to_string()); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata); + + let projected: Schema = schema.project(&[0, 2]).unwrap(); + + assert_eq!(projected.fields().len(), 2); + assert_eq!(projected.fields()[0].name(), "name"); + assert_eq!(projected.fields()[1].name(), "priority"); + assert_eq!(projected.metadata.get("meta").unwrap(), "data") + } + + #[test] + fn test_oob_projection() { + let mut metadata = HashMap::new(); + metadata.insert("meta".to_string(), "data".to_string()); + + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata); + + let projected = schema.project(&[0, 3]); + + assert!(projected.is_err()); + if let Err(e) = projected { + assert_eq!( + e.to_string(), + "Schema error: project index 3 out of bounds, max field 3".to_string() + ) + } + } + + #[test] + fn test_schema_contains() { + let mut metadata1 = HashMap::new(); + metadata1.insert("meta".to_string(), "data".to_string()); + + let schema1 = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata1.clone()); + + let mut metadata2 = HashMap::new(); + metadata2.insert("meta".to_string(), "data".to_string()); + metadata2.insert("meta2".to_string(), "data".to_string()); + let schema2 = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]) + .with_metadata(metadata2); + + // reflexivity + assert!(schema1.contains(&schema1)); + assert!(schema2.contains(&schema2)); + + assert!(!schema1.contains(&schema2)); + assert!(schema2.contains(&schema1)); + } + + #[test] + fn schema_equality() { + let schema1 = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::LargeBinary, true), + ]); + let schema2 = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::LargeBinary, true), + ]); + + assert_eq!(schema1, schema2); + + let schema3 = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float32, true), + ]); + let schema4 = Schema::new(vec![ + Field::new("C1", DataType::Utf8, false), + Field::new("C2", DataType::Float64, true), + ]); + + assert_ne!(schema1, schema3); + assert_ne!(schema1, schema4); + assert_ne!(schema2, schema3); + assert_ne!(schema2, schema4); + assert_ne!(schema3, schema4); + + let f = Field::new("c1", DataType::Utf8, false).with_metadata(Some( + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect(), + )); + let schema5 = Schema::new(vec![ + f, + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::LargeBinary, true), + ]); + assert_ne!(schema1, schema5); + } + + #[test] + fn create_schema_string() { + let schema = person_schema(); + assert_eq!(schema.to_string(), + "Field { name: \"first_name\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: Some({\"k\": \"v\"}) }, \ + Field { name: \"last_name\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \ + Field { name: \"address\", data_type: Struct([\ + Field { name: \"street\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \ + Field { name: \"zip\", data_type: UInt16, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }\ + ]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \ + Field { name: \"interests\", data_type: Dictionary(Int32, Utf8), nullable: true, dict_id: 123, dict_is_ordered: true, metadata: None }") + } + + #[test] + fn schema_field_accessors() { + let schema = person_schema(); + + // test schema accessors + assert_eq!(schema.fields().len(), 4); + + // test field accessors + let first_name = &schema.fields()[0]; + assert_eq!(first_name.name(), "first_name"); + assert_eq!(first_name.data_type(), &DataType::Utf8); + assert!(!first_name.is_nullable()); + assert_eq!(first_name.dict_id(), None); + assert_eq!(first_name.dict_is_ordered(), None); + + let metadata = first_name.metadata(); + assert!(metadata.is_some()); + let md = metadata.as_ref().unwrap(); + assert_eq!(md.len(), 1); + let key = md.get("k"); + assert!(key.is_some()); + assert_eq!(key.unwrap(), "v"); + + let interests = &schema.fields()[3]; + assert_eq!(interests.name(), "interests"); + assert_eq!( + interests.data_type(), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + ); + assert_eq!(interests.dict_id(), Some(123)); + assert_eq!(interests.dict_is_ordered(), Some(true)); + } + + #[test] + #[should_panic( + expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]" + )] + fn schema_index_of() { + let schema = person_schema(); + assert_eq!(schema.index_of("first_name").unwrap(), 0); + assert_eq!(schema.index_of("last_name").unwrap(), 1); + schema.index_of("nickname").unwrap(); + } + + #[test] + #[should_panic( + expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]" + )] + fn schema_field_with_name() { + let schema = person_schema(); + assert_eq!( + schema.field_with_name("first_name").unwrap().name(), + "first_name" + ); + assert_eq!( + schema.field_with_name("last_name").unwrap().name(), + "last_name" + ); + schema.field_with_name("nickname").unwrap(); + } + + #[test] + fn schema_field_with_dict_id() { + let schema = person_schema(); + + let fields_dict_123: Vec<_> = schema + .fields_with_dict_id(123) + .iter() + .map(|f| f.name()) + .collect(); + assert_eq!(fields_dict_123, vec!["interests"]); + + assert!(schema.fields_with_dict_id(456).is_empty()); + } + + fn person_schema() -> Schema { + let kv_array = [("k".to_string(), "v".to_string())]; + let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); + let first_name = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(field_metadata)); + + Schema::new(vec![ + first_name, + Field::new("last_name", DataType::Utf8, false), + Field::new( + "address", + DataType::Struct(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("zip", DataType::UInt16, false), + ]), + false, + ), + Field::new_dict( + "interests", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + 123, + true, + ), + ]) + } + + #[test] + fn test_try_merge_field_with_metadata() { + // 1. Different values for the same key should cause error. + let metadata1: BTreeMap = + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect(); + let f1 = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(metadata1)); + + let metadata2: BTreeMap = + [("foo".to_string(), "baz".to_string())] + .iter() + .cloned() + .collect(); + let f2 = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(metadata2)); + + assert!( + Schema::try_merge(vec![Schema::new(vec![f1]), Schema::new(vec![f2])]) + .is_err() + ); + + // 2. None + Some + let mut f1 = Field::new("first_name", DataType::Utf8, false); + let metadata2: BTreeMap = + [("missing".to_string(), "value".to_string())] + .iter() + .cloned() + .collect(); + let f2 = Field::new("first_name", DataType::Utf8, false) + .with_metadata(Some(metadata2)); + + assert!(f1.try_merge(&f2).is_ok()); + assert!(f1.metadata().is_some()); + assert_eq!( + f1.metadata().as_ref().unwrap(), + f2.metadata().as_ref().unwrap() + ); + + // 3. Some + Some + let mut f1 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect(), + )); + let f2 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( + [("foo2".to_string(), "bar2".to_string())] + .iter() + .cloned() + .collect(), + )); + + assert!(f1.try_merge(&f2).is_ok()); + assert!(f1.metadata().is_some()); + assert_eq!( + f1.metadata().cloned().unwrap(), + [ + ("foo".to_string(), "bar".to_string()), + ("foo2".to_string(), "bar2".to_string()) + ] + .iter() + .cloned() + .collect() + ); + + // 4. Some + None. + let mut f1 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect(), + )); + let f2 = Field::new("first_name", DataType::Utf8, false); + assert!(f1.try_merge(&f2).is_ok()); + assert!(f1.metadata().is_some()); + assert_eq!( + f1.metadata().cloned().unwrap(), + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect() + ); + + // 5. None + None. + let mut f1 = Field::new("first_name", DataType::Utf8, false); + let f2 = Field::new("first_name", DataType::Utf8, false); + assert!(f1.try_merge(&f2).is_ok()); + assert!(f1.metadata().is_none()); + } + + #[test] + fn test_schema_merge() { + let merged = Schema::try_merge(vec![ + Schema::new(vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + Field::new( + "address", + DataType::Struct(vec![Field::new("zip", DataType::UInt16, false)]), + false, + ), + ]), + Schema::new_with_metadata( + vec![ + // nullable merge + Field::new("last_name", DataType::Utf8, true), + Field::new( + "address", + DataType::Struct(vec![ + // add new nested field + Field::new("street", DataType::Utf8, false), + // nullable merge on nested field + Field::new("zip", DataType::UInt16, true), + ]), + false, + ), + // new field + Field::new("number", DataType::Utf8, true), + ], + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect::>(), + ), + ]) + .unwrap(); + + assert_eq!( + merged, + Schema::new_with_metadata( + vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, true), + Field::new( + "address", + DataType::Struct(vec![ + Field::new("zip", DataType::UInt16, true), + Field::new("street", DataType::Utf8, false), + ]), + false, + ), + Field::new("number", DataType::Utf8, true), + ], + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect::>() + ) + ); + + // support merge union fields + assert_eq!( + Schema::try_merge(vec![ + Schema::new(vec![Field::new( + "c1", + DataType::Union( + vec![ + Field::new("c11", DataType::Utf8, true), + Field::new("c12", DataType::Utf8, true), + ], + vec![0, 1], + UnionMode::Dense + ), + false + ),]), + Schema::new(vec![Field::new( + "c1", + DataType::Union( + vec![ + Field::new("c12", DataType::Utf8, true), + Field::new("c13", DataType::Time64(TimeUnit::Second), true), + ], + vec![1, 2], + UnionMode::Dense + ), + false + ),]) + ]) + .unwrap(), + Schema::new(vec![Field::new( + "c1", + DataType::Union( + vec![ + Field::new("c11", DataType::Utf8, true), + Field::new("c12", DataType::Utf8, true), + Field::new("c13", DataType::Time64(TimeUnit::Second), true), + ], + vec![0, 1, 2], + UnionMode::Dense + ), + false + ),]), + ); + + // incompatible field should throw error + assert!(Schema::try_merge(vec![ + Schema::new(vec![ + Field::new("first_name", DataType::Utf8, false), + Field::new("last_name", DataType::Utf8, false), + ]), + Schema::new(vec![Field::new("last_name", DataType::Int64, false),]) + ]) + .is_err()); + + // incompatible metadata should throw error + let res = Schema::try_merge(vec![ + Schema::new_with_metadata( + vec![Field::new("first_name", DataType::Utf8, false)], + [("foo".to_string(), "bar".to_string())] + .iter() + .cloned() + .collect::>(), + ), + Schema::new_with_metadata( + vec![Field::new("last_name", DataType::Utf8, false)], + [("foo".to_string(), "baz".to_string())] + .iter() + .cloned() + .collect::>(), + ), + ]) + .unwrap_err(); + + let expected = "Fail to merge schema due to conflicting metadata. Key 'foo' has different values 'bar' and 'baz'"; + assert!( + res.to_string().contains(expected), + "Could not find expected string '{}' in '{}'", + expected, + res + ); + } +} diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 7391ffcf827..ebe2daca9c1 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -44,9 +44,8 @@ ahash = { version = "0.8", default-features = false, features = ["compile-time-r ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] } [dependencies] -arrow-buffer = { path = "../arrow-buffer", version = "23.0.0" } - -serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } +arrow-buffer = { version = "23.0.0", path = "../arrow-buffer" } +arrow-schema = { version = "23.0.0", path = "../arrow-schema" } serde_json = { version = "1.0", default-features = false, features = ["std"], optional = true } indexmap = { version = "1.9", default-features = false, features = ["std"] } rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true } @@ -77,7 +76,7 @@ default = ["csv", "ipc", "json"] ipc_compression = ["ipc", "zstd", "lz4"] csv = ["csv_crate"] ipc = ["flatbuffers"] -json = ["serde", "serde_json"] +json = ["serde_json"] simd = ["packed_simd"] prettyprint = ["comfy-table"] # The test utils feature enables code used in benchmarks and tests but diff --git a/arrow/src/csv/mod.rs b/arrow/src/csv/mod.rs index ffe82f33580..46ba7d71e20 100644 --- a/arrow/src/csv/mod.rs +++ b/arrow/src/csv/mod.rs @@ -25,3 +25,22 @@ pub use self::reader::Reader; pub use self::reader::ReaderBuilder; pub use self::writer::Writer; pub use self::writer::WriterBuilder; +use arrow_schema::ArrowError; + +fn map_csv_error(error: csv_crate::Error) -> ArrowError { + match error.kind() { + csv_crate::ErrorKind::Io(error) => ArrowError::CsvError(error.to_string()), + csv_crate::ErrorKind::Utf8 { pos: _, err } => ArrowError::CsvError(format!( + "Encountered UTF-8 error while reading CSV file: {}", + err + )), + csv_crate::ErrorKind::UnequalLengths { + expected_len, len, .. + } => ArrowError::CsvError(format!( + "Encountered unequal lengths between records on CSV file. Expected {} \ + records, found {} records", + len, expected_len + )), + _ => ArrowError::CsvError("Error reading CSV file".to_string()), + } +} diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs index d164d35c3c8..3ec605dd048 100644 --- a/arrow/src/csv/reader.rs +++ b/arrow/src/csv/reader.rs @@ -58,6 +58,7 @@ use crate::error::{ArrowError, Result}; use crate::record_batch::{RecordBatch, RecordBatchOptions}; use crate::util::reader_parser::Parser; +use crate::csv::map_csv_error; use csv_crate::{ByteRecord, StringRecord}; use std::ops::Neg; @@ -187,10 +188,10 @@ fn infer_reader_schema_with_csv_options( // get or create header names // when has_header is false, creates default column names with column_ prefix let headers: Vec = if roptions.has_header { - let headers = &csv_reader.headers()?.clone(); + let headers = &csv_reader.headers().map_err(map_csv_error)?.clone(); headers.iter().map(|s| s.to_string()).collect() } else { - let first_record_count = &csv_reader.headers()?.len(); + let first_record_count = &csv_reader.headers().map_err(map_csv_error)?.len(); (0..*first_record_count) .map(|i| format!("column_{}", i + 1)) .collect() @@ -208,7 +209,7 @@ fn infer_reader_schema_with_csv_options( let mut record = StringRecord::new(); let max_records = roptions.max_read_records.unwrap_or(usize::MAX); while records_count < max_records { - if !csv_reader.read_record(&mut record)? { + if !csv_reader.read_record(&mut record).map_err(map_csv_error)? { break; } records_count += 1; diff --git a/arrow/src/csv/writer.rs b/arrow/src/csv/writer.rs index 7097706ba5f..1b377c38b37 100644 --- a/arrow/src/csv/writer.rs +++ b/arrow/src/csv/writer.rs @@ -70,11 +70,13 @@ use crate::compute::kernels::temporal::using_chrono_tz_and_utc_naive_date_time; #[cfg(feature = "chrono-tz")] use chrono::{DateTime, Utc}; +use crate::csv::map_csv_error; use crate::datatypes::*; use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use crate::util::display::make_string_from_decimal; use crate::{array::*, util::serialization::lexical_to_string}; + const DEFAULT_DATE_FORMAT: &str = "%F"; const DEFAULT_TIME_FORMAT: &str = "%T"; const DEFAULT_TIMESTAMP_FORMAT: &str = "%FT%H:%M:%S.%9f"; @@ -343,7 +345,9 @@ impl Writer { .fields() .iter() .for_each(|field| headers.push(field.name().to_string())); - self.writer.write_record(&headers[..])?; + self.writer + .write_record(&headers[..]) + .map_err(map_csv_error)?; } self.beginning = false; } @@ -364,7 +368,7 @@ impl Writer { for row_index in 0..batch.num_rows() { self.convert(columns.as_slice(), row_index, &mut buffer)?; - self.writer.write_record(&buffer)?; + self.writer.write_record(&buffer).map_err(map_csv_error)?; } self.writer.flush()?; diff --git a/arrow/src/datatypes/datatype.rs b/arrow/src/datatypes/decimal.rs similarity index 68% rename from arrow/src/datatypes/datatype.rs rename to arrow/src/datatypes/decimal.rs index d3189b8b18c..ffdb04e0d77 100644 --- a/arrow/src/datatypes/datatype.rs +++ b/arrow/src/datatypes/decimal.rs @@ -15,256 +15,10 @@ // specific language governing permissions and limitations // under the License. -use num::BigInt; -use std::cmp::Ordering; -use std::fmt; - use crate::error::{ArrowError, Result}; use crate::util::decimal::singed_cmp_le_bytes; - -use super::Field; - -/// The set of datatypes that are supported by this implementation of Apache Arrow. -/// -/// The Arrow specification on data types includes some more types. -/// See also [`Schema.fbs`](https://github.com/apache/arrow/blob/master/format/Schema.fbs) -/// for Arrow's specification. -/// -/// The variants of this enum include primitive fixed size types as well as parametric or -/// nested types. -/// Currently the Rust implementation supports the following nested types: -/// - `List` -/// - `Struct` -/// -/// Nested types can themselves be nested within other arrays. -/// For more information on these types please see -/// [the physical memory layout of Apache Arrow](https://arrow.apache.org/docs/format/Columnar.html#physical-memory-layout). -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum DataType { - /// Null type - Null, - /// A boolean datatype representing the values `true` and `false`. - Boolean, - /// A signed 8-bit integer. - Int8, - /// A signed 16-bit integer. - Int16, - /// A signed 32-bit integer. - Int32, - /// A signed 64-bit integer. - Int64, - /// An unsigned 8-bit integer. - UInt8, - /// An unsigned 16-bit integer. - UInt16, - /// An unsigned 32-bit integer. - UInt32, - /// An unsigned 64-bit integer. - UInt64, - /// A 16-bit floating point number. - Float16, - /// A 32-bit floating point number. - Float32, - /// A 64-bit floating point number. - Float64, - /// A timestamp with an optional timezone. - /// - /// Time is measured as a Unix epoch, counting the seconds from - /// 00:00:00.000 on 1 January 1970, excluding leap seconds, - /// as a 64-bit integer. - /// - /// The time zone is a string indicating the name of a time zone, one of: - /// - /// * As used in the Olson time zone database (the "tz database" or - /// "tzdata"), such as "America/New_York" - /// * An absolute time zone offset of the form +XX:XX or -XX:XX, such as +07:30 - /// - /// Timestamps with a non-empty timezone - /// ------------------------------------ - /// - /// If a Timestamp column has a non-empty timezone value, its epoch is - /// 1970-01-01 00:00:00 (January 1st 1970, midnight) in the *UTC* timezone - /// (the Unix epoch), regardless of the Timestamp's own timezone. - /// - /// Therefore, timestamp values with a non-empty timezone correspond to - /// physical points in time together with some additional information about - /// how the data was obtained and/or how to display it (the timezone). - /// - /// For example, the timestamp value 0 with the timezone string "Europe/Paris" - /// corresponds to "January 1st 1970, 00h00" in the UTC timezone, but the - /// application may prefer to display it as "January 1st 1970, 01h00" in - /// the Europe/Paris timezone (which is the same physical point in time). - /// - /// One consequence is that timestamp values with a non-empty timezone - /// can be compared and ordered directly, since they all share the same - /// well-known point of reference (the Unix epoch). - /// - /// Timestamps with an unset / empty timezone - /// ----------------------------------------- - /// - /// If a Timestamp column has no timezone value, its epoch is - /// 1970-01-01 00:00:00 (January 1st 1970, midnight) in an *unknown* timezone. - /// - /// Therefore, timestamp values without a timezone cannot be meaningfully - /// interpreted as physical points in time, but only as calendar / clock - /// indications ("wall clock time") in an unspecified timezone. - /// - /// For example, the timestamp value 0 with an empty timezone string - /// corresponds to "January 1st 1970, 00h00" in an unknown timezone: there - /// is not enough information to interpret it as a well-defined physical - /// point in time. - /// - /// One consequence is that timestamp values without a timezone cannot - /// be reliably compared or ordered, since they may have different points of - /// reference. In particular, it is *not* possible to interpret an unset - /// or empty timezone as the same as "UTC". - /// - /// Conversion between timezones - /// ---------------------------- - /// - /// If a Timestamp column has a non-empty timezone, changing the timezone - /// to a different non-empty value is a metadata-only operation: - /// the timestamp values need not change as their point of reference remains - /// the same (the Unix epoch). - /// - /// However, if a Timestamp column has no timezone value, changing it to a - /// non-empty value requires to think about the desired semantics. - /// One possibility is to assume that the original timestamp values are - /// relative to the epoch of the timezone being set; timestamp values should - /// then adjusted to the Unix epoch (for example, changing the timezone from - /// empty to "Europe/Paris" would require converting the timestamp values - /// from "Europe/Paris" to "UTC", which seems counter-intuitive but is - /// nevertheless correct). - Timestamp(TimeUnit, Option), - /// A 32-bit date representing the elapsed time since UNIX epoch (1970-01-01) - /// in days (32 bits). - Date32, - /// A 64-bit date representing the elapsed time since UNIX epoch (1970-01-01) - /// in milliseconds (64 bits). Values are evenly divisible by 86400000. - Date64, - /// A 32-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. - Time32(TimeUnit), - /// A 64-bit time representing the elapsed time since midnight in the unit of `TimeUnit`. - Time64(TimeUnit), - /// Measure of elapsed time in either seconds, milliseconds, microseconds or nanoseconds. - Duration(TimeUnit), - /// A "calendar" interval which models types that don't necessarily - /// have a precise duration without the context of a base timestamp (e.g. - /// days can differ in length during day light savings time transitions). - Interval(IntervalUnit), - /// Opaque binary data of variable length. - Binary, - /// Opaque binary data of fixed size. - /// Enum parameter specifies the number of bytes per value. - FixedSizeBinary(i32), - /// Opaque binary data of variable length and 64-bit offsets. - LargeBinary, - /// A variable-length string in Unicode with UTF-8 encoding. - Utf8, - /// A variable-length string in Unicode with UFT-8 encoding and 64-bit offsets. - LargeUtf8, - /// A list of some logical data type with variable length. - List(Box), - /// A list of some logical data type with fixed length. - FixedSizeList(Box, i32), - /// A list of some logical data type with variable length and 64-bit offsets. - LargeList(Box), - /// A nested datatype that contains a number of sub-fields. - Struct(Vec), - /// A nested datatype that can represent slots of differing types. Components: - /// - /// 1. [`Field`] for each possible child type the Union can hold - /// 2. The corresponding `type_id` used to identify which Field - /// 3. The type of union (Sparse or Dense) - Union(Vec, Vec, UnionMode), - /// A dictionary encoded array (`key_type`, `value_type`), where - /// each array element is an index of `key_type` into an - /// associated dictionary of `value_type`. - /// - /// Dictionary arrays are used to store columns of `value_type` - /// that contain many repeated values using less memory, but with - /// a higher CPU overhead for some operations. - /// - /// This type mostly used to represent low cardinality string - /// arrays or a limited set of primitive types as integers. - Dictionary(Box, Box), - /// Exact 128-bit width decimal value with precision and scale - /// - /// * precision is the total number of digits - /// * scale is the number of digits past the decimal - /// - /// For example the number 123.45 has precision 5 and scale 2. - Decimal128(u8, u8), - /// Exact 256-bit width decimal value with precision and scale - /// - /// * precision is the total number of digits - /// * scale is the number of digits past the decimal - /// - /// For example the number 123.45 has precision 5 and scale 2. - Decimal256(u8, u8), - /// A Map is a logical nested type that is represented as - /// - /// `List>` - /// - /// The keys and values are each respectively contiguous. - /// The key and value types are not constrained, but keys should be - /// hashable and unique. - /// Whether the keys are sorted can be set in the `bool` after the `Field`. - /// - /// In a field with Map type, the field has a child Struct field, which then - /// has two children: key type and the second the value type. The names of the - /// child fields may be respectively "entries", "key", and "value", but this is - /// not enforced. - Map(Box, bool), -} - -/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum TimeUnit { - /// Time in seconds. - Second, - /// Time in milliseconds. - Millisecond, - /// Time in microseconds. - Microsecond, - /// Time in nanoseconds. - Nanosecond, -} - -/// YEAR_MONTH, DAY_TIME, MONTH_DAY_NANO interval in SQL style. -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum IntervalUnit { - /// Indicates the number of elapsed whole months, stored as 4-byte integers. - YearMonth, - /// Indicates the number of elapsed days and milliseconds, - /// stored as 2 contiguous 32-bit integers (days, milliseconds) (8-bytes in total). - DayTime, - /// A triple of the number of elapsed months, days, and nanoseconds. - /// The values are stored contiguously in 16 byte blocks. Months and - /// days are encoded as 32 bit integers and nanoseconds is encoded as a - /// 64 bit integer. All integers are signed. Each field is independent - /// (e.g. there is no constraint that nanoseconds have the same sign - /// as days or that the quantity of nanoseconds represents less - /// than a day's worth of time). - MonthDayNano, -} - -// Sparse or Dense union layouts -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub enum UnionMode { - Sparse, - Dense, -} - -impl fmt::Display for DataType { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{:?}", self) - } -} +use num::BigInt; +use std::cmp::Ordering; // MAX decimal256 value of little-endian format for each precision. // Each element is the max value of signed 256-bit integer for the specified precision which @@ -887,7 +641,7 @@ pub(crate) const MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION: [[u8; 32]; 76] = [ ]; /// `MAX_DECIMAL_FOR_EACH_PRECISION[p]` holds the maximum `i128` value -/// that can be stored in [DataType::Decimal128] value of precision `p` +/// that can be stored in [arrow_schema::DataType::Decimal128] value of precision `p` pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ 9, 99, @@ -930,7 +684,7 @@ pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ ]; /// `MIN_DECIMAL_FOR_EACH_PRECISION[p]` holds the minimum `i128` value -/// that can be stored in a [DataType::Decimal128] value of precision `p` +/// that can be stored in a [arrow_schema::DataType::Decimal128] value of precision `p` pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -9, -99, @@ -972,19 +726,20 @@ pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -99999999999999999999999999999999999999, ]; -/// The maximum precision for [DataType::Decimal128] values +/// The maximum precision for [arrow_schema::DataType::Decimal128] values pub const DECIMAL128_MAX_PRECISION: u8 = 38; -/// The maximum scale for [DataType::Decimal128] values +/// The maximum scale for [arrow_schema::DataType::Decimal128] values pub const DECIMAL128_MAX_SCALE: u8 = 38; -/// The maximum precision for [DataType::Decimal256] values +/// The maximum precision for [arrow_schema::DataType::Decimal256] values pub const DECIMAL256_MAX_PRECISION: u8 = 76; -/// The maximum scale for [DataType::Decimal256] values +/// The maximum scale for [arrow_schema::DataType::Decimal256] values pub const DECIMAL256_MAX_SCALE: u8 = 76; -/// The default scale for [DataType::Decimal128] and [DataType::Decimal256] values +/// The default scale for [arrow_schema::DataType::Decimal128] and +/// [arrow_schema::DataType::Decimal256] values pub const DECIMAL_DEFAULT_SCALE: u8 = 10; /// Validates that the specified `i128` value can be properly @@ -1051,124 +806,9 @@ pub(crate) fn validate_decimal256_precision_with_lt_bytes( } } -impl DataType { - /// Returns true if this type is numeric: (UInt*, Int*, or Float*). - pub fn is_numeric(t: &DataType) -> bool { - use DataType::*; - matches!( - t, - UInt8 - | UInt16 - | UInt32 - | UInt64 - | Int8 - | Int16 - | Int32 - | Int64 - | Float32 - | Float64 - ) - } - - /// Returns true if the type is primitive: (numeric, temporal). - pub fn is_primitive(t: &DataType) -> bool { - use DataType::*; - matches!( - t, - Int8 | Int16 - | Int32 - | Int64 - | UInt8 - | UInt16 - | UInt32 - | UInt64 - | Float32 - | Float64 - | Date32 - | Date64 - | Time32(_) - | Time64(_) - | Timestamp(_, _) - | Interval(_) - | Duration(_) - ) - } - - /// Returns true if this type is temporal: (Date*, Time*, Duration, or Interval). - pub fn is_temporal(t: &DataType) -> bool { - use DataType::*; - matches!( - t, - Date32 - | Date64 - | Timestamp(_, _) - | Time32(_) - | Time64(_) - | Duration(_) - | Interval(_) - ) - } - - /// Returns true if this type is valid as a dictionary key - /// (e.g. [`super::ArrowDictionaryKeyType`] - pub fn is_dictionary_key_type(t: &DataType) -> bool { - use DataType::*; - matches!( - t, - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 - ) - } - - /// Returns true if this type is nested (List, FixedSizeList, LargeList, Struct, Union, or Map) - pub fn is_nested(t: &DataType) -> bool { - use DataType::*; - matches!( - t, - List(_) - | FixedSizeList(_, _) - | LargeList(_) - | Struct(_) - | Union(_, _, _) - | Map(_, _) - ) - } - - /// Compares the datatype with another, ignoring nested field names - /// and metadata. - pub fn equals_datatype(&self, other: &DataType) -> bool { - match (&self, other) { - (DataType::List(a), DataType::List(b)) - | (DataType::LargeList(a), DataType::LargeList(b)) => { - a.is_nullable() == b.is_nullable() - && a.data_type().equals_datatype(b.data_type()) - } - (DataType::FixedSizeList(a, a_size), DataType::FixedSizeList(b, b_size)) => { - a_size == b_size - && a.is_nullable() == b.is_nullable() - && a.data_type().equals_datatype(b.data_type()) - } - (DataType::Struct(a), DataType::Struct(b)) => { - a.len() == b.len() - && a.iter().zip(b).all(|(a, b)| { - a.is_nullable() == b.is_nullable() - && a.data_type().equals_datatype(b.data_type()) - }) - } - ( - DataType::Map(a_field, a_is_sorted), - DataType::Map(b_field, b_is_sorted), - ) => a_field == b_field && a_is_sorted == b_is_sorted, - _ => self == other, - } - } -} - #[cfg(test)] mod test { - use crate::datatypes::datatype::{ - MAX_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION, - MIN_DECIMAL_BYTES_FOR_LARGER_EACH_PRECISION, - }; + use super::*; use crate::util::decimal::Decimal256; use num::{BigInt, Num}; diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 1586d563cd3..2f83871127f 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -26,17 +26,15 @@ use std::sync::Arc; mod native; pub use native::*; -mod field; -pub use field::*; -mod schema; -pub use schema::*; mod numeric; pub use numeric::*; mod types; pub use types::*; -mod datatype; -pub use datatype::*; +mod decimal; mod delta; +pub use decimal::*; + +pub use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode}; #[cfg(feature = "ffi")] mod ffi; @@ -45,550 +43,3 @@ pub use ffi::*; /// A reference-counted reference to a [`Schema`](crate::datatypes::Schema). pub type SchemaRef = Arc; - -#[cfg(test)] -mod tests { - use super::*; - use crate::error::Result; - use std::collections::{BTreeMap, HashMap}; - - #[cfg(feature = "json")] - use crate::json::JsonSerializable; - - #[cfg(feature = "json")] - use serde_json::{ - Number, - Value::{Bool, Number as VNumber, String as VString}, - }; - - #[test] - fn test_list_datatype_equality() { - // tests that list type equality is checked while ignoring list names - let list_a = DataType::List(Box::new(Field::new("item", DataType::Int32, true))); - let list_b = DataType::List(Box::new(Field::new("array", DataType::Int32, true))); - let list_c = DataType::List(Box::new(Field::new("item", DataType::Int32, false))); - let list_d = DataType::List(Box::new(Field::new("item", DataType::UInt32, true))); - assert!(list_a.equals_datatype(&list_b)); - assert!(!list_a.equals_datatype(&list_c)); - assert!(!list_b.equals_datatype(&list_c)); - assert!(!list_a.equals_datatype(&list_d)); - - let list_e = - DataType::FixedSizeList(Box::new(Field::new("item", list_a, false)), 3); - let list_f = - DataType::FixedSizeList(Box::new(Field::new("array", list_b, false)), 3); - let list_g = DataType::FixedSizeList( - Box::new(Field::new("item", DataType::FixedSizeBinary(3), true)), - 3, - ); - assert!(list_e.equals_datatype(&list_f)); - assert!(!list_e.equals_datatype(&list_g)); - assert!(!list_f.equals_datatype(&list_g)); - - let list_h = DataType::Struct(vec![Field::new("f1", list_e, true)]); - let list_i = DataType::Struct(vec![Field::new("f1", list_f.clone(), true)]); - let list_j = DataType::Struct(vec![Field::new("f1", list_f.clone(), false)]); - let list_k = DataType::Struct(vec![ - Field::new("f1", list_f.clone(), false), - Field::new("f2", list_g.clone(), false), - Field::new("f3", DataType::Utf8, true), - ]); - let list_l = DataType::Struct(vec![ - Field::new("ff1", list_f.clone(), false), - Field::new("ff2", list_g.clone(), false), - Field::new("ff3", DataType::LargeUtf8, true), - ]); - let list_m = DataType::Struct(vec![ - Field::new("ff1", list_f, false), - Field::new("ff2", list_g, false), - Field::new("ff3", DataType::Utf8, true), - ]); - assert!(list_h.equals_datatype(&list_i)); - assert!(!list_h.equals_datatype(&list_j)); - assert!(!list_k.equals_datatype(&list_l)); - assert!(list_k.equals_datatype(&list_m)); - } - - #[test] - #[cfg(feature = "json")] - fn create_struct_type() { - let _person = DataType::Struct(vec![ - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new( - "address", - DataType::Struct(vec![ - Field::new("street", DataType::Utf8, false), - Field::new("zip", DataType::UInt16, false), - ]), - false, - ), - ]); - } - - #[test] - #[cfg(feature = "json")] - fn serde_struct_type() { - let kv_array = [("k".to_string(), "v".to_string())]; - let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); - - // Non-empty map: should be converted as JSON obj { ... } - let first_name = Field::new("first_name", DataType::Utf8, false) - .with_metadata(Some(field_metadata)); - - // Empty map: should be omitted. - let last_name = Field::new("last_name", DataType::Utf8, false) - .with_metadata(Some(BTreeMap::default())); - - let person = DataType::Struct(vec![ - first_name, - last_name, - Field::new( - "address", - DataType::Struct(vec![ - Field::new("street", DataType::Utf8, false), - Field::new("zip", DataType::UInt16, false), - ]), - false, - ), - ]); - - let serialized = serde_json::to_string(&person).unwrap(); - - // NOTE that this is testing the default (derived) serialization format, not the - // JSON format specified in metadata.md - - assert_eq!( - "{\"Struct\":[\ - {\"name\":\"first_name\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false,\"metadata\":{\"k\":\"v\"}},\ - {\"name\":\"last_name\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false},\ - {\"name\":\"address\",\"data_type\":{\"Struct\":\ - [{\"name\":\"street\",\"data_type\":\"Utf8\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false},\ - {\"name\":\"zip\",\"data_type\":\"UInt16\",\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false}\ - ]},\"nullable\":false,\"dict_id\":0,\"dict_is_ordered\":false}]}", - serialized - ); - - let deserialized = serde_json::from_str(&serialized).unwrap(); - - assert_eq!(person, deserialized); - } - - #[test] - fn create_schema_string() { - let schema = person_schema(); - assert_eq!(schema.to_string(), - "Field { name: \"first_name\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: Some({\"k\": \"v\"}) }, \ - Field { name: \"last_name\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"address\", data_type: Struct([\ - Field { name: \"street\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"zip\", data_type: UInt16, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }\ - ]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: None }, \ - Field { name: \"interests\", data_type: Dictionary(Int32, Utf8), nullable: true, dict_id: 123, dict_is_ordered: true, metadata: None }") - } - - #[test] - fn schema_field_accessors() { - let schema = person_schema(); - - // test schema accessors - assert_eq!(schema.fields().len(), 4); - - // test field accessors - let first_name = &schema.fields()[0]; - assert_eq!(first_name.name(), "first_name"); - assert_eq!(first_name.data_type(), &DataType::Utf8); - assert!(!first_name.is_nullable()); - assert_eq!(first_name.dict_id(), None); - assert_eq!(first_name.dict_is_ordered(), None); - - let metadata = first_name.metadata(); - assert!(metadata.is_some()); - let md = metadata.as_ref().unwrap(); - assert_eq!(md.len(), 1); - let key = md.get("k"); - assert!(key.is_some()); - assert_eq!(key.unwrap(), "v"); - - let interests = &schema.fields()[3]; - assert_eq!(interests.name(), "interests"); - assert_eq!( - interests.data_type(), - &DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) - ); - assert_eq!(interests.dict_id(), Some(123)); - assert_eq!(interests.dict_is_ordered(), Some(true)); - } - - #[test] - #[should_panic( - expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]" - )] - fn schema_index_of() { - let schema = person_schema(); - assert_eq!(schema.index_of("first_name").unwrap(), 0); - assert_eq!(schema.index_of("last_name").unwrap(), 1); - schema.index_of("nickname").unwrap(); - } - - #[test] - #[should_panic( - expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]" - )] - fn schema_field_with_name() { - let schema = person_schema(); - assert_eq!( - schema.field_with_name("first_name").unwrap().name(), - "first_name" - ); - assert_eq!( - schema.field_with_name("last_name").unwrap().name(), - "last_name" - ); - schema.field_with_name("nickname").unwrap(); - } - - #[test] - fn schema_field_with_dict_id() { - let schema = person_schema(); - - let fields_dict_123: Vec<_> = schema - .fields_with_dict_id(123) - .iter() - .map(|f| f.name()) - .collect(); - assert_eq!(fields_dict_123, vec!["interests"]); - - assert!(schema.fields_with_dict_id(456).is_empty()); - } - - #[test] - fn schema_equality() { - let schema1 = Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::LargeBinary, true), - ]); - let schema2 = Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::LargeBinary, true), - ]); - - assert_eq!(schema1, schema2); - - let schema3 = Schema::new(vec![ - Field::new("c1", DataType::Utf8, false), - Field::new("c2", DataType::Float32, true), - ]); - let schema4 = Schema::new(vec![ - Field::new("C1", DataType::Utf8, false), - Field::new("C2", DataType::Float64, true), - ]); - - assert!(schema1 != schema3); - assert!(schema1 != schema4); - assert!(schema2 != schema3); - assert!(schema2 != schema4); - assert!(schema3 != schema4); - - let f = Field::new("c1", DataType::Utf8, false).with_metadata(Some( - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect(), - )); - let schema5 = Schema::new(vec![ - f, - Field::new("c2", DataType::Float64, true), - Field::new("c3", DataType::LargeBinary, true), - ]); - assert!(schema1 != schema5); - } - - #[test] - #[cfg(feature = "json")] - fn test_arrow_native_type_to_json() { - assert_eq!(Some(Bool(true)), true.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1i8.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1i16.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1i32.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1i64.into_json_value()); - assert_eq!(Some(VString("1".to_string())), 1i128.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1u8.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1u16.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1u32.into_json_value()); - assert_eq!(Some(VNumber(Number::from(1))), 1u64.into_json_value()); - assert_eq!( - Some(VNumber(Number::from_f64(0.01f64).unwrap())), - 0.01.into_json_value() - ); - assert_eq!( - Some(VNumber(Number::from_f64(0.01f64).unwrap())), - 0.01f64.into_json_value() - ); - assert_eq!(None, f32::NAN.into_json_value()); - } - - fn person_schema() -> Schema { - let kv_array = [("k".to_string(), "v".to_string())]; - let field_metadata: BTreeMap = kv_array.iter().cloned().collect(); - let first_name = Field::new("first_name", DataType::Utf8, false) - .with_metadata(Some(field_metadata)); - - Schema::new(vec![ - first_name, - Field::new("last_name", DataType::Utf8, false), - Field::new( - "address", - DataType::Struct(vec![ - Field::new("street", DataType::Utf8, false), - Field::new("zip", DataType::UInt16, false), - ]), - false, - ), - Field::new_dict( - "interests", - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - true, - 123, - true, - ), - ]) - } - - #[test] - fn test_try_merge_field_with_metadata() { - // 1. Different values for the same key should cause error. - let metadata1: BTreeMap = - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect(); - let f1 = Field::new("first_name", DataType::Utf8, false) - .with_metadata(Some(metadata1)); - - let metadata2: BTreeMap = - [("foo".to_string(), "baz".to_string())] - .iter() - .cloned() - .collect(); - let f2 = Field::new("first_name", DataType::Utf8, false) - .with_metadata(Some(metadata2)); - - assert!( - Schema::try_merge(vec![Schema::new(vec![f1]), Schema::new(vec![f2])]) - .is_err() - ); - - // 2. None + Some - let mut f1 = Field::new("first_name", DataType::Utf8, false); - let metadata2: BTreeMap = - [("missing".to_string(), "value".to_string())] - .iter() - .cloned() - .collect(); - let f2 = Field::new("first_name", DataType::Utf8, false) - .with_metadata(Some(metadata2)); - - assert!(f1.try_merge(&f2).is_ok()); - assert!(f1.metadata().is_some()); - assert_eq!( - f1.metadata().as_ref().unwrap(), - f2.metadata().as_ref().unwrap() - ); - - // 3. Some + Some - let mut f1 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect(), - )); - let f2 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( - [("foo2".to_string(), "bar2".to_string())] - .iter() - .cloned() - .collect(), - )); - - assert!(f1.try_merge(&f2).is_ok()); - assert!(f1.metadata().is_some()); - assert_eq!( - f1.metadata().cloned().unwrap(), - [ - ("foo".to_string(), "bar".to_string()), - ("foo2".to_string(), "bar2".to_string()) - ] - .iter() - .cloned() - .collect() - ); - - // 4. Some + None. - let mut f1 = Field::new("first_name", DataType::Utf8, false).with_metadata(Some( - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect(), - )); - let f2 = Field::new("first_name", DataType::Utf8, false); - assert!(f1.try_merge(&f2).is_ok()); - assert!(f1.metadata().is_some()); - assert_eq!( - f1.metadata().cloned().unwrap(), - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect() - ); - - // 5. None + None. - let mut f1 = Field::new("first_name", DataType::Utf8, false); - let f2 = Field::new("first_name", DataType::Utf8, false); - assert!(f1.try_merge(&f2).is_ok()); - assert!(f1.metadata().is_none()); - } - - #[test] - fn test_schema_merge() -> Result<()> { - let merged = Schema::try_merge(vec![ - Schema::new(vec![ - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - Field::new( - "address", - DataType::Struct(vec![Field::new("zip", DataType::UInt16, false)]), - false, - ), - ]), - Schema::new_with_metadata( - vec![ - // nullable merge - Field::new("last_name", DataType::Utf8, true), - Field::new( - "address", - DataType::Struct(vec![ - // add new nested field - Field::new("street", DataType::Utf8, false), - // nullable merge on nested field - Field::new("zip", DataType::UInt16, true), - ]), - false, - ), - // new field - Field::new("number", DataType::Utf8, true), - ], - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect::>(), - ), - ])?; - - assert_eq!( - merged, - Schema::new_with_metadata( - vec![ - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, true), - Field::new( - "address", - DataType::Struct(vec![ - Field::new("zip", DataType::UInt16, true), - Field::new("street", DataType::Utf8, false), - ]), - false, - ), - Field::new("number", DataType::Utf8, true), - ], - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect::>() - ) - ); - - // support merge union fields - assert_eq!( - Schema::try_merge(vec![ - Schema::new(vec![Field::new( - "c1", - DataType::Union( - vec![ - Field::new("c11", DataType::Utf8, true), - Field::new("c12", DataType::Utf8, true), - ], - vec![0, 1], - UnionMode::Dense - ), - false - ),]), - Schema::new(vec![Field::new( - "c1", - DataType::Union( - vec![ - Field::new("c12", DataType::Utf8, true), - Field::new("c13", DataType::Time64(TimeUnit::Second), true), - ], - vec![1, 2], - UnionMode::Dense - ), - false - ),]) - ])?, - Schema::new(vec![Field::new( - "c1", - DataType::Union( - vec![ - Field::new("c11", DataType::Utf8, true), - Field::new("c12", DataType::Utf8, true), - Field::new("c13", DataType::Time64(TimeUnit::Second), true), - ], - vec![0, 1, 2], - UnionMode::Dense - ), - false - ),]), - ); - - // incompatible field should throw error - assert!(Schema::try_merge(vec![ - Schema::new(vec![ - Field::new("first_name", DataType::Utf8, false), - Field::new("last_name", DataType::Utf8, false), - ]), - Schema::new(vec![Field::new("last_name", DataType::Int64, false),]) - ]) - .is_err()); - - // incompatible metadata should throw error - let res = Schema::try_merge(vec![ - Schema::new_with_metadata( - vec![Field::new("first_name", DataType::Utf8, false)], - [("foo".to_string(), "bar".to_string())] - .iter() - .cloned() - .collect::>(), - ), - Schema::new_with_metadata( - vec![Field::new("last_name", DataType::Utf8, false)], - [("foo".to_string(), "baz".to_string())] - .iter() - .cloned() - .collect::>(), - ), - ]) - .unwrap_err(); - - let expected = "Fail to merge schema due to conflicting metadata. Key 'foo' has different values 'bar' and 'baz'"; - assert!( - res.to_string().contains(expected), - "Could not find expected string '{}' in '{}'", - expected, - res - ); - - Ok(()) - } -} diff --git a/arrow/src/datatypes/schema.rs b/arrow/src/datatypes/schema.rs deleted file mode 100644 index b0eca611474..00000000000 --- a/arrow/src/datatypes/schema.rs +++ /dev/null @@ -1,386 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::collections::HashMap; -use std::fmt; -use std::hash::Hash; - -use crate::error::{ArrowError, Result}; - -use super::Field; - -/// Describes the meta-data of an ordered sequence of relative types. -/// -/// Note that this information is only part of the meta-data and not part of the physical -/// memory layout. -#[derive(Debug, Clone, PartialEq, Eq)] -#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Schema { - pub fields: Vec, - /// A map of key-value pairs containing additional meta data. - #[cfg_attr( - feature = "serde", - serde(skip_serializing_if = "HashMap::is_empty", default) - )] - pub metadata: HashMap, -} - -impl Schema { - /// Creates an empty `Schema` - pub fn empty() -> Self { - Self { - fields: vec![], - metadata: HashMap::new(), - } - } - - /// Creates a new [`Schema`] from a sequence of [`Field`] values. - /// - /// # Example - /// - /// ``` - /// # use arrow::datatypes::{Field, DataType, Schema}; - /// let field_a = Field::new("a", DataType::Int64, false); - /// let field_b = Field::new("b", DataType::Boolean, false); - /// - /// let schema = Schema::new(vec![field_a, field_b]); - /// ``` - pub fn new(fields: Vec) -> Self { - Self::new_with_metadata(fields, HashMap::new()) - } - - /// Creates a new [`Schema`] from a sequence of [`Field`] values - /// and adds additional metadata in form of key value pairs. - /// - /// # Example - /// - /// ``` - /// # use arrow::datatypes::{Field, DataType, Schema}; - /// # use std::collections::HashMap; - /// let field_a = Field::new("a", DataType::Int64, false); - /// let field_b = Field::new("b", DataType::Boolean, false); - /// - /// let mut metadata: HashMap = HashMap::new(); - /// metadata.insert("row_count".to_string(), "100".to_string()); - /// - /// let schema = Schema::new_with_metadata(vec![field_a, field_b], metadata); - /// ``` - #[inline] - pub const fn new_with_metadata( - fields: Vec, - metadata: HashMap, - ) -> Self { - Self { fields, metadata } - } - - /// Sets the metadata of this `Schema` to be `metadata` and returns self - pub fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; - self - } - - /// Returns a new schema with only the specified columns in the new schema - /// This carries metadata from the parent schema over as well - pub fn project(&self, indices: &[usize]) -> Result { - let new_fields = indices - .iter() - .map(|i| { - self.fields.get(*i).cloned().ok_or_else(|| { - ArrowError::SchemaError(format!( - "project index {} out of bounds, max field {}", - i, - self.fields().len() - )) - }) - }) - .collect::>>()?; - Ok(Self::new_with_metadata(new_fields, self.metadata.clone())) - } - - /// Merge schema into self if it is compatible. Struct fields will be merged recursively. - /// - /// Example: - /// - /// ``` - /// use arrow::datatypes::*; - /// - /// let merged = Schema::try_merge(vec![ - /// Schema::new(vec![ - /// Field::new("c1", DataType::Int64, false), - /// Field::new("c2", DataType::Utf8, false), - /// ]), - /// Schema::new(vec![ - /// Field::new("c1", DataType::Int64, true), - /// Field::new("c2", DataType::Utf8, false), - /// Field::new("c3", DataType::Utf8, false), - /// ]), - /// ]).unwrap(); - /// - /// assert_eq!( - /// merged, - /// Schema::new(vec![ - /// Field::new("c1", DataType::Int64, true), - /// Field::new("c2", DataType::Utf8, false), - /// Field::new("c3", DataType::Utf8, false), - /// ]), - /// ); - /// ``` - pub fn try_merge(schemas: impl IntoIterator) -> Result { - schemas - .into_iter() - .try_fold(Self::empty(), |mut merged, schema| { - let Schema { metadata, fields } = schema; - for (key, value) in metadata.into_iter() { - // merge metadata - if let Some(old_val) = merged.metadata.get(&key) { - if old_val != &value { - return Err(ArrowError::SchemaError(format!( - "Fail to merge schema due to conflicting metadata. \ - Key '{}' has different values '{}' and '{}'", - key, old_val, value - ))); - } - } - merged.metadata.insert(key, value); - } - // merge fields - for field in fields.into_iter() { - let merged_field = - merged.fields.iter_mut().find(|f| f.name() == field.name()); - match merged_field { - Some(merged_field) => merged_field.try_merge(&field)?, - // found a new field, add to field list - None => merged.fields.push(field), - } - } - Ok(merged) - }) - } - - /// Returns an immutable reference of the vector of `Field` instances. - #[inline] - pub const fn fields(&self) -> &Vec { - &self.fields - } - - /// Returns a vector with references to all fields (including nested fields) - #[inline] - #[cfg(feature = "ipc")] - pub(crate) fn all_fields(&self) -> Vec<&Field> { - self.fields.iter().flat_map(|f| f.fields()).collect() - } - - /// Returns an immutable reference of a specific [`Field`] instance selected using an - /// offset within the internal `fields` vector. - pub fn field(&self, i: usize) -> &Field { - &self.fields[i] - } - - /// Returns an immutable reference of a specific [`Field`] instance selected by name. - pub fn field_with_name(&self, name: &str) -> Result<&Field> { - Ok(&self.fields[self.index_of(name)?]) - } - - /// Returns a vector of immutable references to all [`Field`] instances selected by - /// the dictionary ID they use. - pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> { - self.fields - .iter() - .flat_map(|f| f.fields_with_dict_id(dict_id)) - .collect() - } - - /// Find the index of the column with the given name. - pub fn index_of(&self, name: &str) -> Result { - (0..self.fields.len()) - .find(|idx| self.fields[*idx].name() == name) - .ok_or_else(|| { - let valid_fields: Vec = - self.fields.iter().map(|f| f.name().clone()).collect(); - ArrowError::InvalidArgumentError(format!( - "Unable to get field named \"{}\". Valid fields: {:?}", - name, valid_fields - )) - }) - } - - /// Returns an immutable reference to the Map of custom metadata key-value pairs. - #[inline] - pub const fn metadata(&self) -> &HashMap { - &self.metadata - } - - /// Look up a column by name and return a immutable reference to the column along with - /// its index. - pub fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> { - self.fields - .iter() - .enumerate() - .find(|&(_, c)| c.name() == name) - } - - /// Check to see if `self` is a superset of `other` schema. Here are the comparison rules: - /// - /// * `self` and `other` should contain the same number of fields - /// * for every field `f` in `other`, the field in `self` with corresponding index should be a - /// superset of `f`. - /// * self.metadata is a superset of other.metadata - /// - /// In other words, any record conforms to `other` should also conform to `self`. - pub fn contains(&self, other: &Schema) -> bool { - self.fields.len() == other.fields.len() - && self.fields.iter().zip(other.fields.iter()).all(|(f1, f2)| f1.contains(f2)) - // make sure self.metadata is a superset of other.metadata - && other.metadata.iter().all(|(k, v1)| match self.metadata.get(k) { - Some(v2) => v1 == v2, - _ => false, - }) - } -} - -impl fmt::Display for Schema { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str( - &self - .fields - .iter() - .map(|c| c.to_string()) - .collect::>() - .join(", "), - ) - } -} - -// need to implement `Hash` manually because `HashMap` implement Eq but no `Hash` -#[allow(clippy::derive_hash_xor_eq)] -impl Hash for Schema { - fn hash(&self, state: &mut H) { - self.fields.hash(state); - - // ensure deterministic key order - let mut keys: Vec<&String> = self.metadata.keys().collect(); - keys.sort(); - for k in keys { - k.hash(state); - self.metadata.get(k).expect("key valid").hash(state); - } - } -} - -#[cfg(test)] -mod tests { - use crate::datatypes::DataType; - - use super::*; - - #[test] - #[cfg(feature = "json")] - fn test_ser_de_metadata() { - // ser/de with empty metadata - let schema = Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("address", DataType::Utf8, false), - Field::new("priority", DataType::UInt8, false), - ]); - - let json = serde_json::to_string(&schema).unwrap(); - let de_schema = serde_json::from_str(&json).unwrap(); - - assert_eq!(schema, de_schema); - - // ser/de with non-empty metadata - let schema = schema - .with_metadata([("key".to_owned(), "val".to_owned())].into_iter().collect()); - let json = serde_json::to_string(&schema).unwrap(); - let de_schema = serde_json::from_str(&json).unwrap(); - - assert_eq!(schema, de_schema); - } - - #[test] - fn test_projection() { - let mut metadata = HashMap::new(); - metadata.insert("meta".to_string(), "data".to_string()); - - let schema = Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("address", DataType::Utf8, false), - Field::new("priority", DataType::UInt8, false), - ]) - .with_metadata(metadata); - - let projected: Schema = schema.project(&[0, 2]).unwrap(); - - assert_eq!(projected.fields().len(), 2); - assert_eq!(projected.fields()[0].name(), "name"); - assert_eq!(projected.fields()[1].name(), "priority"); - assert_eq!(projected.metadata.get("meta").unwrap(), "data") - } - - #[test] - fn test_oob_projection() { - let mut metadata = HashMap::new(); - metadata.insert("meta".to_string(), "data".to_string()); - - let schema = Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("address", DataType::Utf8, false), - Field::new("priority", DataType::UInt8, false), - ]) - .with_metadata(metadata); - - let projected: Result = schema.project(&[0, 3]); - - assert!(projected.is_err()); - if let Err(e) = projected { - assert_eq!( - e.to_string(), - "Schema error: project index 3 out of bounds, max field 3".to_string() - ) - } - } - - #[test] - fn test_schema_contains() { - let mut metadata1 = HashMap::new(); - metadata1.insert("meta".to_string(), "data".to_string()); - - let schema1 = Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("address", DataType::Utf8, false), - Field::new("priority", DataType::UInt8, false), - ]) - .with_metadata(metadata1.clone()); - - let mut metadata2 = HashMap::new(); - metadata2.insert("meta".to_string(), "data".to_string()); - metadata2.insert("meta2".to_string(), "data".to_string()); - let schema2 = Schema::new(vec![ - Field::new("name", DataType::Utf8, false), - Field::new("address", DataType::Utf8, false), - Field::new("priority", DataType::UInt8, false), - ]) - .with_metadata(metadata2); - - // reflexivity - assert!(schema1.contains(&schema1)); - assert!(schema2.contains(&schema2)); - - assert!(!schema1.contains(&schema2)); - assert!(schema2.contains(&schema1)); - } -} diff --git a/arrow/src/error.rs b/arrow/src/error.rs index 5d92fb93017..f7acec0b34d 100644 --- a/arrow/src/error.rs +++ b/arrow/src/error.rs @@ -16,120 +16,7 @@ // under the License. //! Defines `ArrowError` for representing failures in various Arrow operations. -use std::fmt::{Debug, Display, Formatter}; -use std::io::Write; -use std::error::Error; - -/// Many different operations in the `arrow` crate return this error type. -#[derive(Debug)] -pub enum ArrowError { - /// Returned when functionality is not yet available. - NotYetImplemented(String), - ExternalError(Box), - CastError(String), - MemoryError(String), - ParseError(String), - SchemaError(String), - ComputeError(String), - DivideByZero, - CsvError(String), - JsonError(String), - IoError(String), - InvalidArgumentError(String), - ParquetError(String), - /// Error during import or export to/from the C Data Interface - CDataInterface(String), - DictionaryKeyOverflowError, -} - -impl ArrowError { - /// Wraps an external error in an `ArrowError`. - pub fn from_external_error( - error: Box, - ) -> Self { - Self::ExternalError(error) - } -} - -impl From<::std::io::Error> for ArrowError { - fn from(error: std::io::Error) -> Self { - ArrowError::IoError(error.to_string()) - } -} - -#[cfg(feature = "csv")] -impl From for ArrowError { - fn from(error: csv_crate::Error) -> Self { - match error.kind() { - csv_crate::ErrorKind::Io(error) => ArrowError::CsvError(error.to_string()), - csv_crate::ErrorKind::Utf8 { pos: _, err } => ArrowError::CsvError(format!( - "Encountered UTF-8 error while reading CSV file: {}", - err - )), - csv_crate::ErrorKind::UnequalLengths { - expected_len, len, .. - } => ArrowError::CsvError(format!( - "Encountered unequal lengths between records on CSV file. Expected {} \ - records, found {} records", - len, expected_len - )), - _ => ArrowError::CsvError("Error reading CSV file".to_string()), - } - } -} - -impl From<::std::string::FromUtf8Error> for ArrowError { - fn from(error: std::string::FromUtf8Error) -> Self { - ArrowError::ParseError(error.to_string()) - } -} - -#[cfg(feature = "json")] -impl From for ArrowError { - fn from(error: serde_json::Error) -> Self { - ArrowError::JsonError(error.to_string()) - } -} - -impl From<::std::io::IntoInnerError> for ArrowError { - fn from(error: std::io::IntoInnerError) -> Self { - ArrowError::IoError(error.to_string()) - } -} - -impl Display for ArrowError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - ArrowError::NotYetImplemented(source) => { - write!(f, "Not yet implemented: {}", &source) - } - ArrowError::ExternalError(source) => write!(f, "External error: {}", &source), - ArrowError::CastError(desc) => write!(f, "Cast error: {}", desc), - ArrowError::MemoryError(desc) => write!(f, "Memory error: {}", desc), - ArrowError::ParseError(desc) => write!(f, "Parser error: {}", desc), - ArrowError::SchemaError(desc) => write!(f, "Schema error: {}", desc), - ArrowError::ComputeError(desc) => write!(f, "Compute error: {}", desc), - ArrowError::DivideByZero => write!(f, "Divide by zero error"), - ArrowError::CsvError(desc) => write!(f, "Csv error: {}", desc), - ArrowError::JsonError(desc) => write!(f, "Json error: {}", desc), - ArrowError::IoError(desc) => write!(f, "Io error: {}", desc), - ArrowError::InvalidArgumentError(desc) => { - write!(f, "Invalid argument error: {}", desc) - } - ArrowError::ParquetError(desc) => { - write!(f, "Parquet argument error: {}", desc) - } - ArrowError::CDataInterface(desc) => { - write!(f, "C Data interface error: {}", desc) - } - ArrowError::DictionaryKeyOverflowError => { - write!(f, "Dictionary key bigger than the key type") - } - } - } -} - -impl Error for ArrowError {} +pub use arrow_schema::ArrowError; pub type Result = std::result::Result; diff --git a/arrow/src/json/mod.rs b/arrow/src/json/mod.rs index 836145bb08e..21f96d90a5d 100644 --- a/arrow/src/json/mod.rs +++ b/arrow/src/json/mod.rs @@ -80,3 +80,36 @@ impl JsonSerializable for f64 { Number::from_f64(self).map(Value::Number) } } + +#[cfg(test)] +mod tests { + use super::*; + + use serde_json::{ + Number, + Value::{Bool, Number as VNumber, String as VString}, + }; + + #[test] + fn test_arrow_native_type_to_json() { + assert_eq!(Some(Bool(true)), true.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1i8.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1i16.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1i32.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1i64.into_json_value()); + assert_eq!(Some(VString("1".to_string())), 1i128.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1u8.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1u16.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1u32.into_json_value()); + assert_eq!(Some(VNumber(Number::from(1))), 1u64.into_json_value()); + assert_eq!( + Some(VNumber(Number::from_f64(0.01f64).unwrap())), + 0.01.into_json_value() + ); + assert_eq!( + Some(VNumber(Number::from_f64(0.01f64).unwrap())), + 0.01f64.into_json_value() + ); + assert_eq!(None, f32::NAN.into_json_value()); + } +} diff --git a/arrow/src/json/writer.rs b/arrow/src/json/writer.rs index bf40b31b494..beee02582ff 100644 --- a/arrow/src/json/writer.rs +++ b/arrow/src/json/writer.rs @@ -700,7 +700,10 @@ where } self.format.start_row(&mut self.writer, is_first_row)?; - self.writer.write_all(&serde_json::to_vec(row)?)?; + self.writer.write_all( + &serde_json::to_vec(row) + .map_err(|error| ArrowError::JsonError(error.to_string()))?, + )?; self.format.end_row(&mut self.writer)?; Ok(()) } diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 89463e4c8fd..90caa2e3a5c 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -39,10 +39,8 @@ use crate::record_batch::RecordBatch; import_exception!(pyarrow, ArrowException); pub type PyArrowException = ArrowException; -impl From for PyErr { - fn from(err: ArrowError) -> PyErr { - PyArrowException::new_err(err.to_string()) - } +fn to_py_err(err: ArrowError) -> PyErr { + PyArrowException::new_err(err.to_string()) } pub trait PyArrowConvert: Sized { @@ -55,12 +53,12 @@ impl PyArrowConvert for DataType { let c_schema = FFI_ArrowSchema::empty(); let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?; - let dtype = DataType::try_from(&c_schema)?; + let dtype = DataType::try_from(&c_schema).map_err(to_py_err)?; Ok(dtype) } fn to_pyarrow(&self, py: Python) -> PyResult { - let c_schema = FFI_ArrowSchema::try_from(self)?; + let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?; let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; let module = py.import("pyarrow")?; let class = module.getattr("DataType")?; @@ -75,12 +73,12 @@ impl PyArrowConvert for Field { let c_schema = FFI_ArrowSchema::empty(); let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?; - let field = Field::try_from(&c_schema)?; + let field = Field::try_from(&c_schema).map_err(to_py_err)?; Ok(field) } fn to_pyarrow(&self, py: Python) -> PyResult { - let c_schema = FFI_ArrowSchema::try_from(self)?; + let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?; let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; let module = py.import("pyarrow")?; let class = module.getattr("Field")?; @@ -95,12 +93,12 @@ impl PyArrowConvert for Schema { let c_schema = FFI_ArrowSchema::empty(); let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?; - let schema = Schema::try_from(&c_schema)?; + let schema = Schema::try_from(&c_schema).map_err(to_py_err)?; Ok(schema) } fn to_pyarrow(&self, py: Python) -> PyResult { - let c_schema = FFI_ArrowSchema::try_from(self)?; + let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?; let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; let module = py.import("pyarrow")?; let class = module.getattr("Schema")?; @@ -127,15 +125,17 @@ impl PyArrowConvert for ArrayData { ), )?; - let ffi_array = - unsafe { ffi::ArrowArray::try_from_raw(array_pointer, schema_pointer)? }; - let data = ArrayData::try_from(ffi_array)?; + let ffi_array = unsafe { + ffi::ArrowArray::try_from_raw(array_pointer, schema_pointer) + .map_err(to_py_err)? + }; + let data = ArrayData::try_from(ffi_array).map_err(to_py_err)?; Ok(data) } fn to_pyarrow(&self, py: Python) -> PyResult { - let array = ffi::ArrowArray::try_from(self.clone())?; + let array = ffi::ArrowArray::try_from(self.clone()).map_err(to_py_err)?; let (array_pointer, schema_pointer) = ffi::ArrowArray::into_raw(array); let module = py.import("pyarrow")?; @@ -151,6 +151,21 @@ impl PyArrowConvert for ArrayData { } } +impl PyArrowConvert for Vec { + fn from_pyarrow(value: &PyAny) -> PyResult { + let list = value.downcast::()?; + list.iter().map(|x| T::from_pyarrow(&x)).collect() + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let values = self + .iter() + .map(|v| v.to_pyarrow(py)) + .collect::>>()?; + Ok(values.to_object(py)) + } +} + impl PyArrowConvert for T where T: Array + From, @@ -176,7 +191,7 @@ impl PyArrowConvert for RecordBatch { .map(ArrayRef::from_pyarrow) .collect::>()?; - let batch = RecordBatch::try_new(schema, arrays)?; + let batch = RecordBatch::try_new(schema, arrays).map_err(to_py_err)?; Ok(batch) } @@ -237,25 +252,25 @@ impl PyArrowConvert for ArrowArrayStreamReader { } } -macro_rules! add_conversion { - ($typ:ty) => { - impl<'source> FromPyObject<'source> for $typ { - fn extract(value: &'source PyAny) -> PyResult { - Self::from_pyarrow(value) - } - } +/// A newtype wrapper around a `T: PyArrowConvert` that implements +/// [`FromPyObject`] and [`IntoPy`] allowing usage with pyo3 macros +#[derive(Debug)] +pub struct PyArrowType(pub T); - impl<'a> IntoPy for $typ { - fn into_py(self, py: Python) -> PyObject { - self.to_pyarrow(py).unwrap() - } - } - }; +impl<'source, T: PyArrowConvert> FromPyObject<'source> for PyArrowType { + fn extract(value: &'source PyAny) -> PyResult { + Ok(Self(T::from_pyarrow(value)?)) + } +} + +impl<'a, T: PyArrowConvert> IntoPy for PyArrowType { + fn into_py(self, py: Python) -> PyObject { + self.0.to_pyarrow(py).unwrap() + } } -add_conversion!(DataType); -add_conversion!(Field); -add_conversion!(Schema); -add_conversion!(ArrayData); -add_conversion!(RecordBatch); -add_conversion!(ArrowArrayStreamReader); +impl From for PyArrowType { + fn from(s: T) -> Self { + Self(s) + } +} diff --git a/integration-testing/src/util/mod.rs b/integration-testing/src/util/mod.rs index 9ecd301360f..f9ddc0e6f4b 100644 --- a/integration-testing/src/util/mod.rs +++ b/integration-testing/src/util/mod.rs @@ -265,7 +265,8 @@ impl ArrowJsonField { /// TODO: convert to use an Into fn to_arrow_field(&self) -> Result { // a bit regressive, but we have to convert the field to JSON in order to convert it - let field = serde_json::to_value(self)?; + let field = serde_json::to_value(self) + .map_err(|error| ArrowError::JsonError(error.to_string()))?; field_from_json(&field) } }