Skip to content

Commit

Permalink
[ Dy2Static ] Unify tensor.size and Variable.size() by jst (#45144)
Browse files Browse the repository at this point in the history
* unify the size and size() by jst

* fix bugs

* bug fix.

* fix bugs

* change all_close -> np.testing.assert_allclose
  • Loading branch information
2742195759 committed Aug 29, 2022
1 parent a237ff8 commit 163cd15
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
from paddle.fluid.dygraph.dygraph_to_static.base_transformer import BaseTransformer
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code


class BasicApiTransformer(BaseTransformer):
Expand All @@ -37,6 +38,8 @@ def __init__(self, wrapper_root):
def transform(self):
to_tensor_transformer = ToTensorTransformer(self.root)
to_tensor_transformer.transform()
attribute_transformer = AttributeJstTransformer(self.root)
attribute_transformer.transform()
self.visit(self.root)

return self.wrapper_root
Expand Down Expand Up @@ -122,6 +125,42 @@ def visit_Call(self, node):
return node


class AttributeJstTransformer(BaseTransformer):
"""
change some special attribute into __jst.XXX(obj, "attr_name") format.
for example:
a.size --> __jst.attr(a, "size")
because `size` have different behavier when in dygraph / static mode
NOTE: we only deal with ctx=Load() case.
"""

def __init__(self, node):
assert isinstance(
node, gast.AST
), "Input non-gast.AST node for the initialization of ToTensorTransformer."
self.interested_name = set([
'size',
])
self.root = node

def transform(self):
self.visit(self.root)
return self.root

def visit_Attribute(self, node):
assert isinstance(node, gast.Attribute)
assert isinstance(node.attr, str)
if isinstance(node.ctx,
gast.Load) and node.attr in self.interested_name:
attr = node.attr
value = node.value
node = gast.parse("_jst.Attr({}, \"{}\")".format(
ast_to_source_code(value).strip(), attr)).body[0].value
self.generic_visit(node)
return node


def is_to_variable(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
from paddle.fluid.layers.utils import copy_mutable_vars


def convert_attr(x, attr):
if isinstance(x, Variable) and attr == "size":
return x.size()
else:
return getattr(value, attr)


def indexable(x, code=None):
if isinstance(x, Variable): return x
if hasattr(x, '__len__') and hasattr(x, '__getitem__'): return x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,29 @@ def test_to_static_numpy_report_error(self):
static_res = self._run(to_static=True)


@paddle.jit.to_static
def tensor_size(x):
x = paddle.to_tensor(x)
x = paddle.reshape(x, paddle.shape(x)) # dynamic shape
y = x.size
return y


class TestTensorSize(unittest.TestCase):

def _run(self, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
x = paddle.ones([1, 2, 3])
if to_static == False:
return tensor_size(x)
return tensor_size(x).numpy()

def test_tensor_clone(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-5)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions python/paddle/jit/dy2static/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .convert_operators import convert_shape as Shape # noqa: F401
from .convert_operators import convert_while_loop as While # noqa: F401
from .convert_operators import unpack_by_structure as Unpack # noqa: F401
from .convert_operators import convert_attr as Attr # noqa: F401
from .convert_operators import indexable as Indexable # noqa: F401
from .variable_trans_func import create_bool_as_type # noqa: F401
from .variable_trans_func import to_static_variable # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_var_dtype # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_shape # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_while_loop # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable # noqa: F401
from ...fluid.dygraph.dygraph_to_static.convert_operators import unpack_by_structure, indexable, convert_attr # noqa: F401

__all__ = []

0 comments on commit 163cd15

Please sign in to comment.