Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR playground #5644

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
248 changes: 248 additions & 0 deletions onnx/ir/_attributes.py
@@ -0,0 +1,248 @@
from __future__ import annotations

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.

from onnx.ir import _enums, _protocols
from typing import Sequence


# NOTE: None of these classes will have a "to_onnx" method because
# We cannot assume that the build tool chain has protoc installed.


class Attribute(_protocols.AttributeProtocol):
"""Base class for ONNX attributes."""

# NOTE: We use primitive types for T

def __init__(
self,
name: str,
typ: _enums.AttributeType,
value,
*,
ref_attr_name: str,
doc_string: str,
):
self.name = name
self.type = typ
self.value = value
self.ref_attr_name = ref_attr_name
self.doc_string = doc_string

# TODO: How do we represent ref attributes? Do we need to?


# NOTE: The following classes are just supporting classes (partially applied) for convenience
# But I think they would be useful to have in the IR by having the type info
# explicitly in the class type.
# Arguably, they can also be functions that return Attribute objects.
# TODO: We need something that is easy for pattern matchers, otherwise they
# can just be make_attribute functions.
class AttrFloat32(Attribute):
def __init__(
self, name: str, value: float, ref_attr_name: str = "", doc_string: str = ""
):
super().__init__(
name,
_enums.AttributeType.FLOAT,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrInt64(Attribute):
def __init__(
self, name: str, value: int, ref_attr_name: str = "", doc_string: str = ""
):
super().__init__(
name,
_enums.AttributeType.INT,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrString(Attribute):
def __init__(
self, name: str, value: str, ref_attr_name: str = "", doc_string: str = ""
):
super().__init__(
name,
_enums.AttributeType.STRING,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


# NOTE: Tensor should be a tensor proto
class AttrTensor(Attribute):
def __init__(
self,
name: str,
value: _protocols.TensorProtocol,
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.TENSOR,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrGraph(Attribute):
def __init__(
self,
name: str,
value: _protocols.GraphProtocol,

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_protocols.GraphProtocol" is not defined To disable, use # type: ignore[name-defined]
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.GRAPH,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrFloat32s(Attribute):
def __init__(
self,
name: str,
value: Sequence[float],
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.FLOATS,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrInt64s(Attribute):
def __init__(
self,
name: str,
value: Sequence[int],
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.INTS,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrStrings(Attribute):
def __init__(
self,
name: str,
value: Sequence[str],
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.STRINGS,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrTensors(Attribute):
def __init__(
self,
name: str,
value: Sequence[_protocols.TensorProtocol],
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.TENSORS,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrGraphs(Attribute):
def __init__(
self,
name: str,
value: Sequence[_protocols.GraphProtocol],

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_protocols.GraphProtocol" is not defined To disable, use # type: ignore[name-defined]
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.GRAPHS,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


# NOTE: SparseTensor should be a sparse tensor proto
class AttrSparseTensor(Attribute):
def __init__(
self,
name: str,
value: Sequence[_protocols.SparseTensorProtocol],
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.SPARSE_TENSOR,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrSparseTensors(Attribute):
def __init__(
self,
name: str,
value: Sequence[_protocols.SparseTensorProtocol],
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.SPARSE_TENSORS,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)


class AttrTypeProto(Attribute):
def __init__(
self,
name: str,
value: _protocols.TypeProtocol,
ref_attr_name: str = "",
doc_string: str = "",
):
super().__init__(
name,
_enums.AttributeType.TYPE_PROTO,
value,
ref_attr_name=ref_attr_name,
doc_string=doc_string,
)
57 changes: 57 additions & 0 deletions onnx/ir/_enums.py
@@ -0,0 +1,57 @@
import enum

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.

class AttributeType(enum.IntEnum):
"""Enum for the types of ONNX attributes."""

# NOTE: Naming follows python conventions.
# C++ names can follow C++ conventions and rename when binding.

# TODO: Should we code gen this? We just need to get rid of protoc
# We can code gen with https://github.com/recap-build/proto-schema-parser/tree/main

# NOTE: We can assume the build tool chain has python, just not protoc, right?
# NOTE: We should alias OpSchema::AttrType as well
UNDEFINED = 0
FLOAT = 1
INT = 2
STRING = 3
TENSOR = 4
GRAPH = 5
FLOATS = 6
INTS = 7
STRINGS = 8
TENSORS = 9
GRAPHS = 10
SPARSE_TENSOR = 11
SPARSE_TENSORS = 12
TYPE_PROTO = 13
TYPE_PROTOS = 14


class DataType(enum.IntEnum):
UNDEFINED = 0
FLOAT = 1
UINT8 = 2
INT8 = 3
UINT16 = 4
INT16 = 5
INT32 = 6
INT64 = 7
STRING = 8
BOOL = 9
FLOAT16 = 10
DOUBLE = 11
UINT32 = 12
UINT64 = 13
COMPLEX64 = 14
COMPLEX128 = 15
BFLOAT16 = 16
FLOAT8E4M3FN = 17
FLOAT8E4M3FNUZ = 18
FLOAT8E5M2 = 19
FLOAT8E5M2FNUZ = 20


class TensorDataLocation(enum.IntEnum):
DEFAULT = 0
EXTERNAL = 1