-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
IR playground #5644
base: main
Are you sure you want to change the base?
IR playground #5644
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
from __future__ import annotations | ||
Check warning Code scanning / lintrunner BLACK-ISORT/format Warning
Run lintrunner -a to apply this patch.
|
||
|
||
import enum | ||
from typing import ( | ||
Any, | ||
Generic, | ||
Protocol, | ||
Sequence, | ||
TypeVar, | ||
) | ||
|
||
|
||
|
||
S = TypeVar("S") | ||
T = TypeVar("T") | ||
AttrT = TypeVar("AttrT", bound="Attribute") | ||
AttrIterableT = TypeVar("AttrIterableT", bound="_AttrIterable") | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name \_AttrIterable.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "_AttrIterable" is not defined
To disable, use # type: ignore[name-defined]
|
||
|
||
|
||
class AttributeType(enum.IntEnum): | ||
"""Enum for the types of ONNX attributes.""" | ||
|
||
# NOTE: Naming follows python conventions. | ||
# C++ names can follow C++ conventions and rename when binding. | ||
|
||
# TODO: Should we code gen this? We just need to get rid of protoc | ||
# We can code gen with https://github.com/recap-build/proto-schema-parser/tree/main | ||
|
||
# NOTE: We can assume the build tool chain has python, just not protoc, right? | ||
|
||
UNDEFINED = 0 | ||
FLOAT = 1 | ||
INT = 2 | ||
STRING = 3 | ||
TENSOR = 4 | ||
GRAPH = 5 | ||
FLOATS = 6 | ||
INTS = 7 | ||
STRINGS = 8 | ||
TENSORS = 9 | ||
GRAPHS = 10 | ||
SPARSE_TENSOR = 11 | ||
SPARSE_TENSORS = 12 | ||
TYPE_PROTO = 13 | ||
TYPE_PROTOS = 14 | ||
|
||
|
||
# NOTE: None of these classes will have a "to_onnx" method because | ||
# We cannot assume that the build tool chain has protoc installed. | ||
|
||
class AttributeProtocol(Protocol): | ||
"""Protocol for ONNX attributes.""" | ||
|
||
name: str | ||
type: AttributeType | ||
value: Any | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will not work for C++, but it would be nice in Python |
||
ref_attr_name: str | ||
doc_string: str | ||
|
||
|
||
class Attribute(Generic[T], AttributeProtocol): | ||
"""Base class for ONNX attributes.""" | ||
|
||
# NOTE: We use primitive types for T | ||
|
||
def __init__(self, | ||
name: str, | ||
type: AttributeType, | ||
value: T, | ||
*, | ||
ref_attr_name: str, | ||
doc_string: str | ||
): | ||
self.name = name | ||
self.type = type | ||
self.value = value | ||
self.ref_attr_name = ref_attr_name | ||
self.doc_string = doc_string | ||
|
||
# TODO: How do we represent ref attributes? Do we need to? | ||
|
||
|
||
# NOTE: The following classes are just supporting classes (partially applied) for convenience | ||
# But I think they would be useful to have in the IR by having the type info | ||
# explicitly in the class type. | ||
# Arguably, they can also be functions that return Attribute objects. | ||
# TODO: We need something that is easy for pattern matchers, otherwise they | ||
# can just be make_attribute functions. | ||
class AttrFloat32(Attribute[float]): | ||
def __init__(self, name: str, value: float, ref_attr_name: str = "", doc_string: str = ""): | ||
super().__init__(name, AttributeType.FLOAT, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrInt64(Attribute[int]): | ||
def __init__(self, name: str, value: int, ref_attr_name: str = "", doc_string: str = ""): | ||
super().__init__(name, AttributeType.INT, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrString(Attribute[str]): | ||
def __init__(self, name: str, value: str, ref_attr_name: str = "", doc_string: str = ""): | ||
super().__init__(name, AttributeType.STRING, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
# NOTE: Tensor should be a tensor proto | ||
class AttrTensor(Attribute[Tensor]): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Tensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
def __init__(self, name: str, value: Tensor, ref_attr_name: str = "", doc_string: str = ""): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Tensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
super().__init__(name, AttributeType.TENSOR, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrGraph(Attribute[Graph]): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Graph.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Graph" is not defined
To disable, use # type: ignore[name-defined]
|
||
def __init__(self, name: str, value: Graph, ref_attr_name: str = "", doc_string: str = ""): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Graph.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Graph" is not defined
To disable, use # type: ignore[name-defined]
|
||
super().__init__(name, AttributeType.GRAPH, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrFloat32s(Attribute[Sequence[float]]): | ||
def __init__(self, name: str, value: Sequence[float], ref_attr_name: str = "", doc_string: str = ""): | ||
super().__init__(name, AttributeType.FLOATS, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrInt64s(Attribute[Sequence[int]]): | ||
def __init__(self, name: str, value: Sequence[int], ref_attr_name: str = "", doc_string: str = ""): | ||
super().__init__(name, AttributeType.INTS, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrStrings(Attribute[Sequence[str]]): | ||
def __init__(self, name: str, value: Sequence[str], ref_attr_name: str = "", doc_string: str = ""): | ||
super().__init__(name, AttributeType.STRINGS, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrTensors(Attribute[Sequence[Tensor]]): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Tensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
def __init__(self, name: str, value: Sequence[Tensor], ref_attr_name: str = "", doc_string: str = ""): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Tensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Tensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
super().__init__(name, AttributeType.TENSORS, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrGraphs(Attribute[Sequence[Graph]]): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Graph.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Graph" is not defined
To disable, use # type: ignore[name-defined]
|
||
def __init__(self, name: str, value: Sequence[Graph], ref_attr_name: str = "", doc_string: str = ""): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name Graph.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "Graph" is not defined
To disable, use # type: ignore[name-defined]
|
||
super().__init__(name, AttributeType.GRAPHS, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
# NOTE: SparseTensor should be a sparse tensor proto | ||
class AttrSparseTensor(Attribute[Sequence[SparseTensor]]): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "SparseTensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
def __init__(self, name: str, value: Sequence[SparseTensor], ref_attr_name: str = "", doc_string: str = ""): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "SparseTensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
super().__init__(name, AttributeType.SPARSE_TENSOR, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
class AttrSparseTensors(Attribute[Sequence[SparseTensor]]): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "SparseTensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
def __init__(self, name: str, value: Sequence[SparseTensor], ref_attr_name: str = "", doc_string: str = ""): | ||
Check failure Code scanning / lintrunner RUFF/F821 Error
Undefined name SparseTensor.
See https://beta.ruff.rs/docs/rules/ Check failure Code scanning / lintrunner MYPY/name-defined Error
Name "SparseTensor" is not defined
To disable, use # type: ignore[name-defined]
|
||
super().__init__(name, AttributeType.SPARSE_TENSORS, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
||
|
||
# class AttrTypeProto(Attribute[Sequence[TypeProto]]): | ||
# def __init__(self, name: str, value: Sequence[TypeProto], ref_attr_name: str = "", doc_string: str = ""): | ||
# super().__init__(name, AttributeType.TYPE_PROTO, value, ref_attr_name=ref_attr_name, doc_string=doc_string) | ||
|
Check warning
Code scanning / lintrunner
BLACK-ISORT/format Warning