Skip to content

Commit

Permalink
Merge pull request #58576 from pak-laura/c2.99f03a9d3bafe902c1e6beb10…
Browse files Browse the repository at this point in the history
…5b2f24172f238645

Replace CHECK with returning an InternalError on failing to create py…
  • Loading branch information
mihaimaruseac committed Nov 14, 2022
2 parents 5dbe90a + 6fc67e4 commit 4f34ec8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tensorflow/python/lib/core/py_func.cc
Expand Up @@ -83,8 +83,8 @@ bool IsCPUDevice(const Device* d) {
return d == nullptr || d->tensorflow_accelerator_device_info() == nullptr;
}

// Givens the 'call', prepares the token and inputs as a python tuple
// that is appropriate for calling the trampoline.
// Given the 'call', prepares the token and inputs as a python tuple that is
// appropriate for calling the trampoline.
Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
int64_t n = call->ins.size();
PyObject* lst = PyList_New(n);
Expand Down Expand Up @@ -119,8 +119,12 @@ Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) {
PyList_SetItem(lst, i, arg);
}
*tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst);
CHECK(*tuple);
return Status::OK();
if (*tuple == nullptr) {
return errors::Internal(
"Failed to create python tuple. Please make sure `token` is a "
"well-formed UTF-8 string.");
}
return OkStatus();
}

bool IsSingleNone(PyObject* obj) {
Expand Down
28 changes: 28 additions & 0 deletions tensorflow/python/ops/script_ops_test.py
Expand Up @@ -16,8 +16,11 @@

from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import gen_script_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops.script_ops import numpy_function
from tensorflow.python.platform import test
Expand Down Expand Up @@ -87,5 +90,30 @@ def func_stateful(a, b):
2) # as stateful, func is guaranteed to execute twice


class PyFunctionTest(test.TestCase):

@test_util.run_in_graph_and_eager_modes
def test_variable_arguments(self):

def plus(a, b):
return a + b

v1 = resource_variable_ops.ResourceVariable(1)
self.evaluate(v1.initializer)

actual_result = script_ops.eager_py_func(plus, [v1, 2], dtypes.int32)
expect_result = constant_op.constant(3, dtypes.int32)
self.assertAllEqual(actual_result, expect_result)

@test_util.run_in_graph_and_eager_modes
def test_fail_on_non_utf8_token(self):
value = constant_op.constant(value=[1, 2])
token = b"\xb0"
data_type = [dtypes.int32]
with self.assertRaises((errors.InternalError, UnicodeDecodeError)):
self.evaluate(
gen_script_ops.py_func(input=[value], token=token, Tout=data_type))


if __name__ == "__main__":
test.main()

0 comments on commit 4f34ec8

Please sign in to comment.