diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc index 04d50ea3debe68..89bcbcbfbe4c46 100644 --- a/tensorflow/python/lib/core/py_func.cc +++ b/tensorflow/python/lib/core/py_func.cc @@ -83,8 +83,8 @@ bool IsCPUDevice(const Device* d) { return d == nullptr || d->tensorflow_accelerator_device_info() == nullptr; } -// Givens the 'call', prepares the token and inputs as a python tuple -// that is appropriate for calling the trampoline. +// Given the 'call', prepares the token and inputs as a python tuple that is +// appropriate for calling the trampoline. Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) { int64_t n = call->ins.size(); PyObject* lst = PyList_New(n); @@ -119,7 +119,11 @@ Status MakeArgTuple(const PyCall* call, TFE_Context* ctx, PyObject** tuple) { PyList_SetItem(lst, i, arg); } *tuple = Py_BuildValue("(ssN)", call->token.c_str(), device_name, lst); - CHECK(*tuple); + if (*tuple == nullptr) { + return errors::Internal( + "Failed to create python tuple. Please make sure `token` is a " + "well-formed UTF-8 string."); + } return OkStatus(); } diff --git a/tensorflow/python/ops/script_ops_test.py b/tensorflow/python/ops/script_ops_test.py index 45689b9c897e78..92339c434464e4 100644 --- a/tensorflow/python/ops/script_ops_test.py +++ b/tensorflow/python/ops/script_ops_test.py @@ -17,7 +17,9 @@ from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_script_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops.script_ops import numpy_function @@ -103,6 +105,15 @@ def plus(a, b): expect_result = constant_op.constant(3, dtypes.int32) self.assertAllEqual(actual_result, expect_result) + @test_util.run_in_graph_and_eager_modes + def test_fail_on_non_utf8_token(self): + value = constant_op.constant(value=[1, 2]) + token = b"\xb0" + data_type = [dtypes.int32] + with self.assertRaises((errors.InternalError, UnicodeDecodeError)): + self.evaluate( + gen_script_ops.py_func(input=[value], token=token, Tout=data_type)) + if __name__ == "__main__": test.main()