Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IR playground #5644

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
142 changes: 142 additions & 0 deletions onnx/ir/_attributes.py
@@ -0,0 +1,142 @@
from __future__ import annotations

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

BLACK-ISORT/format Warning

Run lintrunner -a to apply this patch.

import enum
from typing import (
Any,
Generic,
Protocol,
Sequence,
TypeVar,
)



S = TypeVar("S")
T = TypeVar("T")
AttrT = TypeVar("AttrT", bound="Attribute")
AttrIterableT = TypeVar("AttrIterableT", bound="_AttrIterable")

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name \_AttrIterable.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "_AttrIterable" is not defined To disable, use # type: ignore[name-defined]


class AttributeType(enum.IntEnum):
"""Enum for the types of ONNX attributes."""

# NOTE: Naming follows python conventions.
# C++ names can follow C++ conventions and rename when binding.

# TODO: Should we code gen this? We just need to get rid of protoc
# We can code gen with https://github.com/recap-build/proto-schema-parser/tree/main

# NOTE: We can assume the build tool chain has python, just not protoc, right?

UNDEFINED = 0
FLOAT = 1
INT = 2
STRING = 3
TENSOR = 4
GRAPH = 5
FLOATS = 6
INTS = 7
STRINGS = 8
TENSORS = 9
GRAPHS = 10
SPARSE_TENSOR = 11
SPARSE_TENSORS = 12
TYPE_PROTO = 13
TYPE_PROTOS = 14


# NOTE: None of these classes will have a "to_onnx" method because
# We cannot assume that the build tool chain has protoc installed.

class AttributeProtocol(Protocol):
"""Protocol for ONNX attributes."""

name: str
type: AttributeType
value: Any
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work for C++, but it would be nice in Python

ref_attr_name: str
doc_string: str


class Attribute(Generic[T], AttributeProtocol):
"""Base class for ONNX attributes."""

# NOTE: We use primitive types for T

def __init__(self,
name: str,
type: AttributeType,
value: T,
*,
ref_attr_name: str,
doc_string: str
):
self.name = name
self.type = type
self.value = value
self.ref_attr_name = ref_attr_name
self.doc_string = doc_string

# TODO: How do we represent ref attributes? Do we need to?


# NOTE: The following classes are just supporting classes (partially applied) for convenience
# But I think they would be useful to have in the IR by having the type info
# explicitly in the class type.
# Arguably, they can also be functions that return Attribute objects.
# TODO: We need something that is easy for pattern matchers, otherwise they
# can just be make_attribute functions.
class AttrFloat32(Attribute[float]):
def __init__(self, name: str, value: float, ref_attr_name: str = "", doc_string: str = ""):
super().__init__(name, AttributeType.FLOAT, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrInt64(Attribute[int]):
def __init__(self, name: str, value: int, ref_attr_name: str = "", doc_string: str = ""):
super().__init__(name, AttributeType.INT, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrString(Attribute[str]):
def __init__(self, name: str, value: str, ref_attr_name: str = "", doc_string: str = ""):
super().__init__(name, AttributeType.STRING, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

# NOTE: Tensor should be a tensor proto
class AttrTensor(Attribute[Tensor]):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Tensor" is not defined To disable, use # type: ignore[name-defined]
def __init__(self, name: str, value: Tensor, ref_attr_name: str = "", doc_string: str = ""):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Tensor" is not defined To disable, use # type: ignore[name-defined]
super().__init__(name, AttributeType.TENSOR, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrGraph(Attribute[Graph]):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Graph.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Graph" is not defined To disable, use # type: ignore[name-defined]
def __init__(self, name: str, value: Graph, ref_attr_name: str = "", doc_string: str = ""):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Graph.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Graph" is not defined To disable, use # type: ignore[name-defined]
super().__init__(name, AttributeType.GRAPH, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrFloat32s(Attribute[Sequence[float]]):
def __init__(self, name: str, value: Sequence[float], ref_attr_name: str = "", doc_string: str = ""):
super().__init__(name, AttributeType.FLOATS, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrInt64s(Attribute[Sequence[int]]):
def __init__(self, name: str, value: Sequence[int], ref_attr_name: str = "", doc_string: str = ""):
super().__init__(name, AttributeType.INTS, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrStrings(Attribute[Sequence[str]]):
def __init__(self, name: str, value: Sequence[str], ref_attr_name: str = "", doc_string: str = ""):
super().__init__(name, AttributeType.STRINGS, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrTensors(Attribute[Sequence[Tensor]]):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Tensor" is not defined To disable, use # type: ignore[name-defined]
def __init__(self, name: str, value: Sequence[Tensor], ref_attr_name: str = "", doc_string: str = ""):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Tensor" is not defined To disable, use # type: ignore[name-defined]
super().__init__(name, AttributeType.TENSORS, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrGraphs(Attribute[Sequence[Graph]]):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Graph.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Graph" is not defined To disable, use # type: ignore[name-defined]
def __init__(self, name: str, value: Sequence[Graph], ref_attr_name: str = "", doc_string: str = ""):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name Graph.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "Graph" is not defined To disable, use # type: ignore[name-defined]
super().__init__(name, AttributeType.GRAPHS, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

# NOTE: SparseTensor should be a sparse tensor proto
class AttrSparseTensor(Attribute[Sequence[SparseTensor]]):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "SparseTensor" is not defined To disable, use # type: ignore[name-defined]
def __init__(self, name: str, value: Sequence[SparseTensor], ref_attr_name: str = "", doc_string: str = ""):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "SparseTensor" is not defined To disable, use # type: ignore[name-defined]
super().__init__(name, AttributeType.SPARSE_TENSOR, value, ref_attr_name=ref_attr_name, doc_string=doc_string)

class AttrSparseTensors(Attribute[Sequence[SparseTensor]]):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "SparseTensor" is not defined To disable, use # type: ignore[name-defined]
def __init__(self, name: str, value: Sequence[SparseTensor], ref_attr_name: str = "", doc_string: str = ""):

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/

Check failure

Code scanning / lintrunner

MYPY/name-defined Error

Name "SparseTensor" is not defined To disable, use # type: ignore[name-defined]
super().__init__(name, AttributeType.SPARSE_TENSORS, value, ref_attr_name=ref_attr_name, doc_string=doc_string)


# class AttrTypeProto(Attribute[Sequence[TypeProto]]):
# def __init__(self, name: str, value: Sequence[TypeProto], ref_attr_name: str = "", doc_string: str = ""):
# super().__init__(name, AttributeType.TYPE_PROTO, value, ref_attr_name=ref_attr_name, doc_string=doc_string)
Fixed Show fixed Hide fixed