Skip to content

Commit

Permalink
Round-tripping and operations on decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
FiV0 committed Jun 20, 2023
1 parent a2a40a7 commit 44e7e3a
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 15 deletions.
7 changes: 4 additions & 3 deletions core/src/main/clojure/xtdb/expression.clj
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@
:f32 'float
:f64 'double
:timestamp-tz 'long
:duration 'long}
:duration 'long
:decimal 'bigdec}
types/col-type-head))

(def idx-sym (gensym 'idx))
Expand Down Expand Up @@ -227,7 +228,7 @@
:hierarchy #'types/col-type-hierarchy)

(def ^:private col-type->rw-fn
'{:bool Boolean, :i8 Byte, :i16 Short, :i32 Int, :i64 Long, :f32 Float, :f64 Double
'{:bool Boolean, :i8 Byte, :i16 Short, :i32 Int, :i64 Long, :f32 Float, :f64 Double,
:date Long, :time-local Long, :timestamp-tz Long, :timestamp-local Long, :duration Long
:utf8 Bytes, :varbinary Bytes, :keyword Bytes, :uuid Bytes, :uri Bytes

Expand All @@ -243,7 +244,7 @@
(defmethod read-value-code k [_ & args] `(~(symbol (str ".read" rw-fn)) ~@args))
(defmethod write-value-code k [_ & args] `(~(symbol (str ".write" rw-fn)) ~@args)))

(doseq [[k tag] {:interval PeriodDuration}]
(doseq [[k tag] {:interval PeriodDuration :decimal BigDecimal}]
(defmethod read-value-code k [_ & args]
(-> `(.readObject ~@args) (with-tag tag)))

Expand Down
20 changes: 15 additions & 5 deletions core/src/main/clojure/xtdb/types.clj
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
(java.time Duration Instant LocalDate LocalTime Period ZoneId)
java.util.concurrent.ConcurrentHashMap
java.util.function.Function
(org.apache.arrow.vector BitVector DateDayVector DateMilliVector IntervalDayVector IntervalMonthDayNanoVector IntervalYearVector TimeMicroVector TimeMilliVector TimeNanoVector TimeSecVector TimeStampMicroTZVector TimeStampMilliTZVector TimeStampNanoTZVector TimeStampSecTZVector ValueVector VarBinaryVector VarCharVector)
(org.apache.arrow.vector BitVector DateDayVector DateMilliVector DecimalVector Decimal256Vector IntervalDayVector IntervalMonthDayNanoVector IntervalYearVector TimeMicroVector TimeMilliVector TimeNanoVector TimeSecVector TimeStampMicroTZVector TimeStampMilliTZVector TimeStampNanoTZVector TimeStampSecTZVector ValueVector VarBinaryVector VarCharVector)
(org.apache.arrow.vector.complex DenseUnionVector FixedSizeListVector ListVector StructVector)
(org.apache.arrow.vector.holders NullableIntervalDayHolder NullableIntervalMonthDayNanoHolder)
(org.apache.arrow.vector.types DateUnit FloatingPointPrecision IntervalUnit TimeUnit Types$MinorType UnionMode)
(org.apache.arrow.vector.types.pojo ArrowType ArrowType$Binary ArrowType$Bool ArrowType$Date ArrowType$Duration ArrowType$FixedSizeBinary ArrowType$FixedSizeList ArrowType$FloatingPoint ArrowType$Int ArrowType$Interval ArrowType$List ArrowType$Null ArrowType$Struct ArrowType$Time ArrowType$Time ArrowType$Timestamp ArrowType$Union ArrowType$Utf8 Field FieldType)
(org.apache.arrow.vector.types.pojo ArrowType ArrowType$Binary ArrowType$Bool ArrowType$Date ArrowType$Decimal ArrowType$Duration ArrowType$FixedSizeBinary ArrowType$FixedSizeList ArrowType$FloatingPoint ArrowType$Int ArrowType$Interval ArrowType$List ArrowType$Null ArrowType$Struct ArrowType$Time ArrowType$Time ArrowType$Timestamp ArrowType$Union ArrowType$Utf8 Field FieldType)
(xtdb.types IntervalDayTime IntervalMonthDayNano IntervalYearMonth)
(xtdb.vector.extensions AbsentType AbsentVector ClojureFormType KeywordType SetType SetVector UriType UuidType)))



(set! *unchecked-math* :warn-on-boxed)

(def struct-type (.getType Types$MinorType/STRUCT))
Expand Down Expand Up @@ -216,7 +218,7 @@
(derive :i8 :int) (derive :i16 :int) (derive :i32 :int) (derive :i64 :int)
(derive :u8 :uint) (derive :u16 :uint) (derive :u32 :uint) (derive :u64 :uint)

(derive :uint :num) (derive :int :num) (derive :float :num)
(derive :uint :num) (derive :int :num) (derive :float :num) (derive :decimal :num)
(derive :num :any)

(derive :date-time :any)
Expand Down Expand Up @@ -312,9 +314,15 @@
:null Types$MinorType/NULL, :bool Types$MinorType/BIT
:f32 Types$MinorType/FLOAT4, :f64 Types$MinorType/FLOAT8
:i8 Types$MinorType/TINYINT, :i16 Types$MinorType/SMALLINT, :i32 Types$MinorType/INT, :i64 Types$MinorType/BIGINT
:utf8 Types$MinorType/VARCHAR, :varbinary Types$MinorType/VARBINARY)]
:utf8 Types$MinorType/VARCHAR, :varbinary Types$MinorType/VARBINARY
:decimal Types$MinorType/DECIMAL)]
(->field col-name (.getType minor-type) (or nullable? (= col-type :null)))))

(defmethod col-type->field* :decimal [col-name nullable? _col-type]
;; TODO decide on how deal with precision that is out of this range but still fits into
;; a 128 bit decimal
(->field col-name (ArrowType$Decimal/createDecimal 38 19 (int 128)) nullable?))

(defmethod col-type->field* :keyword [col-name nullable? _col-type]
(->field col-name KeywordType/INSTANCE nullable?))

Expand Down Expand Up @@ -349,6 +357,7 @@
(defmethod arrow-type->col-type ArrowType$Bool [_] :bool)
(defmethod arrow-type->col-type ArrowType$Utf8 [_] :utf8)
(defmethod arrow-type->col-type ArrowType$Binary [_] :varbinary)
(defmethod arrow-type->col-type ArrowType$Decimal [_] :decimal)

(defn- col-type->nullable-col-type [col-type]
(zmatch col-type
Expand Down Expand Up @@ -564,7 +573,8 @@
(-> col-type-hierarchy
(derive :i8 :i16) (derive :i16 :i32) (derive :i32 :i64)
(derive :f32 :f64)
(derive :int :f64) (derive :int :f32)))
(derive :int :f64) (derive :int :f32)
(derive :int :decimal) (derive :uint :decimal)))

(defmethod least-upper-bound2 [:num :num] [x-type y-type]
(cond
Expand Down
24 changes: 22 additions & 2 deletions core/src/main/clojure/xtdb/vector/writer.clj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
[xtdb.vector :as vec]
[xtdb.vector.indirect :as iv])
(:import (clojure.lang Keyword)
(java.math BigDecimal)
(java.lang AutoCloseable)
java.net.URI
(java.nio ByteBuffer CharBuffer)
Expand All @@ -15,7 +16,7 @@
(java.util.function Function)
java.util.function.Function
(org.apache.arrow.memory BufferAllocator)
(org.apache.arrow.vector BigIntVector BitVector DateDayVector DateMilliVector DurationVector ExtensionTypeVector FixedSizeBinaryVector Float4Vector Float8Vector IntVector IntervalDayVector IntervalMonthDayNanoVector IntervalYearVector NullVector PeriodDuration SmallIntVector TimeMicroVector TimeMilliVector TimeNanoVector TimeSecVector TimeStampVector TinyIntVector ValueVector VarBinaryVector VarCharVector VectorSchemaRoot)
(org.apache.arrow.vector BigIntVector BitVector DecimalVector Decimal256Vector DateDayVector DateMilliVector DurationVector ExtensionTypeVector FixedSizeBinaryVector Float4Vector Float8Vector IntVector IntervalDayVector IntervalMonthDayNanoVector IntervalYearVector NullVector PeriodDuration SmallIntVector TimeMicroVector TimeMilliVector TimeNanoVector TimeSecVector TimeStampVector TinyIntVector ValueVector VarBinaryVector VarCharVector VectorSchemaRoot)
(org.apache.arrow.vector.complex DenseUnionVector ListVector StructVector)
(org.apache.arrow.vector.types.pojo ArrowType$List ArrowType$Struct ArrowType$Union Field FieldType)
xtdb.api.protocols.ClojureForm
Expand Down Expand Up @@ -140,6 +141,21 @@
(writeNull [_ _] (.setNull arrow-vec (.getPositionAndIncrement wp)))
(writeLong [_ days] (.setSafe arrow-vec (.getPositionAndIncrement wp) (* days 86400000)))))))

(extend-protocol WriterFactory
DecimalVector
(->writer [arrow-vec]
(let [wp (IWriterPosition/build (.getValueCount arrow-vec))]
(reify IVectorWriter
(getVector [_] (doto arrow-vec (.setValueCount (.getPosition wp))))
(clear [_] (.clear arrow-vec) (.setPosition wp 0))
(rowCopier [this src-vec] (scalar-copier this src-vec))
(writerPosition [_] wp)
(writeNull [_ _] (.setNull arrow-vec (.getPositionAndIncrement wp)))
(writeObject [_ decimal]
(let [new-decimal (.setScale decimal (.getScale arrow-vec))]
(.setSafe arrow-vec (.getPositionAndIncrement wp) new-decimal)))
(writerForType [this _col-type] this)))))

(extend-protocol ArrowWriteable
nil
(value->col-type [_] :null)
Expand Down Expand Up @@ -171,7 +187,11 @@

Double
(value->col-type [_] :f64)
(write-value! [v ^IVectorWriter w] (.writeDouble w v)))
(write-value! [v ^IVectorWriter w] (.writeDouble w v))

BigDecimal
(value->col-type [_] :decimal)
(write-value! [v ^IVectorWriter w] (.writeObject w v)))

(extend-protocol WriterFactory
DurationVector
Expand Down
44 changes: 44 additions & 0 deletions src/test/clojure/xtdb/expression_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,50 @@
(t/is (= {:res 2.0, :res-type :f32}
(run-test '/ (float 4) (int 2))))))


(t/deftest test-decimal-arithmetic-and-coersion
(letfn [(run-test [f x y]
(with-open [rel (tu/open-rel [(tu/open-vec "x" [x])
(tu/open-vec "y" [y])])]
(-> (run-projection rel (list f 'x 'y))
(update :res first))))]

;; standard ops
(t/is (= {:res 2.0M, :res-type :decimal}
(run-test '+ (bigdec 1.0) (bigdec 1.0))))
(t/is (= {:res 0.0M, :res-type :decimal}
(run-test '- (bigdec 1.0) (bigdec 1.0))))
(t/is (= {:res 0.5M, :res-type :decimal}
(run-test '/ (bigdec 1.0) (bigdec 2.0))))
(t/is (= {:res 2.0M, :res-type :decimal}
(run-test '* (bigdec 1.0) (bigdec 2.0))))

;; TODO LUB
;; some of these work because of implicit conversion by Clojure
;; (t/is (= {:res 2.0M, :res-type :decimal}
;; (run-test '* (byte 1) (bigdec 2.0))))
;; (t/is (= {:res 2.0M, :res-type :decimal}
;; (run-test '* (short 1) (bigdec 2.0))))
;; (t/is (= {:res 2.0M, :res-type :decimal}
;; (run-test '* (int 1) (bigdec 2.0))))
;; (t/is (= {:res true, :res-type :bool}
;; (run-test '< (int 1) (bigdec 2.0))))
;; (t/is (= {:res true, :res-type :bool}
;; (run-test '> (int 1) (bigdec 2.0))))
;; (t/is (= {:res 2.0, :res-type :double}
;; (run-test '+ (float 1.0) (bigdec 2.0))))

;; TODO decide on behaviour
;; out of fixed precision range
(t/is (thrown? UnsupportedOperationException
(run-test '+ (bigdec 1E+19M) (bigdec 1E+19M))))
;; overflow
(t/is (thrown? UnsupportedOperationException
(run-test '+ (bigdec 1E+35M) (bigdec 1E-35M))))
;; underflow
(t/is (thrown? ArithmeticException
(run-test '/ (bigdec 1E-35M) (bigdec 1E+35M))))))

(t/deftest test-throws-on-overflow
(letfn [(run-unary-test [f x]
(with-open [rel (tu/open-rel [(tu/open-vec "x" [x])])]
Expand Down
18 changes: 13 additions & 5 deletions src/test/clojure/xtdb/types_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
[xtdb.vector :as vec]
[xtdb.vector.writer :as vw])
(:import java.net.URI
(java.math BigDecimal)
java.nio.ByteBuffer
(java.time Instant LocalDate LocalTime OffsetDateTime ZonedDateTime)
(org.apache.arrow.vector BigIntVector BitVector DateDayVector Float4Vector Float8Vector IntVector IntervalMonthDayNanoVector NullVector SmallIntVector TimeNanoVector TimeStampMicroTZVector TinyIntVector VarBinaryVector VarCharVector)
(org.apache.arrow.vector BigIntVector BitVector DateDayVector DecimalVector Decimal256Vector Float4Vector Float8Vector IntVector IntervalMonthDayNanoVector NullVector SmallIntVector TimeNanoVector TimeStampMicroTZVector TinyIntVector VarBinaryVector VarCharVector)
(org.apache.arrow.vector.complex DenseUnionVector ListVector StructVector)
(xtdb.types IntervalDayTime IntervalYearMonth)
(xtdb.vector IVectorWriter)
Expand All @@ -18,7 +19,7 @@


(defn- test-read [col-type-fn write-fn vs]
;; TODO no longer types, but there are other things in here that depend on `test-read`
;; TODO no longer types, but there are other things in here that depend on `test-read`
(with-open [duv (DenseUnionVector/empty "" tu/*allocator*)]
(let [duv-writer (vw/->writer duv)]
(doseq [v vs]
Expand All @@ -33,9 +34,9 @@
(test-read vw/value->col-type #(vw/write-value! %2 %1) vs))

(t/deftest round-trips-values
(t/is (= {:vs [false nil 2 1 6 4 3.14 2.0]
:vec-types [BitVector NullVector BigIntVector TinyIntVector SmallIntVector IntVector Float8Vector Float4Vector]}
(test-round-trip [false nil (long 2) (byte 1) (short 6) (int 4) (double 3.14) (float 2)]))
(t/is (= {:vs [false nil 2 1 6 4 3.14 2.0 BigDecimal/ONE]
:vec-types [BitVector NullVector BigIntVector TinyIntVector SmallIntVector IntVector Float8Vector Float4Vector DecimalVector]}
(test-round-trip [false nil (long 2) (byte 1) (short 6) (int 4) (double 3.14) (float 2) BigDecimal/ONE]))
"primitives")

(t/is (= {:vs ["Hello"
Expand Down Expand Up @@ -79,6 +80,13 @@
(test-round-trip vs))
"extension types")))

(t/deftest decimal-vector-test
(let [vs [BigDecimal/ONE 123.45M 12.3M]]
(->> "BigDecimal can be round tripped"
(t/is (= {:vs vs
:vec-types [DecimalVector DecimalVector DecimalVector]}
(test-round-trip vs))))))

(t/deftest date-vector-test
(let [vs [(LocalDate/of 2007 12 11)]]
(->> "LocalDate can be round tripped through DAY date vectors"
Expand Down

0 comments on commit 44e7e3a

Please sign in to comment.