aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-01 13:46:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:57:32 -0700
commitec900f15e352e4b203b1f0678f7d2ff042df57d5 (patch)
tree2d7a7ffc0f17cb28801c7a9937b6f4e3777592c7
parent3039a4694e22674b502257ae34b0a5b614a631f3 (diff)
Minor speed improvements to defun.
- EncodeArg in C instead of python. - Also caches parsed device specs, and device spec hashes - Adds a common way to register python types in C. - Fastpath canonicalize function inputs when no kwargs are passed - Set the func name attr directly instead of creating an op to wrap it. - Rewrite IsAttrsHelper without caching Before: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 101.803263028 extras { key: "examples_per_sec" value { double_value: 9822.86785562 } } } After: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 47.2899993261 extras { key: "examples_per_sec" value { double_value: 21146.1199884 } } } PiperOrigin-RevId: 215272962
-rwxr-xr-xtensorflow/c/eager/c_api.cc8
-rwxr-xr-xtensorflow/c/eager/c_api.h3
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/function.py100
-rw-r--r--tensorflow/python/eager/function_test.py26
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h4
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc223
-rw-r--r--tensorflow/python/framework/device.py12
-rw-r--r--tensorflow/python/framework/sparse_tensor.py2
-rwxr-xr-xtensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/python/util/nest.py4
-rw-r--r--tensorflow/python/util/util.cc223
-rw-r--r--tensorflow/python/util/util.h34
-rw-r--r--tensorflow/python/util/util.i10
14 files changed, 462 insertions, 189 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 0bf3d9542b..3554ec0bf3 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -578,6 +578,14 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
+void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
+ const char* data, size_t length) {
+ tensorflow::AttrValue attr_value;
+ tensorflow::NameAttrList* func = attr_value.mutable_func();
+ func->set_name(data, length);
+ op->operation.MutableAttrs()->Set(attr_name, attr_value);
+}
+
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
tensorflow::Tensor t;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 6323f8a053..b2454d8722 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -313,6 +313,9 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
const char* attr_name,
const TFE_Op* value);
+TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
+ const char* data, size_t length);
+
TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op,
const char* attr_name,
TF_Tensor* tensor,
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d3d997e6df..d0c1a93118 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 60a4f018cd..3b6f288fb9 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1005,52 +1005,8 @@ def func_graph_from_py_func(name,
return func_graph
-_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
-
-
-def _encode_arg(arg):
- """A canonical representation for this argument, for use in a cache key."""
-
- # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
- # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
- # are used for both performance reasons, as much TensorFlow code specializes
- # on known shapes to produce slimmer graphs, and correctness, as some
- # high-level APIs require shapes to be fully-known.
- #
- # TODO(akshayka): Add support for sparse tensors.
- #
- # pylint: disable=protected-access
- if isinstance(arg, ops.Tensor):
- return _TensorType(arg.dtype, arg._shape_tuple())
- elif isinstance(arg, ops.IndexedSlices):
- if arg.dense_shape is not None:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()),
- ])
- else:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- ])
- # pylint: enable=protected-access
- elif isinstance(arg, (list, tuple)):
- return tuple([_encode_arg(elem) for elem in arg])
- elif isinstance(arg, dict):
- return tuple(
- (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
- else:
- try:
- # If possible, keep only a weak reference to Python objects. Weak
- # references hash to the same value as the original object.
- # TODO(allenl): Clean up dead functions and their cache keys if the cache
- # gets large. Right now creating objects with a defunned method, calling
- # the method, and losing a reference to the object in a loop will leak
- # memory here.
- return weakref.ref(arg)
- except TypeError:
- return arg
+pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
+pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
def _deterministic_dict_values(dictionary):
@@ -1120,6 +1076,8 @@ class PolymorphicFunction(object):
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
}
+ self._default_values = fullargspec.defaults
+ self._default_values_start_index = offset
if input_signature is None:
self._input_signature = None
else:
@@ -1180,7 +1138,7 @@ class PolymorphicFunction(object):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
- cache_key = tuple(_encode_arg(arg) for arg in inputs)
+ cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
else:
del args, kwargs
cache_key = self._flat_input_signature
@@ -1203,7 +1161,7 @@ class PolymorphicFunction(object):
colocation_stack = (() if executing_eagerly else
tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
- return cache_key + (execution_context, device_functions, colocation_stack)
+ return (cache_key, execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
@@ -1231,26 +1189,32 @@ class PolymorphicFunction(object):
# Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that
# aren't in `args`.
- arg_indices_to_values = {
- index: default
- for index, default in six.iteritems(self._arg_indices_to_default_values)
- if index >= len(args)
- }
- consumed_args = []
- for arg, value in six.iteritems(kwargs):
- index = self._args_to_indices.get(arg, None)
- if index is not None:
- arg_indices_to_values[index] = value
- consumed_args.append(arg)
- elif self._input_signature is not None:
- raise ValueError("Cannot define a TensorFlow function from a Python "
- "function with keyword arguments when "
- "input_signature is provided.")
- for arg in consumed_args:
- # After this loop, `kwargs` will only contain true keyword arguments, as
- # opposed to named arguments called in a keyword-like fashion.
- kwargs.pop(arg)
- inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ if not kwargs:
+ if self._default_values:
+ inputs = args + self._default_values[len(args) -
+ self._default_values_start_index:]
+ else:
+ inputs = args
+ else:
+ arg_indices_to_values = {
+ index: default for index, default in six.iteritems(
+ self._arg_indices_to_default_values) if index >= len(args)
+ }
+ consumed_args = []
+ for arg, value in six.iteritems(kwargs):
+ index = self._args_to_indices.get(arg, None)
+ if index is not None:
+ arg_indices_to_values[index] = value
+ consumed_args.append(arg)
+ elif self._input_signature is not None:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+ for arg in consumed_args:
+ # After this loop, `kwargs` will only contain true keyword arguments, as
+ # opposed to named arguments called in a keyword-like fashion.
+ kwargs.pop(arg)
+ inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index afe3ba9893..9ce367a837 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1237,6 +1237,24 @@ class FunctionTest(test.TestCase):
x = constant_op.constant([1.0, 2.0])
self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
+ def testCacheObjectHashCollisions(self):
+
+ class Foo(object):
+
+ def __hash__(self):
+ return 42
+
+ def func(foo):
+ del foo
+ return
+
+ defined = function.defun(func)
+ defined(Foo())
+ self.assertEqual(len(defined._function_cache), 1)
+
+ defined(Foo())
+ self.assertEqual(len(defined._function_cache), 2)
+
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
@@ -1250,20 +1268,20 @@ class FunctionTest(test.TestCase):
def cache_keys():
"""Sanitizes cache keys of non-input metadata."""
- return tuple(key[:3] for key in defined._function_cache)
+ return tuple(key[0] for key in defined._function_cache)
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn((0, 1, 20), cache_keys())
+ self.assertIn(('tRRR', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn((1, 1, 2), cache_keys())
+ self.assertIn(('tRRR', (1, 1, 2)), cache_keys())
# This matches the previous call.
defined(foo=1)
self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn((1, 2, 3), cache_keys())
+ self.assertIn(('tRRR', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index f1b4042ec9..decd635b58 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -224,4 +224,8 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
// The shape is represented as a Python tuple of integers.
PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
+// Encodes the object as a tuple that is meant to be used as part of the key
+// for the defun function cache.
+PyObject* TFE_Py_EncodeArg(PyObject*);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 196e20e4d7..4b9f7f4100 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@@ -567,11 +568,8 @@ bool SetOpAttrScalar(
return false;
}
}
- TFE_Op* func = TFE_NewOp(
- ctx, string(func_name.data(), func_name.size()).c_str(), status);
- if (TF_GetCode(status) != TF_OK) return false;
- TFE_OpSetAttrFunction(op, key, func);
- TFE_DeleteOp(func);
+ TF_SetStatus(status, TF_OK, "");
+ TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
} else {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
@@ -2748,3 +2746,218 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
return RecordGradient(op_name, inputs, attrs, results, name);
}
+
+namespace {
+
+tensorflow::int64 GetPyNoneHash() {
+ tensorflow::int64 py_none_hash = PyObject_Hash(Py_None);
+ return py_none_hash;
+}
+
+struct EncodeResult {
+ string str;
+ std::vector<PyObject*> objects;
+
+ PyObject* ToPyTuple() {
+ PyObject* result = PyTuple_New(2);
+
+ PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str()));
+
+ if (objects.empty()) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(result, 1, Py_None);
+ } else {
+ PyObject* objects_tuple = PyTuple_New(objects.size());
+
+ for (int i = 0; i < objects.size(); i++) {
+ PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
+ }
+
+ PyTuple_SET_ITEM(result, 1, objects_tuple);
+ }
+
+ return result;
+ }
+};
+
+tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
+ if (EagerTensor_CheckExact(arg)) {
+ TFE_TensorHandle* t = EagerTensor_Handle(arg);
+ tensorflow::TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
+ absl::StrAppend(&result->str, t->handle->dtype);
+
+ for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
+ absl::StrAppend(&result->str, dim_size);
+ }
+
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Safe_PyObjectPtr dtype_object(
+ PyObject_GetAttrString(arg, "dtype"));
+
+ if (dtype_object == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor object doesn't have dtype() attr.");
+ }
+
+ tensorflow::Safe_PyObjectPtr dtype_enum(
+ PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
+
+ if (dtype_enum == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor's dtype object doesn't have _type_enum() attr.");
+ }
+
+ tensorflow::DataType dtype =
+ static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
+
+ absl::StrAppend(&result->str, dtype);
+ static char _shape_tuple[] = "_shape_tuple";
+ tensorflow::Safe_PyObjectPtr shape_tuple(
+ PyObject_CallMethod(arg, _shape_tuple, nullptr));
+
+ if (shape_tuple == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor object doesn't have _shape_tuple() method.");
+ }
+
+ if (shape_tuple.get() == Py_None) {
+ // Unknown shape, encode that directly.
+ absl::StrAppend(&result->str, GetPyNoneHash());
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
+ shape_tuple.get(), "shape_tuple didn't return a sequence"));
+
+ int len = PySequence_Fast_GET_SIZE(shape_seq.get());
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
+ if (item == Py_None) {
+ absl::StrAppend(&result->str, GetPyNoneHash());
+ } else {
+ absl::StrAppend(&result->str, MakeInt(item));
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+const char kTensor[] = "T";
+const char kIndexedSlices[] = "I";
+const char kList[] = "L";
+const char kTuple[] = "t";
+const char kDict[] = "D";
+const char kRaw[] = "R";
+
+tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
+
+// This function doesn't set the type of sequence before
+tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
+ EncodeResult* result) {
+ tensorflow::Safe_PyObjectPtr arg_seq(
+ PySequence_Fast(arg, "unable to create seq from list/tuple"));
+
+ absl::StrAppend(&result->str, type);
+ int len = PySequence_Fast_GET_SIZE(arg_seq.get());
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
+ if (item == Py_None) {
+ absl::StrAppend(&result->str, GetPyNoneHash());
+ } else {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
+ if (tensorflow::swig::IsTensor(arg)) {
+ absl::StrAppend(&result->str, kTensor);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result));
+ } else if (tensorflow::swig::IsIndexedSlices(arg)) {
+ absl::StrAppend(&result->str, kIndexedSlices);
+ tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
+ if (values == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a values attr");
+ }
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result));
+
+ tensorflow::Safe_PyObjectPtr indices(
+ PyObject_GetAttrString(arg, "indices"));
+ if (indices == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a indices attr");
+ }
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result));
+
+ tensorflow::Safe_PyObjectPtr dense_shape(
+ PyObject_GetAttrString(arg, "dense_shape"));
+ if (dense_shape == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a dense_shape attr");
+ }
+ if (dense_shape.get() != Py_None) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result));
+ }
+ } else if (PyList_Check(arg)) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, result));
+ } else if (PyTuple_Check(arg)) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, result));
+ } else if (PyDict_Check(arg)) {
+ tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
+ if (PyList_Sort(keys.get()) == -1) {
+ return tensorflow::errors::Internal("Unable to sort keys");
+ }
+
+ absl::StrAppend(&result->str, kDict);
+ int len = PyList_Size(keys.get());
+
+ for (int i = 0; i < len; i++) {
+ PyObject* key = PyList_GetItem(keys.get(), i);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result));
+ PyObject* value = PyDict_GetItem(arg, key);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result));
+ }
+ } else {
+ PyObject* object = PyWeakref_NewRef(arg, nullptr);
+
+ if (object == nullptr) {
+ PyErr_Clear();
+
+ object = arg;
+ Py_INCREF(object);
+ }
+
+ absl::StrAppend(&result->str, kRaw);
+ result->objects.push_back(object);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
+// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
+// are used for both performance reasons, as much TensorFlow code specializes
+// on known shapes to produce slimmer graphs, and correctness, as some
+// high-level APIs require shapes to be fully-known.
+//
+// TODO(nareshmodi): Add support for sparse tensors.
+PyObject* TFE_Py_EncodeArg(PyObject* arg) {
+ EncodeResult result;
+ const auto status = TFE_Py_EncodeArgHelper(arg, &result);
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ return result.ToPyTuple();
+}
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 06c653097a..7f6e0a75a5 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -87,6 +87,7 @@ class DeviceSpec(object):
else:
self.device_type = device_type
self.device_index = device_index
+ self._hash = hash(self.to_string())
def _clear(self):
self._job = None
@@ -234,7 +235,7 @@ class DeviceSpec(object):
return self.to_string() == other.to_string()
def __hash__(self):
- return hash(self.to_string())
+ return self._hash
def check_valid(spec):
@@ -266,6 +267,7 @@ def canonical_name(device):
# possible to compare the device function stacks belonging to different
# graphs in a meaningful way.
_cached_device_functions = {}
+_cached_device_specs = {}
_cache_lock = threading.Lock()
@@ -297,7 +299,13 @@ def merge_device(spec):
"""
with _cache_lock:
if not isinstance(spec, DeviceSpec):
- spec = DeviceSpec.from_string(spec or "")
+ cached_device_spec = _cached_device_specs.get(spec, None)
+ if cached_device_spec is None:
+ device_spec = DeviceSpec.from_string(spec or "")
+ _cached_device_specs[spec] = device_spec
+ spec = device_spec
+ else:
+ spec = cached_device_spec
cached_function = _cached_device_functions.get(spec, None)
if cached_function is not None:
return cached_function
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index 41ef2e11d1..440e3a0968 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -245,7 +245,7 @@ class SparseTensor(_TensorLike):
SparseTensorValue = collections.namedtuple(
"SparseTensorValue", ["indices", "values", "dense_shape"])
tf_export("SparseTensorValue")(SparseTensorValue)
-pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue)
+pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
@tf_export("convert_to_tensor_or_sparse_tensor")
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index c411a58b70..61e0abbfcb 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -67,6 +67,7 @@ limitations under the License.
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
%rename("%s") TFE_Py_RegisterVSpace;
+%rename("%s") TFE_Py_EncodeArg;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 758cba7487..d67dbde304 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -819,5 +819,5 @@ def flatten_with_joined_string_paths(structure, separator="/"):
return list(zip(flat_string_paths, flatten(structure)))
-_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
-_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping)
+_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
+_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 38b8491c66..7b3e618e84 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -29,14 +29,51 @@ limitations under the License.
namespace tensorflow {
namespace swig {
-namespace {
+std::unordered_map<string, PyObject*>* PythonTypesMap() {
+ static auto* m = new std::unordered_map<string, PyObject*>();
+ return m;
+}
+
+PyObject* GetRegisteredType(const string& key) {
+ auto* m = PythonTypesMap();
+ auto it = m->find(key);
+ if (it == m->end()) return nullptr;
+ return it->second;
+}
+
+PyObject* RegisterType(PyObject* type_name, PyObject* type) {
+ if (!PyType_Check(type)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat("Expecting a type, got ",
+ Py_TYPE(type)->tp_name)
+ .c_str());
+ return nullptr;
+ }
-// Type object for collections.Sequence. This is set by RegisterSequenceClass.
-PyObject* CollectionsSequenceType = nullptr;
-// Type object for collections.Mapping, set by RegisterMappingClass.
-PyObject* CollectionsMappingType = nullptr;
-PyTypeObject* SparseTensorValueType = nullptr;
+ string key;
+ if (PyBytes_Check(type_name)) {
+ key = PyBytes_AsString(type_name);
+ }
+#if PY_MAJOR_VERSION >= 3
+ if (PyUnicode_Check(type_name)) {
+ key = PyUnicode_AsUTF8(type_name);
+ }
+#endif
+ if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
+ PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
+ "Type already registered for ", key)
+ .c_str());
+ return nullptr;
+ }
+
+ Py_INCREF(type);
+ PythonTypesMap()->emplace(key, type);
+
+ Py_RETURN_NONE;
+}
+
+namespace {
const int kMaxItemsInCache = 1024;
bool WarnedThatSetIsNotSequence = false;
@@ -177,46 +214,82 @@ class CachedTypeCheck {
// Returns -1 if an error occurred.
int IsMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- return PyObject_IsInstance(to_check, CollectionsMappingType);
+ PyObject* collections_mapping_type = GetRegisteredType("Mapping");
+ if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please register the type with the identifier "
+ "\"Mapping\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, collections_mapping_type);
});
if (PyDict_Check(o)) return true;
- if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Mapping type has not been set. "
- "Please call RegisterMappingClass before using this module")
- .c_str());
- return -1;
- }
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
int IsAttrsHelper(PyObject* o) {
- Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
- if (cls) {
- return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
- } else {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ }
+
// PyObject_GetAttrString returns null on error
PyErr_Clear();
return 0;
- }
+ });
+ return check_cache->CachedLookup(o);
}
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 1 if `o` is an object of type IndexedSlices.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
+int IsIndexedSlicesHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
-
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
+ PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices");
+ if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "IndexedSlices type has not been set. "
+ "Please register the type with the identifier "
+ "\"IndexedSlices\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, indexed_slices_type);
+ });
+ return check_cache->CachedLookup(o);
+}
- return static_cast<int>(is_instance != 0 && !IsString(to_check));
+// Returns 1 if `o` is a Tensor.
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsTensorHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* tensor_type = GetRegisteredType("Tensor");
+ if (TF_PREDICT_FALSE(tensor_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "Tensor type has not been set. "
+ "Please register the type with the identifier "
+ "\"Tensor\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, tensor_type);
});
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
if (IsAttrsHelper(o)) return true;
@@ -226,15 +299,24 @@ int IsSequenceHelper(PyObject* o) {
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ int is_instance = PyObject_IsInstance(to_check, collections_sequence_type);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
return check_cache->CachedLookup(o);
}
@@ -401,11 +483,13 @@ class AttrsValueIterator : public ValueIterator {
};
bool IsSparseTensorValueType(PyObject* o) {
- if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
+ PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
+ if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}
- return PyObject_TypeCheck(o, SparseTensorValueType) == 1;
+ return PyObject_TypeCheck(
+ o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
}
int IsSequenceForDataHelper(PyObject* o) {
@@ -647,49 +731,11 @@ bool AssertSameStructureHelper(
} // namespace
-void RegisterSequenceClass(PyObject* sequence_class) {
- if (!PyType_Check(sequence_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Sequence`. Got ",
- Py_TYPE(sequence_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsSequenceType = sequence_class;
-}
-
-void RegisterMappingClass(PyObject* mapping_class) {
- if (!PyType_Check(mapping_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Mapping`. Got ",
- Py_TYPE(mapping_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsMappingType = mapping_class;
-}
-
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
- if (!PyType_Check(sparse_tensor_value_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `SparseTensorValue`. Got ",
- Py_TYPE(sparse_tensor_value_class)->tp_name)
- .c_str());
- return;
- }
- SparseTensorValueType =
- reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class);
-}
-
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
+bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
+bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
@@ -737,13 +783,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
return nullptr;
}
@@ -755,7 +803,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
- int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+ int is_instance =
+ PyObject_IsInstance(fields.get(), collections_sequence_type);
if (is_instance == 0) {
Py_RETURN_FALSE;
} else if (is_instance == -1) {
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 01f85ea1dc..f37cd527d8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -65,6 +65,24 @@ bool IsMapping(PyObject* o);
// True if the object is an instance of an attr.s decorated class.
bool IsAttrs(PyObject* o);
+// Returns a true if its input is an ops.Tensor.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is a tensor.
+bool IsTensor(PyObject* o);
+
+// Returns a true if its input is an ops.IndexesSlices.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is an ops.IndexedSlices.
+bool IsIndexedSlices(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
@@ -130,18 +148,6 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
// TypeError: The nest is or contains a dict with non-sortable keys.
PyObject* Flatten(PyObject* nested);
-// RegisterSequenceClass is used to pass PyTypeObject for collections.Sequence
-// (which is defined in python) into the C++ world.
-// Alternative approach could be to import the collections modules and retrieve
-// the type from the module. This approach also requires some trigger from
-// Python so that we know that Python interpreter had been initialzied.
-void RegisterSequenceClass(PyObject* sequence_class);
-// Like RegisterSequenceClass, but for collections.Mapping.
-void RegisterMappingClass(PyObject* mapping_class);
-// Similar to the above functions, except for the
-// sparse_tensor.SparseTensorValue class.
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
-
// The tensorflow.python.data package has its own nest utility that follows very
// slightly different semantics for its functions than the tensorflow.python
// nest utility. Returns a true if its input is a collections.Sequence (except
@@ -167,6 +173,10 @@ PyObject* FlattenForData(PyObject* nested);
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types);
+// RegisterType is used to pass PyTypeObject (which is defined in python) for an
+// arbitrary identifier `type_name` into C++.
+PyObject* RegisterType(PyObject* type_name, PyObject* type);
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 32a6e684fa..3c0ec87fa4 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -28,14 +28,8 @@ limitations under the License.
// for functions in this module because they use python methods that need GIL.
// TODO(iga): Find a way not to leak such definitions across files.
-%unignore tensorflow::swig::RegisterSequenceClass;
-%noexception tensorflow::swig::RegisterSequenceClass;
-
-%unignore tensorflow::swig::RegisterMappingClass;
-%noexception tensorflow::swig::RegisterMappingClass;
-
-%unignore tensorflow::swig::RegisterSparseTensorValueClass;
-%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%unignore tensorflow::swig::RegisterType;
+%noexception tensorflow::swig::RegisterType;
%feature("docstring") tensorflow::swig::IsSequence
"""Returns a true if its input is a collections.Sequence (except strings).