diff options
-rwxr-xr-x | tensorflow/c/eager/c_api.cc | 8 | ||||
-rwxr-xr-x | tensorflow/c/eager/c_api.h | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 100 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 26 | ||||
-rwxr-xr-x | tensorflow/python/eager/pywrap_tfe.h | 4 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 223 | ||||
-rw-r--r-- | tensorflow/python/framework/device.py | 12 | ||||
-rw-r--r-- | tensorflow/python/framework/sparse_tensor.py | 2 | ||||
-rwxr-xr-x | tensorflow/python/pywrap_tfe.i | 1 | ||||
-rw-r--r-- | tensorflow/python/util/nest.py | 4 | ||||
-rw-r--r-- | tensorflow/python/util/util.cc | 223 | ||||
-rw-r--r-- | tensorflow/python/util/util.h | 34 | ||||
-rw-r--r-- | tensorflow/python/util/util.i | 10 |
14 files changed, 462 insertions, 189 deletions
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 0bf3d9542b..3554ec0bf3 100755 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -578,6 +578,14 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, op->operation.MutableAttrs()->Set(attr_name, attr_value); } +void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length) { + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(data, length); + op->operation.MutableAttrs()->Set(attr_name, attr_value); +} + void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, TF_Status* status) { tensorflow::Tensor t; diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 6323f8a053..b2454d8722 100755 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -313,6 +313,9 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); +TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name, + const char* data, size_t length); + TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor, diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index d3d997e6df..d0c1a93118 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -37,6 +37,7 @@ cc_library( "//tensorflow/python:safe_ptr", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:variant", ], ) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 60a4f018cd..3b6f288fb9 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1005,52 +1005,8 @@ def func_graph_from_py_func(name, return func_graph -_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"]) - - -def _encode_arg(arg): - """A canonical representation for this argument, for use in a cache key.""" - - # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes - # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes - # are used for both performance reasons, as much TensorFlow code specializes - # on known shapes to produce slimmer graphs, and correctness, as some - # high-level APIs require shapes to be fully-known. - # - # TODO(akshayka): Add support for sparse tensors. - # - # pylint: disable=protected-access - if isinstance(arg, ops.Tensor): - return _TensorType(arg.dtype, arg._shape_tuple()) - elif isinstance(arg, ops.IndexedSlices): - if arg.dense_shape is not None: - return tuple([ - _TensorType(arg.values.dtype, arg.values._shape_tuple()), - _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), - _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()), - ]) - else: - return tuple([ - _TensorType(arg.values.dtype, arg.values._shape_tuple()), - _TensorType(arg.indices.dtype, arg.indices._shape_tuple()), - ]) - # pylint: enable=protected-access - elif isinstance(arg, (list, tuple)): - return tuple([_encode_arg(elem) for elem in arg]) - elif isinstance(arg, dict): - return tuple( - (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg)) - else: - try: - # If possible, keep only a weak reference to Python objects. Weak - # references hash to the same value as the original object. - # TODO(allenl): Clean up dead functions and their cache keys if the cache - # gets large. Right now creating objects with a defunned method, calling - # the method, and losing a reference to the object in a loop will leak - # memory here. - return weakref.ref(arg) - except TypeError: - return arg +pywrap_tensorflow.RegisterType("Tensor", ops.Tensor) +pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices) def _deterministic_dict_values(dictionary): @@ -1120,6 +1076,8 @@ class PolymorphicFunction(object): offset + index: default for index, default in enumerate(fullargspec.defaults or []) } + self._default_values = fullargspec.defaults + self._default_values_start_index = offset if input_signature is None: self._input_signature = None else: @@ -1180,7 +1138,7 @@ class PolymorphicFunction(object): """Computes the cache key given inputs and execution context.""" if self._input_signature is None: inputs = (args, kwargs) if kwargs else args - cache_key = tuple(_encode_arg(arg) for arg in inputs) + cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs) else: del args, kwargs cache_key = self._flat_input_signature @@ -1203,7 +1161,7 @@ class PolymorphicFunction(object): colocation_stack = (() if executing_eagerly else tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access - return cache_key + (execution_context, device_functions, colocation_stack) + return (cache_key, execution_context, device_functions, colocation_stack) def _canonicalize_function_inputs(self, *args, **kwargs): """Canonicalizes `args` and `kwargs`. @@ -1231,26 +1189,32 @@ class PolymorphicFunction(object): # Maps from index of arg to its corresponding value, according to `args` # and `kwargs`; seeded with the default values for the named args that # aren't in `args`. - arg_indices_to_values = { - index: default - for index, default in six.iteritems(self._arg_indices_to_default_values) - if index >= len(args) - } - consumed_args = [] - for arg, value in six.iteritems(kwargs): - index = self._args_to_indices.get(arg, None) - if index is not None: - arg_indices_to_values[index] = value - consumed_args.append(arg) - elif self._input_signature is not None: - raise ValueError("Cannot define a TensorFlow function from a Python " - "function with keyword arguments when " - "input_signature is provided.") - for arg in consumed_args: - # After this loop, `kwargs` will only contain true keyword arguments, as - # opposed to named arguments called in a keyword-like fashion. - kwargs.pop(arg) - inputs = args + _deterministic_dict_values(arg_indices_to_values) + if not kwargs: + if self._default_values: + inputs = args + self._default_values[len(args) - + self._default_values_start_index:] + else: + inputs = args + else: + arg_indices_to_values = { + index: default for index, default in six.iteritems( + self._arg_indices_to_default_values) if index >= len(args) + } + consumed_args = [] + for arg, value in six.iteritems(kwargs): + index = self._args_to_indices.get(arg, None) + if index is not None: + arg_indices_to_values[index] = value + consumed_args.append(arg) + elif self._input_signature is not None: + raise ValueError("Cannot define a TensorFlow function from a Python " + "function with keyword arguments when " + "input_signature is provided.") + for arg in consumed_args: + # After this loop, `kwargs` will only contain true keyword arguments, as + # opposed to named arguments called in a keyword-like fashion. + kwargs.pop(arg) + inputs = args + _deterministic_dict_values(arg_indices_to_values) flat_inputs = nest.flatten(inputs) # Check for NumPy arrays in arguments and convert them to Tensors. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index afe3ba9893..9ce367a837 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1237,6 +1237,24 @@ class FunctionTest(test.TestCase): x = constant_op.constant([1.0, 2.0]) self.assertAllEqual([2., 4.], self.evaluate(defined(x))) + def testCacheObjectHashCollisions(self): + + class Foo(object): + + def __hash__(self): + return 42 + + def func(foo): + del foo + return + + defined = function.defun(func) + defined(Foo()) + self.assertEqual(len(defined._function_cache), 1) + + defined(Foo()) + self.assertEqual(len(defined._function_cache), 2) + def testPythonFunctionWithDefaultArgs(self): def func(foo, bar=1, baz=2): @@ -1250,20 +1268,20 @@ class FunctionTest(test.TestCase): def cache_keys(): """Sanitizes cache keys of non-input metadata.""" - return tuple(key[:3] for key in defined._function_cache) + return tuple(key[0] for key in defined._function_cache) # `True` corresponds to the fact that we're executing eagerly - self.assertIn((0, 1, 20), cache_keys()) + self.assertIn(('tRRR', (0, 1, 20)), cache_keys()) defined(1) # bar=1, baz=2 - self.assertIn((1, 1, 2), cache_keys()) + self.assertIn(('tRRR', (1, 1, 2)), cache_keys()) # This matches the previous call. defined(foo=1) self.assertEqual(len(defined._function_cache), 2) defined(1, 2, 3) - self.assertIn((1, 2, 3), cache_keys()) + self.assertIn(('tRRR', (1, 2, 3)), cache_keys()) # This matches the previous call. defined(1, bar=2, baz=3) diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index f1b4042ec9..decd635b58 100755 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -224,4 +224,8 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim); // The shape is represented as a Python tuple of integers. PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor); +// Encodes the object as a tuple that is meant to be used as part of the key +// for the defun function cache. +PyObject* TFE_Py_EncodeArg(PyObject*); + #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_ diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 196e20e4d7..4b9f7f4100 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/python/eager/pywrap_tfe.h" +#include "absl/strings/str_cat.h" #include "absl/types/variant.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" @@ -567,11 +568,8 @@ bool SetOpAttrScalar( return false; } } - TFE_Op* func = TFE_NewOp( - ctx, string(func_name.data(), func_name.size()).c_str(), status); - if (TF_GetCode(status) != TF_OK) return false; - TFE_OpSetAttrFunction(op, key, func); - TFE_DeleteOp(func); + TF_SetStatus(status, TF_OK, ""); + TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size()); } else { TF_SetStatus( status, TF_UNIMPLEMENTED, @@ -2748,3 +2746,218 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, return RecordGradient(op_name, inputs, attrs, results, name); } + +namespace { + +tensorflow::int64 GetPyNoneHash() { + tensorflow::int64 py_none_hash = PyObject_Hash(Py_None); + return py_none_hash; +} + +struct EncodeResult { + string str; + std::vector<PyObject*> objects; + + PyObject* ToPyTuple() { + PyObject* result = PyTuple_New(2); + + PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str())); + + if (objects.empty()) { + Py_INCREF(Py_None); + PyTuple_SET_ITEM(result, 1, Py_None); + } else { + PyObject* objects_tuple = PyTuple_New(objects.size()); + + for (int i = 0; i < objects.size(); i++) { + PyTuple_SET_ITEM(objects_tuple, i, objects[i]); + } + + PyTuple_SET_ITEM(result, 1, objects_tuple); + } + + return result; + } +}; + +tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) { + if (EagerTensor_CheckExact(arg)) { + TFE_TensorHandle* t = EagerTensor_Handle(arg); + tensorflow::TensorShape tensor_shape; + TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape)); + absl::StrAppend(&result->str, t->handle->dtype); + + for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) { + absl::StrAppend(&result->str, dim_size); + } + + return tensorflow::Status::OK(); + } + + tensorflow::Safe_PyObjectPtr dtype_object( + PyObject_GetAttrString(arg, "dtype")); + + if (dtype_object == nullptr) { + return tensorflow::errors::InvalidArgument( + "ops.Tensor object doesn't have dtype() attr."); + } + + tensorflow::Safe_PyObjectPtr dtype_enum( + PyObject_GetAttrString(dtype_object.get(), "_type_enum")); + + if (dtype_enum == nullptr) { + return tensorflow::errors::InvalidArgument( + "ops.Tensor's dtype object doesn't have _type_enum() attr."); + } + + tensorflow::DataType dtype = + static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get())); + + absl::StrAppend(&result->str, dtype); + static char _shape_tuple[] = "_shape_tuple"; + tensorflow::Safe_PyObjectPtr shape_tuple( + PyObject_CallMethod(arg, _shape_tuple, nullptr)); + + if (shape_tuple == nullptr) { + return tensorflow::errors::InvalidArgument( + "ops.Tensor object doesn't have _shape_tuple() method."); + } + + if (shape_tuple.get() == Py_None) { + // Unknown shape, encode that directly. + absl::StrAppend(&result->str, GetPyNoneHash()); + return tensorflow::Status::OK(); + } + + tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast( + shape_tuple.get(), "shape_tuple didn't return a sequence")); + + int len = PySequence_Fast_GET_SIZE(shape_seq.get()); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i); + if (item == Py_None) { + absl::StrAppend(&result->str, GetPyNoneHash()); + } else { + absl::StrAppend(&result->str, MakeInt(item)); + } + } + + return tensorflow::Status::OK(); +} + +const char kTensor[] = "T"; +const char kIndexedSlices[] = "I"; +const char kList[] = "L"; +const char kTuple[] = "t"; +const char kDict[] = "D"; +const char kRaw[] = "R"; + +tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result); + +// This function doesn't set the type of sequence before +tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type, + EncodeResult* result) { + tensorflow::Safe_PyObjectPtr arg_seq( + PySequence_Fast(arg, "unable to create seq from list/tuple")); + + absl::StrAppend(&result->str, type); + int len = PySequence_Fast_GET_SIZE(arg_seq.get()); + for (int i = 0; i < len; ++i) { + PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i); + if (item == Py_None) { + absl::StrAppend(&result->str, GetPyNoneHash()); + } else { + TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result)); + } + } + + return tensorflow::Status::OK(); +} + +tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) { + if (tensorflow::swig::IsTensor(arg)) { + absl::StrAppend(&result->str, kTensor); + TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result)); + } else if (tensorflow::swig::IsIndexedSlices(arg)) { + absl::StrAppend(&result->str, kIndexedSlices); + tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values")); + if (values == nullptr) { + PyErr_Clear(); + return tensorflow::errors::InvalidArgument( + "IndexedSlices does not have a values attr"); + } + TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result)); + + tensorflow::Safe_PyObjectPtr indices( + PyObject_GetAttrString(arg, "indices")); + if (indices == nullptr) { + PyErr_Clear(); + return tensorflow::errors::InvalidArgument( + "IndexedSlices does not have a indices attr"); + } + TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result)); + + tensorflow::Safe_PyObjectPtr dense_shape( + PyObject_GetAttrString(arg, "dense_shape")); + if (dense_shape == nullptr) { + PyErr_Clear(); + return tensorflow::errors::InvalidArgument( + "IndexedSlices does not have a dense_shape attr"); + } + if (dense_shape.get() != Py_None) { + TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result)); + } + } else if (PyList_Check(arg)) { + TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, result)); + } else if (PyTuple_Check(arg)) { + TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, result)); + } else if (PyDict_Check(arg)) { + tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg)); + if (PyList_Sort(keys.get()) == -1) { + return tensorflow::errors::Internal("Unable to sort keys"); + } + + absl::StrAppend(&result->str, kDict); + int len = PyList_Size(keys.get()); + + for (int i = 0; i < len; i++) { + PyObject* key = PyList_GetItem(keys.get(), i); + TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result)); + PyObject* value = PyDict_GetItem(arg, key); + TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result)); + } + } else { + PyObject* object = PyWeakref_NewRef(arg, nullptr); + + if (object == nullptr) { + PyErr_Clear(); + + object = arg; + Py_INCREF(object); + } + + absl::StrAppend(&result->str, kRaw); + result->objects.push_back(object); + } + + return tensorflow::Status::OK(); +} + +} // namespace + +// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes +// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes +// are used for both performance reasons, as much TensorFlow code specializes +// on known shapes to produce slimmer graphs, and correctness, as some +// high-level APIs require shapes to be fully-known. +// +// TODO(nareshmodi): Add support for sparse tensors. +PyObject* TFE_Py_EncodeArg(PyObject* arg) { + EncodeResult result; + const auto status = TFE_Py_EncodeArgHelper(arg, &result); + if (MaybeRaiseExceptionFromStatus(status, nullptr)) { + return nullptr; + } + + return result.ToPyTuple(); +} diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py index 06c653097a..7f6e0a75a5 100644 --- a/tensorflow/python/framework/device.py +++ b/tensorflow/python/framework/device.py @@ -87,6 +87,7 @@ class DeviceSpec(object): else: self.device_type = device_type self.device_index = device_index + self._hash = hash(self.to_string()) def _clear(self): self._job = None @@ -234,7 +235,7 @@ class DeviceSpec(object): return self.to_string() == other.to_string() def __hash__(self): - return hash(self.to_string()) + return self._hash def check_valid(spec): @@ -266,6 +267,7 @@ def canonical_name(device): # possible to compare the device function stacks belonging to different # graphs in a meaningful way. _cached_device_functions = {} +_cached_device_specs = {} _cache_lock = threading.Lock() @@ -297,7 +299,13 @@ def merge_device(spec): """ with _cache_lock: if not isinstance(spec, DeviceSpec): - spec = DeviceSpec.from_string(spec or "") + cached_device_spec = _cached_device_specs.get(spec, None) + if cached_device_spec is None: + device_spec = DeviceSpec.from_string(spec or "") + _cached_device_specs[spec] = device_spec + spec = device_spec + else: + spec = cached_device_spec cached_function = _cached_device_functions.get(spec, None) if cached_function is not None: return cached_function diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 41ef2e11d1..440e3a0968 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -245,7 +245,7 @@ class SparseTensor(_TensorLike): SparseTensorValue = collections.namedtuple( "SparseTensorValue", ["indices", "values", "dense_shape"]) tf_export("SparseTensorValue")(SparseTensorValue) -pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue) +pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue) @tf_export("convert_to_tensor_or_sparse_tensor") diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index c411a58b70..61e0abbfcb 100755 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -67,6 +67,7 @@ limitations under the License. %rename("%s") TFE_ContextStartStep; %rename("%s") TFE_ContextEndStep; %rename("%s") TFE_Py_RegisterVSpace; +%rename("%s") TFE_Py_EncodeArg; %{ #include "tensorflow/python/eager/pywrap_tfe.h" diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 758cba7487..d67dbde304 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -819,5 +819,5 @@ def flatten_with_joined_string_paths(structure, separator="/"): return list(zip(flat_string_paths, flatten(structure))) -_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence) -_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping) +_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping) +_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 38b8491c66..7b3e618e84 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -29,14 +29,51 @@ limitations under the License. namespace tensorflow { namespace swig { -namespace { +std::unordered_map<string, PyObject*>* PythonTypesMap() { + static auto* m = new std::unordered_map<string, PyObject*>(); + return m; +} + +PyObject* GetRegisteredType(const string& key) { + auto* m = PythonTypesMap(); + auto it = m->find(key); + if (it == m->end()) return nullptr; + return it->second; +} + +PyObject* RegisterType(PyObject* type_name, PyObject* type) { + if (!PyType_Check(type)) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat("Expecting a type, got ", + Py_TYPE(type)->tp_name) + .c_str()); + return nullptr; + } -// Type object for collections.Sequence. This is set by RegisterSequenceClass. -PyObject* CollectionsSequenceType = nullptr; -// Type object for collections.Mapping, set by RegisterMappingClass. -PyObject* CollectionsMappingType = nullptr; -PyTypeObject* SparseTensorValueType = nullptr; + string key; + if (PyBytes_Check(type_name)) { + key = PyBytes_AsString(type_name); + } +#if PY_MAJOR_VERSION >= 3 + if (PyUnicode_Check(type_name)) { + key = PyUnicode_AsUTF8(type_name); + } +#endif + if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) { + PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( + "Type already registered for ", key) + .c_str()); + return nullptr; + } + + Py_INCREF(type); + PythonTypesMap()->emplace(key, type); + + Py_RETURN_NONE; +} + +namespace { const int kMaxItemsInCache = 1024; bool WarnedThatSetIsNotSequence = false; @@ -177,46 +214,82 @@ class CachedTypeCheck { // Returns -1 if an error occurred. int IsMappingHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return PyObject_IsInstance(to_check, CollectionsMappingType); + PyObject* collections_mapping_type = GetRegisteredType("Mapping"); + if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Mapping type has not been set. " + "Please register the type with the identifier " + "\"Mapping\" using RegisterType.") + .c_str()); + return -1; + } + return PyObject_IsInstance(to_check, collections_mapping_type); }); if (PyDict_Check(o)) return true; - if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Mapping type has not been set. " - "Please call RegisterMappingClass before using this module") - .c_str()); - return -1; - } return check_cache->CachedLookup(o); } // Returns 1 if `o` is an instance of attrs-decorated class. // Returns 0 otherwise. int IsAttrsHelper(PyObject* o) { - Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__")); - if (cls) { - return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); - } else { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__")); + if (cls) { + return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); + } + // PyObject_GetAttrString returns null on error PyErr_Clear(); return 0; - } + }); + return check_cache->CachedLookup(o); } -// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). +// Returns 1 if `o` is an object of type IndexedSlices. // Returns 0 otherwise. // Returns -1 if an error occurred. -int IsSequenceHelper(PyObject* o) { +int IsIndexedSlicesHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType); - - // Don't cache a failed is_instance check. - if (is_instance == -1) return -1; + PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices"); + if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "IndexedSlices type has not been set. " + "Please register the type with the identifier " + "\"IndexedSlices\" using RegisterType.") + .c_str()); + return -1; + } + return PyObject_IsInstance(to_check, indexed_slices_type); + }); + return check_cache->CachedLookup(o); +} - return static_cast<int>(is_instance != 0 && !IsString(to_check)); +// Returns 1 if `o` is a Tensor. +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsTensorHelper(PyObject* o) { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + PyObject* tensor_type = GetRegisteredType("Tensor"); + if (TF_PREDICT_FALSE(tensor_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "Tensor type has not been set. " + "Please register the type with the identifier " + "\"Tensor\" using RegisterType.") + .c_str()); + return -1; + } + return PyObject_IsInstance(to_check, tensor_type); }); + return check_cache->CachedLookup(o); +} + +// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsSequenceHelper(PyObject* o) { // We treat dicts and other mappings as special cases of sequences. if (IsMappingHelper(o)) return true; if (IsAttrsHelper(o)) return true; @@ -226,15 +299,24 @@ int IsSequenceHelper(PyObject* o) { "so consider avoiding using them."; WarnedThatSetIsNotSequence = true; } - if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Sequence type has not been set. " - "Please call RegisterSequenceClass before using this module") - .c_str()); - return -1; - } + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + PyObject* collections_sequence_type = GetRegisteredType("Sequence"); + if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Sequence type has not been set. " + "Please register the type with the identifier " + "\"Sequence\" using RegisterType.") + .c_str()); + return -1; + } + int is_instance = PyObject_IsInstance(to_check, collections_sequence_type); + + // Don't cache a failed is_instance check. + if (is_instance == -1) return -1; + + return static_cast<int>(is_instance != 0 && !IsString(to_check)); + }); return check_cache->CachedLookup(o); } @@ -401,11 +483,13 @@ class AttrsValueIterator : public ValueIterator { }; bool IsSparseTensorValueType(PyObject* o) { - if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { + PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue"); + if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) { return false; } - return PyObject_TypeCheck(o, SparseTensorValueType) == 1; + return PyObject_TypeCheck( + o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1; } int IsSequenceForDataHelper(PyObject* o) { @@ -647,49 +731,11 @@ bool AssertSameStructureHelper( } // namespace -void RegisterSequenceClass(PyObject* sequence_class) { - if (!PyType_Check(sequence_class)) { - PyErr_SetString( - PyExc_TypeError, - tensorflow::strings::StrCat( - "Expecting a class definition for `collections.Sequence`. Got ", - Py_TYPE(sequence_class)->tp_name) - .c_str()); - return; - } - CollectionsSequenceType = sequence_class; -} - -void RegisterMappingClass(PyObject* mapping_class) { - if (!PyType_Check(mapping_class)) { - PyErr_SetString( - PyExc_TypeError, - tensorflow::strings::StrCat( - "Expecting a class definition for `collections.Mapping`. Got ", - Py_TYPE(mapping_class)->tp_name) - .c_str()); - return; - } - CollectionsMappingType = mapping_class; -} - -void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { - if (!PyType_Check(sparse_tensor_value_class)) { - PyErr_SetString( - PyExc_TypeError, - tensorflow::strings::StrCat( - "Expecting a class definition for `SparseTensorValue`. Got ", - Py_TYPE(sparse_tensor_value_class)->tp_name) - .c_str()); - return; - } - SparseTensorValueType = - reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class); -} - bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } +bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; } +bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); @@ -737,13 +783,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) { } } - if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Sequence type has not been set. " - "Please call RegisterSequenceClass before using this module") - .c_str()); + PyObject* collections_sequence_type = GetRegisteredType("Sequence"); + + if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Sequence type has not been set. " + "Please register the type with the identifier " + "\"Sequence\" using RegisterType.") + .c_str()); return nullptr; } @@ -755,7 +803,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) { } Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields")); - int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType); + int is_instance = + PyObject_IsInstance(fields.get(), collections_sequence_type); if (is_instance == 0) { Py_RETURN_FALSE; } else if (is_instance == -1) { diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 01f85ea1dc..f37cd527d8 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -65,6 +65,24 @@ bool IsMapping(PyObject* o); // True if the object is an instance of an attr.s decorated class. bool IsAttrs(PyObject* o); +// Returns a true if its input is an ops.Tensor. +// +// Args: +// seq: the input to be checked. +// +// Returns: +// True if the object is a tensor. +bool IsTensor(PyObject* o); + +// Returns a true if its input is an ops.IndexesSlices. +// +// Args: +// seq: the input to be checked. +// +// Returns: +// True if the object is an ops.IndexedSlices. +bool IsIndexedSlices(PyObject* o); + // Implements the same interface as tensorflow.util.nest._same_namedtuples // Returns Py_True iff the two namedtuples have the same name and fields. // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have @@ -130,18 +148,6 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types); // TypeError: The nest is or contains a dict with non-sortable keys. PyObject* Flatten(PyObject* nested); -// RegisterSequenceClass is used to pass PyTypeObject for collections.Sequence -// (which is defined in python) into the C++ world. -// Alternative approach could be to import the collections modules and retrieve -// the type from the module. This approach also requires some trigger from -// Python so that we know that Python interpreter had been initialzied. -void RegisterSequenceClass(PyObject* sequence_class); -// Like RegisterSequenceClass, but for collections.Mapping. -void RegisterMappingClass(PyObject* mapping_class); -// Similar to the above functions, except for the -// sparse_tensor.SparseTensorValue class. -void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class); - // The tensorflow.python.data package has its own nest utility that follows very // slightly different semantics for its functions than the tensorflow.python // nest utility. Returns a true if its input is a collections.Sequence (except @@ -167,6 +173,10 @@ PyObject* FlattenForData(PyObject* nested); PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2, bool check_types); +// RegisterType is used to pass PyTypeObject (which is defined in python) for an +// arbitrary identifier `type_name` into C++. +PyObject* RegisterType(PyObject* type_name, PyObject* type); + } // namespace swig } // namespace tensorflow diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 32a6e684fa..3c0ec87fa4 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -28,14 +28,8 @@ limitations under the License. // for functions in this module because they use python methods that need GIL. // TODO(iga): Find a way not to leak such definitions across files. -%unignore tensorflow::swig::RegisterSequenceClass; -%noexception tensorflow::swig::RegisterSequenceClass; - -%unignore tensorflow::swig::RegisterMappingClass; -%noexception tensorflow::swig::RegisterMappingClass; - -%unignore tensorflow::swig::RegisterSparseTensorValueClass; -%noexception tensorflow::swig::RegisterSparseTensorValueClass; +%unignore tensorflow::swig::RegisterType; +%noexception tensorflow::swig::RegisterType; %feature("docstring") tensorflow::swig::IsSequence """Returns a true if its input is a collections.Sequence (except strings). |