From 787e2cc0067921c104cc1ab9c58981aa6bf259e3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 3 Aug 2021 15:02:46 -0600 Subject: [PATCH 1/3] BUG: Fix NaT handling in the PyArray_CompareFunc for datetime and timedelta In #12658 and #15068 the sort ordering for datetime and timedelta was changed so that NaT sorts to the end, but the internal compare array function was not updated. Fixes #19574. --- numpy/core/src/multiarray/arraytypes.c.src | 43 ++++++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index ad74612272b2..20fe2ce0d37d 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -2805,11 +2805,9 @@ BOOL_compare(npy_bool *ip1, npy_bool *ip2, PyArrayObject *NPY_UNUSED(ap)) /**begin repeat * #TYPE = BYTE, UBYTE, SHORT, USHORT, INT, UINT, - * LONG, ULONG, LONGLONG, ULONGLONG, - * DATETIME, TIMEDELTA# + * LONG, ULONG, LONGLONG, ULONGLONG# * #type = npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint, - * npy_long, npy_ulong, npy_longlong, npy_ulonglong, - * npy_datetime, npy_timedelta# + * npy_long, npy_ulong, npy_longlong, npy_ulonglong# */ static int @@ -2920,6 +2918,43 @@ C@TYPE@_compare(@type@ *pa, @type@ *pb) /**end repeat**/ +/**begin repeat + * #TYPE = DATETIME, TIMEDELTA# + * #type = npy_datetime, npy_timedelta# + */ + +#define LT(a,b) ((a) < (b) || ((b) != (b) && (a) ==(a))) + +static int +@TYPE@_compare(@type@ *pa, @type@ *pb) +{ + const @type@ a = *pa; + const @type@ b = *pb; + int ret; + + printf("Calling @TYPE@_compare\n"); + if (a == NPY_DATETIME_NAT) { + ret = 1; + } + else if (b == NPY_DATETIME_NAT) { + ret = -1; + } + else if (LT(a,b)) { + ret = -1; + } + else if (LT(b,a)) { + ret = 1; + } + else { + ret = 0; + } + return ret; +} + +#undef LT + +/**end repeat**/ + static int HALF_compare (npy_half *pa, npy_half *pb, PyArrayObject *NPY_UNUSED(ap)) { From 9416e593bf8029bbdc05c55c27c84dfd39effb2d Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 3 Aug 2021 16:19:23 -0600 Subject: [PATCH 2/3] Remove debug print --- numpy/core/src/multiarray/arraytypes.c.src | 1 - 1 file changed, 1 deletion(-) diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 20fe2ce0d37d..58dcb7bc023a 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -2932,7 +2932,6 @@ static int const @type@ b = *pb; int ret; - printf("Calling @TYPE@_compare\n"); if (a == NPY_DATETIME_NAT) { ret = 1; } From 73e7a1293745cbb8054fd019c1e03ade9e27638b Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 3 Aug 2021 16:23:38 -0600 Subject: [PATCH 3/3] Fix the implementation of DATETIME_compare I accidentally based it off the float compare template instead of the integer compare template. It also now properly handles the case when both arguments are NaT. --- numpy/core/src/multiarray/arraytypes.c.src | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 58dcb7bc023a..b3ea7544d974 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -2923,8 +2923,6 @@ C@TYPE@_compare(@type@ *pa, @type@ *pb) * #type = npy_datetime, npy_timedelta# */ -#define LT(a,b) ((a) < (b) || ((b) != (b) && (a) ==(a))) - static int @TYPE@_compare(@type@ *pa, @type@ *pb) { @@ -2933,25 +2931,22 @@ static int int ret; if (a == NPY_DATETIME_NAT) { - ret = 1; + if (b == NPY_DATETIME_NAT) { + ret = 0; + } + else { + ret = 1; + } } else if (b == NPY_DATETIME_NAT) { ret = -1; } - else if (LT(a,b)) { - ret = -1; - } - else if (LT(b,a)) { - ret = 1; - } else { - ret = 0; + ret = a < b ? -1 : a == b ? 0 : 1; } return ret; } -#undef LT - /**end repeat**/ static int