Skip to content

Commit

Permalink
[Dy2Stat]Replace paddle.jit.dy2stat with _jst (#42947)
Browse files Browse the repository at this point in the history
* [Dy2Stat]Replace paddle.jit.dy2stat with _jst

* [Dy2Stat]Replace paddle.jit.dy2stat with _jst

* refine code style

* refine code style
  • Loading branch information
Aurelius84 committed May 27, 2022
1 parent a76f2b3 commit 2d87300
Show file tree
Hide file tree
Showing 16 changed files with 62 additions and 66 deletions.
Expand Up @@ -37,7 +37,7 @@ def transform(self):

def visit_Assert(self, node):
convert_assert_node = gast.parse(
'paddle.jit.dy2static.convert_assert({test}, {msg})'.format(
'_jst.convert_assert({test}, {msg})'.format(
test=ast_to_source_code(node.test),
msg=ast_to_source_code(node.msg)
if node.msg else "")).body[0].value
Expand Down
Expand Up @@ -71,7 +71,7 @@ def visit_Call(self, node):
if PDB_SET in func_str:
return node

new_func_str = "paddle.jit.dy2static.convert_call({})".format(func_str)
new_func_str = "_jst.convert_call({})".format(func_str)
new_func_ast = gast.parse(new_func_str).body[0].value
node.func = new_func_ast

Expand Down
Expand Up @@ -39,8 +39,8 @@ def visit_Call(self, node):
func_str = ast_to_source_code(node.func).strip()
if func_str in self._castable_type and len(node.args) > 0:
args_str = ast_to_source_code(node.args[0]).strip()
new_func_str = "paddle.jit.dy2static.convert_var_dtype({}, '{}')".format(
args_str, func_str)
new_func_str = "_jst.convert_var_dtype({}, '{}')".format(args_str,
func_str)
new_node = gast.parse(new_func_str).body[0].value
return new_node

Expand Down
Expand Up @@ -536,7 +536,7 @@ def create_name_nodes(name_ids):
return_vars = create_name_nodes(return_name_ids)

convert_ifelse_layer = gast.parse(
'paddle.jit.dy2static.convert_ifelse('
'_jst.convert_ifelse('
'{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'.
format(
pred=ast_to_source_code(pred),
Expand Down
Expand Up @@ -129,7 +129,7 @@ def _transform_slice_to_tensor_write(self, node):
elif slice_is_num(target_node):
value_code = ast_to_source_code(node.value)
i = "paddle.cast(" \
"x=paddle.jit.dy2static.to_static_variable({})," \
"x=_jst.to_static_variable({})," \
"dtype='int64')".format(ast_to_source_code(slice_node))
assign_code = "{} = paddle.tensor.array_write(x={}, i={}, array={})" \
.format(target_name, value_code, i, target_name)
Expand Down Expand Up @@ -252,7 +252,7 @@ def _replace_pop(self, node):
# 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2
if len(args_str) <= 2:
new_pop_str = "paddle.jit.dy2static.convert_pop({}, {})"\
new_pop_str = "_jst.convert_pop({}, {})"\
.format(target_str, ",".join(args_str))
new_pop_node = gast.parse(new_pop_str).body[0].value
return new_pop_node
Expand Down
Expand Up @@ -57,8 +57,7 @@ def visit_UnaryOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.Not):
arg = ast_to_source_code(node.operand)
new_node_str = "paddle.jit.dy2static.convert_logical_not({})".format(
arg)
new_node_str = "_jst.convert_logical_not({})".format(arg)
# NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node = gast.parse(new_node_str).body[0].value
return new_node
Expand All @@ -67,21 +66,20 @@ def visit_UnaryOp(self, node):
def visit_Compare(self, node):
self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("paddle.jit.dy2static.convert_var_shape"):
if left_str.startswith("_jst.convert_var_shape"):
# check left and comparators are all converted var shape
compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith(
"paddle.jit.dy2static.convert_var_shape"):
if not comparator_str.startswith("_jst.convert_var_shape"):
return node
op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str)

# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str = "paddle.jit.dy2static.convert_shape_compare({})".format(
new_node_str = "_jst.convert_shape_compare({})".format(
compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value
return new_node
Expand Down Expand Up @@ -119,7 +117,7 @@ def _create_bool_op_node(self, nodes, api_type):
nodes = [pre_logic_node] + [post_logic_node]

args = [ast_to_source_code(child) for child in nodes]
new_node_str = "paddle.jit.dy2static.convert_logical_{}(lambda:{}, lambda:{})".format(
new_node_str = "_jst.convert_logical_{}(lambda:{}, lambda:{})".format(
api_type, args[0], args[1])
# NOTE: gast.parse return Module(body=[expr(...)])
new_node = gast.parse(new_node_str).body[0].value
Expand Down
Expand Up @@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
else:
assign_loop_var_names.append(name)

while_func_name = "paddle.jit.dy2static.convert_while_loop"
while_func_name = "_jst.convert_while_loop"
while_node_str = "[{}] = {}({}, {}, [{}])".format(
",".join(assign_loop_var_names), while_func_name, condition_name,
body_name, ",".join(loop_var_names))
Expand Down
Expand Up @@ -50,6 +50,5 @@ def visit_Print(self, node):
return gast.Expr(value=convert_print_node)

def _create_print_node(self, print_args):
convert_print_func = gast.parse(
'paddle.jit.dy2static.convert_print').body[0].value
convert_print_func = gast.parse('_jst.convert_print').body[0].value
return gast.Call(func=convert_print_func, args=print_args, keywords=[])
Expand Up @@ -336,7 +336,7 @@ def _replace_return_in_stmt_list(self, stmt_list, return_node, return_name,
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = True'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, True)".format(
node_str = "{} = _jst.create_bool_as_type({}, True)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())

Expand Down Expand Up @@ -449,7 +449,7 @@ def _replace_after_node_to_if_in_stmt_list(
# Here assume that the parent node of return is gast.If
if isinstance(parent_node_of_return, gast.If):
# Prepend control flow boolean nodes such as '__return@1 = False'
node_str = "{} = paddle.jit.dy2static.create_bool_as_type({}, False)".format(
node_str = "{} = _jst.create_bool_as_type({}, False)".format(
return_name,
ast_to_source_code(parent_node_of_return.test).strip())
assign_false_node = gast.parse(node_str).body[0]
Expand Down
Expand Up @@ -42,7 +42,7 @@ def create_convert_shape_node(var_shape_node,
if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())

convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})".format(
convert_var_shape_func = "_jst.convert_var_shape({}, in_control_flow={})".format(
",".join(args), in_control_flow)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value

Expand All @@ -59,14 +59,14 @@ def create_convert_shape_node(var_shape_node,


def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', globals())".format(
eval_exist_func = "_jst.eval_if_exist_else_none('{}', globals())".format(
api_shape_name)
args = [attr_shape_name, eval_exist_func]

if slice_node is not None and slice_is_num(slice_node):
args.append(ast_to_source_code(slice_node.slice).strip())
choose_shape_func = "paddle.jit.dy2static.choose_shape_attr_or_api({})".format(
",".join(args))
choose_shape_func = "_jst.choose_shape_attr_or_api({})".format(",".join(
args))
choose_shape_node = gast.parse(choose_shape_func).body[0].value
if slice_node is not None and not slice_is_num(slice_node):
return gast.Subscript(
Expand All @@ -84,7 +84,7 @@ class ShapeAttributeTransformer(gast.NodeTransformer):
def visit_Attribute(self, node):
if node.attr == 'shape':
args = ast_to_source_code(node.value).strip()
convert_var_shape_func = "paddle.jit.dy2static.convert_var_shape_simple({})".format(
convert_var_shape_func = "_jst.convert_var_shape_simple({})".format(
args)
api_shape_node = gast.parse(convert_var_shape_func).body[0].value
return api_shape_node
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Expand Up @@ -185,6 +185,7 @@ def is_api_in_module(node, module_prefix):
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers
import paddle.jit.dy2static as _jst

from paddle.fluid.dygraph import to_variable
from paddle import to_tensor
Expand Down Expand Up @@ -521,8 +522,8 @@ def remove_if_exit(filepath):
def _inject_import_statements():
import_statements = [
"import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "from typing import *",
"import numpy as np"
"import paddle.fluid as fluid", "import paddle.jit.dy2static as _jst",
"from typing import *", "import numpy as np"
]
return '\n'.join(import_statements) + '\n'

Expand Down Expand Up @@ -1168,7 +1169,7 @@ def _build_var_len_assign_node(self):
else:
iter_var_name = ast_to_source_code(self.iter_node).strip()

convert_len_node_source_str = '{} = paddle.jit.dy2static.convert_len({})'.format(
convert_len_node_source_str = '{} = _jst.convert_len({})'.format(
self.iter_var_len_name, iter_var_name)

convert_len_node = gast.parse(convert_len_node_source_str).body[0]
Expand Down
Expand Up @@ -77,14 +77,12 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):


def to_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dy2static.to_static_variable({})".format(name,
name)
func_code = "{} = _jst.to_static_variable({})".format(name, name)
return gast.parse(func_code).body[0]


def create_static_variable_gast_node(name):
func_code = "{} = paddle.jit.dy2static\
.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
func_code = "{} = _jst.data_layer_not_check(name='{}', shape=[-1], dtype='float32')".format(
name, unique_name.generate(name))
return gast.parse(func_code).body[0]

Expand Down
Expand Up @@ -24,6 +24,7 @@
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import CONVERSION_OPTIONS
from test_program_translator import get_source_code
import paddle.jit.dy2static as _jst

program_translator = ProgramTranslator()

Expand Down Expand Up @@ -255,7 +256,7 @@ def _get_answer_code(self):
return get_source_code(self.answer_func)

def _get_transformed_code(self):
transformed_func = paddle.jit.dy2static.convert_call(self.func)
transformed_func = _jst.convert_call(self.func)
return get_source_code(transformed_func)

def test_code(self):
Expand All @@ -275,7 +276,7 @@ def set_func(self):
def set_answer_func(self):
class StaticCode():
def func_convert_then_not_to_static(x):
y = paddle.jit.dy2static.convert_call(func_not_to_static)(x)
y = _jst.convert_call(func_not_to_static)(x)
return y

self.answer_func = StaticCode.func_convert_then_not_to_static
Expand Down
Expand Up @@ -65,7 +65,7 @@ def set_test_func(self):
self.func = simple_func

def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7, 8]
self.static_abs_lineno_list = [7, 8, 9]

def set_dygraph_info(self):
self.line_num = 3
Expand Down Expand Up @@ -149,7 +149,7 @@ def set_test_func(self):
self.func = nested_func

def set_static_lineno(self):
self.static_abs_lineno_list = [6, 8, 9, 10, 11]
self.static_abs_lineno_list = [7, 9, 10, 11, 12]

def set_dygraph_info(self):
self.line_num = 5
Expand All @@ -174,7 +174,7 @@ def set_test_func(self):
self.func = decorated_func

def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7]
self.static_abs_lineno_list = [7, 8]

def set_dygraph_info(self):
self.line_num = 2
Expand Down Expand Up @@ -208,7 +208,7 @@ def set_test_func(self):
self.func = decorated_func2

def set_static_lineno(self):
self.static_abs_lineno_list = [6, 7]
self.static_abs_lineno_list = [7, 8]

def set_dygraph_info(self):
self.line_num = 2
Expand Down
Expand Up @@ -27,6 +27,7 @@
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
import paddle.jit.dy2static as _jst

from ifelse_simple_func import dyfunc_with_if_else

Expand Down Expand Up @@ -76,40 +77,38 @@ def false_fn_0(x_v):
x_v = x_v + 1
return x_v

x_v = paddle.jit.dy2static.convert_ifelse(
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ), (x_v, ))
__return_0 = paddle.jit.dy2static.create_bool_as_type(label is not None,
False)
__return_0 = _jst.create_bool_as_type(label is not None, False)

def true_fn_1(__return_0, __return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = paddle.jit.dy2static.create_bool_as_type(
label is not None, True)
__return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss
return __return_0, __return_value_0

def false_fn_1(__return_0, __return_value_0):
return __return_0, __return_value_0

__return_0, __return_value_0 = (paddle.jit.dy2static.convert_ifelse(
__return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0), (__return_0, __return_value_0)))
(__return_0, __return_value_0), (__return_0, __return_value_0))

def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = paddle.jit.dy2static.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_0), True)
__return_1 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_0), True)
__return_value_0 = x_v
return __return_value_0

def false_fn_2(__return_value_0):
return __return_value_0

__return_value_0 = paddle.jit.dy2static.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_0), true_fn_2,
false_fn_2, (__return_0, __return_value_0,
x_v), (__return_value_0, ), (__return_value_0, ))
__return_value_0 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_0), true_fn_2, false_fn_2,
(__return_0, __return_value_0,
x_v), (__return_value_0, ), (__return_value_0, ))
return __return_value_0


Expand All @@ -128,40 +127,38 @@ def false_fn_3(x_v):
x_v = x_v + 1
return x_v

x_v = paddle.jit.dy2static.convert_ifelse(
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
(x_v, ), (x_v, ))
__return_2 = paddle.jit.dy2static.create_bool_as_type(label is not None,
False)
__return_2 = _jst.create_bool_as_type(label is not None, False)

def true_fn_4(__return_2, __return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = paddle.jit.dy2static.create_bool_as_type(
label is not None, True)
__return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss
return __return_2, __return_value_1

def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1

__return_2, __return_value_1 = paddle.jit.dy2static.convert_ifelse(
label is not None, true_fn_4, false_fn_4, (
__return_2, __return_value_1, label, x_v),
__return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1), (__return_2, __return_value_1))

def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = paddle.jit.dy2static.create_bool_as_type(
paddle.jit.dy2static.convert_logical_not(__return_2), True)
__return_3 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_2), True)
__return_value_1 = x_v
return __return_value_1

def false_fn_5(__return_value_1):
return __return_value_1

__return_value_1 = paddle.jit.dy2static.convert_ifelse(
paddle.jit.dy2static.convert_logical_not(__return_2), true_fn_5,
false_fn_5, (__return_2, __return_value_1,
x_v), (__return_value_1, ), (__return_value_1, ))
__return_value_1 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_2), true_fn_5, false_fn_5,
(__return_2, __return_value_1,
x_v), (__return_value_1, ), (__return_value_1, ))
return __return_value_1


Expand Down

0 comments on commit 2d87300

Please sign in to comment.