Skip to content

Commit

Permalink
Add c++ utility function that converts AttrDef values to the expected…
Browse files Browse the repository at this point in the history
… types. This will replace methods such as `make_bool` (in `eager/execute.py`) and `_MakeBool` (in `op_def_library.py`).

PiperOrigin-RevId: 325453895
Change-Id: I3a454a0365b08c545944ef7528c357b5d8bf2c02
  • Loading branch information
edloper authored and tensorflower-gardener committed Aug 7, 2020
1 parent 3186465 commit 17cdd71
Show file tree
Hide file tree
Showing 6 changed files with 546 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tensorflow/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1617,11 +1617,60 @@ py_library(
],
)

cc_library(
name = "op_def_util_cc",
srcs = ["framework/op_def_util.cc"],
hdrs = ["framework/op_def_util.h"],
deps = [
":cpp_python_util",
":safe_ptr",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)

# Note: this target is only used for op_def_util_test. It includes op_def_util.cc
# directly in its srcs (rather than depending on the `op_def_util_cc` target) because
# depending on that target adds dependencies that register objects; and since the
# extension is built as a shared object in some kokoro tests, this causes those objects
# to get registered multiple times (which fails).
tf_python_pybind_extension(
name = "_op_def_util",
srcs = [
"framework/op_def_util.cc",
"framework/op_def_util_pybind.cc",
],
hdrs = [
"framework/op_def_util.h",
"lib/core/safe_ptr.h",
"util/util.h",
"//tensorflow/c:headers",
"//tensorflow/c/eager:headers",
],
module_name = "_op_def_util",
deps = [
":pybind11_status",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"//third_party/python_runtime:headers",
"@com_google_absl//absl/strings",
"@pybind11",
],
)

tf_py_test(
name = "op_def_util_test",
srcs = ["framework/op_def_util_test.py"],
python_version = "PY3",
tags = ["no_pip"],
)

py_library(
name = "framework_ops", # "ops" is already the name of a deprecated target
srcs = ["framework/ops.py"],
srcs_version = "PY2AND3",
deps = [
":_op_def_util",
":c_api_util",
":control_flow_util",
":device",
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/python/framework/op_def_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

import six

from google.protobuf import text_format
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python import _pywrap_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_callbacks
from tensorflow.python.framework import op_def_registry
Expand Down Expand Up @@ -788,3 +790,13 @@ def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=in
outputs = callback_outputs

return output_structure, op_def.is_stateful, op, outputs


# The following symbols are used by op_def_util.cc.
_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType)
_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype)
_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape)
_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape)
_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto)
_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse)
_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor)
270 changes: 270 additions & 0 deletions tensorflow/python/framework/op_def_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/framework/op_def_util.h"

#include <map>

#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/python/util/util.h"

using ::tensorflow::swig::GetRegisteredPyObject;

#if PY_MAJOR_VERSION < 3
#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
#define PY_INT_CHECK(x) (PyInt_Check(x))
#define PY_INT_TYPE PyInt_Type
#else
#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
#define PY_INT_CHECK(x) (PyLong_Check(x))
#define PY_INT_TYPE PyLong_Type
#endif

namespace tensorflow {

namespace {

const std::map<std::string, AttributeType>* AttributeTypeNameMap() {
static auto* type_map = new std::map<std::string, AttributeType>(
{{"any", AttributeType::ANY},
{"float", AttributeType::FLOAT},
{"int", AttributeType::INT},
{"string", AttributeType::STRING},
{"bool", AttributeType::BOOL},
{"shape", AttributeType::SHAPE},
{"type", AttributeType::DTYPE},
{"tensor", AttributeType::TENSOR},
{"list(any)", AttributeType::LIST_ANY},
{"list(float)", AttributeType::LIST_FLOAT},
{"list(int)", AttributeType::LIST_INT},
{"list(string)", AttributeType::LIST_STRING},
{"list(bool)", AttributeType::LIST_BOOL},
{"list(type)", AttributeType::LIST_DTYPE},
{"list(shape)", AttributeType::LIST_SHAPE},
{"list(tensor)", AttributeType::LIST_TENSOR}});
return type_map;
}

// Note: we define functors for converting value types (rather than simple
// functions) so we can define a generic ConvertListAttr method. These
// functors all return a new reference on success, or nullptr on failure.
// They do not (necessarily) call PyErr_SetString.

struct ConvertAnyFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Py_INCREF(value);
return Safe_PyObjectPtr(value);
}
};

struct ConvertFloatFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PyFloat_Check(value)) {
Py_INCREF(value);
result.reset(value);
} else if (!PY_STRING_CHECK(value)) {
result.reset(PyObject_CallFunctionObjArgs(
reinterpret_cast<PyObject*>(&PyFloat_Type), value, nullptr));
}
return result;
}
};

struct ConvertIntFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PY_INT_CHECK(value)) {
Py_INCREF(value);
result.reset(value);
} else if (!PY_STRING_CHECK(value)) {
result.reset(PyObject_CallFunctionObjArgs(
reinterpret_cast<PyObject*>(&PY_INT_TYPE), value, nullptr));
}
return result;
}
};

struct ConvertStringFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PY_STRING_CHECK(value)) {
Py_INCREF(value);
result.reset(value);
}
return result;
}
};

// TODO(edloper): Should we allow ints (or any other values) to be converted
// to booleans? Currently, TensorFlow does not do this conversion for attribute
// values in _MakeBool or make_bool.
struct ConvertBoolFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
if (PyBool_Check(value)) {
Py_INCREF(value);
result.reset(value);
}
return result;
}
};

struct ConvertDTypeFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
// The following symbols are registered in op_def_library.py
static PyObject* dtype = GetRegisteredPyObject("tf.dtypes.DType");
static PyObject* as_dtype = GetRegisteredPyObject("tf.dtypes.as_dtype");
if (reinterpret_cast<PyObject*>(value->ob_type) == dtype) {
Py_INCREF(value);
result.reset(value);
} else {
result.reset(PyObject_CallFunctionObjArgs(as_dtype, value, nullptr));
}
return result;
}
};

struct ConvertTensorShapeFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
// The following symbols are registered in op_def_library.py
static PyObject* shape = GetRegisteredPyObject("tf.TensorShape");
static PyObject* as_shape = GetRegisteredPyObject("tf.as_shape");
if (reinterpret_cast<PyObject*>(value->ob_type) == shape) {
Py_INCREF(value);
result.reset(value);
} else {
result.reset(PyObject_CallFunctionObjArgs(as_shape, value, nullptr));
}
return result;
}
};

struct ConvertTensorProtoFunctor {
Safe_PyObjectPtr operator()(PyObject* value) {
Safe_PyObjectPtr result;
// The following symbols are registered in op_def_library.py
static PyObject* tensor_proto = GetRegisteredPyObject("tf.TensorProto");
static PyObject* text_format_parse =
GetRegisteredPyObject("text_format.Parse");
if (reinterpret_cast<PyObject*>(value->ob_type) == tensor_proto) {
Py_INCREF(value);
result.reset(value);
} else if (PY_STRING_CHECK(value)) {
result.reset(PyObject_CallObject(tensor_proto, nullptr));
if (result) {
PyObject_CallFunctionObjArgs(text_format_parse, value, result.get(),
nullptr);
}
}
return result;
}
};

// Converts `value` to a list of elements with the same type, using
// `convert_functor` to convert each element.
template <typename T>
Safe_PyObjectPtr ConvertListAttr(PyObject* value, T convert_functor) {
// Copy the list.
Safe_PyObjectPtr result(PySequence_List(value));
if (!result) return nullptr;

// Check the type of each item in the list.
Py_ssize_t len = PySequence_Fast_GET_SIZE(result.get());
PyObject** items = PySequence_Fast_ITEMS(result.get());
for (Py_ssize_t i = 0; i < len; ++i) {
if (!PyFloat_Check(value)) {
Safe_PyObjectPtr item = convert_functor(items[i]);
if (!item) return nullptr;
PySequence_SetItem(result.get(), i, item.get());
}
}
return result;
}

// Returns the given `value` value, converted to the indicated type.
// Returns nullptr if `value` is not convertible.
Safe_PyObjectPtr ConvertAttrOrNull(PyObject* value, AttributeType attr_type) {
switch (attr_type) {
case AttributeType::ANY:
return ConvertAnyFunctor()(value);
case AttributeType::FLOAT:
return ConvertFloatFunctor()(value);
case AttributeType::INT:
return ConvertIntFunctor()(value);
case AttributeType::STRING:
return ConvertStringFunctor()(value);
case AttributeType::BOOL:
return ConvertBoolFunctor()(value);
case AttributeType::DTYPE:
return ConvertDTypeFunctor()(value);
case AttributeType::SHAPE:
return ConvertTensorShapeFunctor()(value);
case AttributeType::TENSOR:
return ConvertTensorProtoFunctor()(value);
case AttributeType::LIST_ANY:
return ConvertListAttr(value, ConvertAnyFunctor());
case AttributeType::LIST_FLOAT:
return ConvertListAttr(value, ConvertFloatFunctor());
case AttributeType::LIST_INT:
return ConvertListAttr(value, ConvertIntFunctor());
case AttributeType::LIST_STRING:
return ConvertListAttr(value, ConvertStringFunctor());
case AttributeType::LIST_BOOL:
return ConvertListAttr(value, ConvertBoolFunctor());
case AttributeType::LIST_DTYPE:
return ConvertListAttr(value, ConvertDTypeFunctor());
case AttributeType::LIST_SHAPE:
return ConvertListAttr(value, ConvertTensorShapeFunctor());
case AttributeType::LIST_TENSOR:
return ConvertListAttr(value, ConvertTensorProtoFunctor());
default:
return nullptr;
}
}

} // namespace

AttributeType AttributeTypeFromName(const std::string& type_name) {
const auto* type_map = AttributeTypeNameMap();
auto it = type_map->find(type_name);
return it != type_map->end() ? it->second : AttributeType::UNKNOWN;
}

std::string AttributeTypeToName(AttributeType attr_type) {
for (const auto& pair : *AttributeTypeNameMap()) {
if (pair.second == attr_type) {
return pair.first;
}
}
return "<unknown>";
}

Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
AttributeType type) {
Safe_PyObjectPtr result = ConvertAttrOrNull(value, type);
if (!result) {
auto err = absl::StrCat("Failed to convert value of type '",
value->ob_type->tp_name, "' to type '",
AttributeTypeToName(type), "'.");
PyErr_SetString(PyExc_TypeError, err.c_str());
}

return result;
}

} // namespace tensorflow

0 comments on commit 17cdd71

Please sign in to comment.