Skip to content

Commit

Permalink
Initial proposal for ParseDateTime
Browse files Browse the repository at this point in the history
See #5309 for details.

Signed-off-by: Christian Bourjau <christian.bourjau@quantco.com>
  • Loading branch information
cbourjau committed Jul 14, 2023
1 parent e9474bc commit fc034e6
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 4 deletions.
58 changes: 58 additions & 0 deletions onnx/backend/test/case/node/parsedatetime.py
@@ -0,0 +1,58 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from datetime import datetime

import numpy as np

import onnx
from onnx import numpy_helper
from onnx.backend.sample.ops.abs import abs

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'abs' is not used.
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect


class ParseDateTime(Base):
@staticmethod
def export_float_nan_default() -> None:
fmt = "%d/%m/%y %H:%M"
default = float("NaN")
node = onnx.helper.make_node(
"ParseDateTime",
inputs=["x"],
outputs=["y"],
format=fmt,
unit="s",
default=numpy_helper.from_array(np.array(default, np.float64))
)
x = np.array(["21/11/06 16:30", "foobar"], dtype=object)
y = []
for s in x:
try:
# datetime.timestamp() returns a float
y.append(datetime.strptime(s, fmt).timestamp())
except ValueError:
y.append(default)
expect(node, inputs=[x], outputs=[np.array(y)], name="test_parsedatetime")


@staticmethod
def export_int_default() -> None:
fmt = "%d/%m/%y %H:%M"
default = np.iinfo(np.int64).min
node = onnx.helper.make_node(
"ParseDateTime",
inputs=["x"],
outputs=["y"],
format=fmt,
unit="s",
default=numpy_helper.from_array(np.array(default, np.int64))
)
x = np.array(["21/11/06 16:30", "foobar"], dtype=object)
y = []
for s in x:
try:
y.append(datetime.strptime(s, fmt).timestamp())
except ValueError:
y.append(default)
expect(node, inputs=[x], outputs=[np.array(y, np.int64)], name="test_parsedatetime")
Binary file not shown.
@@ -0,0 +1 @@
221/11/06 16:302foobarBx
Binary file not shown.
10 changes: 6 additions & 4 deletions onnx/defs/operator_sets.h
Expand Up @@ -1102,18 +1102,20 @@ class OpSet_Onnx_ver19 {
};

// Forward declarations for ai.onnx version 20
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ParseDateTime);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, StringConcat);

// Iterate over schema from ai.onnx version 20
class OpSet_Onnx_ver20 {
public:
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ParseDateTime)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, StringConcat)>());
}
};
Expand Down
39 changes: 39 additions & 0 deletions onnx/defs/text/defs.cc
Expand Up @@ -5,6 +5,45 @@
#include "onnx/defs/schema.h"

namespace ONNX_NAMESPACE {
static const char* ParseDateTime_doc = R"DOC(Parse a datetime string into a (floating point) Unix time stamp.)DOC";
ONNX_OPERATOR_SET_SCHEMA(
ParseDateTime,
20,
OpSchema()
.Input(0, "X", "Tensor with datetime strings", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Output(0, "y", "Unix time stamps", "T2", OpSchema::Single, true, 1, OpSchema::NonDifferentiable)
.Attr("format", "Format description in the syntax of C's `strptime`.", AttributeProto::STRING)
.Attr(
"unit",
"Unit of the returned time stamp. Allowed values are: 's' (second), 'ms' (millisecond), 'us' (microsecond) or 'ns' (nanosecond).",
AttributeProto::STRING)
.Attr(
"default",
"Default value to be used if the parsing fails. The tensor must be of rank 0 and either of type `tensor(int64)` or `tensor(double)`. The tensor type is the output type. If 'default' is specified, the output type is `tensor(int64)` and the behavior for failing to parse an input element is implementation defined.",
AttributeProto::TENSOR,
OPTIONAL_VALUE)

.TypeConstraint("T1", {"tensor(string)"}, "UTF-8 datetime strings")
.TypeConstraint("T2", {"tensor(double)", "tensor(int64)"}, "Output type depends on 'default' attribute.")
.SetDoc(ParseDateTime_doc)
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
auto* value = ctx.getAttribute("value");

if (hasInputShape(ctx, 0)) {
propagateShapeFromInputToOutput(ctx, 0, 0);
}

if (nullptr != value) {
// OpSchema::Verify check ensures that the attribute value has_t():
const TensorProto& tensor_proto = value->t();
updateOutputElemType(ctx, 0, tensor_proto.data_type());
return;
} else {
updateOutputElemType(ctx, 0, TensorProto::INT64);
return;
}
}));

static const char* StringConcat_doc =
R"DOC(StringConcat concatenates string tensors elementwise (with NumPy-style broadcasting support))DOC";
ONNX_OPERATOR_SET_SCHEMA(
Expand Down
1 change: 1 addition & 0 deletions onnx/reference/ops/_op_list.py
Expand Up @@ -148,6 +148,7 @@
from onnx.reference.ops.op_optional_has_element import OptionalHasElement
from onnx.reference.ops.op_or import Or
from onnx.reference.ops.op_pad import Pad_1, Pad_2, Pad_11, Pad_18
from onnx.reference.ops.op_parsedatetime import ParseDateTime

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'ParseDateTime' is not used.
from onnx.reference.ops.op_pow import Pow
from onnx.reference.ops.op_prelu import PRelu
from onnx.reference.ops.op_qlinear_conv import QLinearConv
Expand Down
17 changes: 17 additions & 0 deletions onnx/reference/ops/op_parsedatetime.py
@@ -0,0 +1,17 @@
# Copyright (c) ONNX Project Contributors

# SPDX-License-Identifier: Apache-2.0
# pylint: disable=W0221
from datetime import datetime

import numpy as np

from onnx.reference.ops._op import OpRunUnaryNum


class ParseDateTime(OpRunUnaryNum):
def _run(self, x, format: str, unit: str, default=None): # type: ignore
out = np.array([datetime.strptime(s, format).timestamp() for s in x])
out[np.isnan(out)] = default
out = out.astype(default.dtype)
return (out,)

0 comments on commit fc034e6

Please sign in to comment.