diff options
author | 2018-10-01 13:46:31 -0700 | |
---|---|---|
committer | 2018-10-01 13:57:32 -0700 | |
commit | ec900f15e352e4b203b1f0678f7d2ff042df57d5 (patch) | |
tree | 2d7a7ffc0f17cb28801c7a9937b6f4e3777592c7 /tensorflow/python/util | |
parent | 3039a4694e22674b502257ae34b0a5b614a631f3 (diff) |
Minor speed improvements to defun.
- EncodeArg in C instead of python.
- Also caches parsed device specs, and device spec hashes
- Adds a common way to register python types in C.
- Fastpath canonicalize function inputs when no kwargs are passed
- Set the func name attr directly instead of creating an op to wrap it.
- Rewrite IsAttrsHelper without caching
Before:
entry {
name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU"
iters: 30000
wall_time: 101.803263028
extras {
key: "examples_per_sec"
value {
double_value: 9822.86785562
}
}
}
After:
entry {
name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU"
iters: 30000
wall_time: 47.2899993261
extras {
key: "examples_per_sec"
value {
double_value: 21146.1199884
}
}
}
PiperOrigin-RevId: 215272962
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/nest.py | 4 | ||||
-rw-r--r-- | tensorflow/python/util/util.cc | 223 | ||||
-rw-r--r-- | tensorflow/python/util/util.h | 34 | ||||
-rw-r--r-- | tensorflow/python/util/util.i | 10 |
4 files changed, 162 insertions, 109 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index 758cba7487..d67dbde304 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -819,5 +819,5 @@ def flatten_with_joined_string_paths(structure, separator="/"): return list(zip(flat_string_paths, flatten(structure))) -_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence) -_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping) +_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping) +_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 38b8491c66..7b3e618e84 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -29,14 +29,51 @@ limitations under the License. namespace tensorflow { namespace swig { -namespace { +std::unordered_map<string, PyObject*>* PythonTypesMap() { + static auto* m = new std::unordered_map<string, PyObject*>(); + return m; +} + +PyObject* GetRegisteredType(const string& key) { + auto* m = PythonTypesMap(); + auto it = m->find(key); + if (it == m->end()) return nullptr; + return it->second; +} + +PyObject* RegisterType(PyObject* type_name, PyObject* type) { + if (!PyType_Check(type)) { + PyErr_SetString(PyExc_TypeError, + tensorflow::strings::StrCat("Expecting a type, got ", + Py_TYPE(type)->tp_name) + .c_str()); + return nullptr; + } -// Type object for collections.Sequence. This is set by RegisterSequenceClass. -PyObject* CollectionsSequenceType = nullptr; -// Type object for collections.Mapping, set by RegisterMappingClass. -PyObject* CollectionsMappingType = nullptr; -PyTypeObject* SparseTensorValueType = nullptr; + string key; + if (PyBytes_Check(type_name)) { + key = PyBytes_AsString(type_name); + } +#if PY_MAJOR_VERSION >= 3 + if (PyUnicode_Check(type_name)) { + key = PyUnicode_AsUTF8(type_name); + } +#endif + if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) { + PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat( + "Type already registered for ", key) + .c_str()); + return nullptr; + } + + Py_INCREF(type); + PythonTypesMap()->emplace(key, type); + + Py_RETURN_NONE; +} + +namespace { const int kMaxItemsInCache = 1024; bool WarnedThatSetIsNotSequence = false; @@ -177,46 +214,82 @@ class CachedTypeCheck { // Returns -1 if an error occurred. int IsMappingHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - return PyObject_IsInstance(to_check, CollectionsMappingType); + PyObject* collections_mapping_type = GetRegisteredType("Mapping"); + if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Mapping type has not been set. " + "Please register the type with the identifier " + "\"Mapping\" using RegisterType.") + .c_str()); + return -1; + } + return PyObject_IsInstance(to_check, collections_mapping_type); }); if (PyDict_Check(o)) return true; - if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Mapping type has not been set. " - "Please call RegisterMappingClass before using this module") - .c_str()); - return -1; - } return check_cache->CachedLookup(o); } // Returns 1 if `o` is an instance of attrs-decorated class. // Returns 0 otherwise. int IsAttrsHelper(PyObject* o) { - Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__")); - if (cls) { - return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); - } else { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__")); + if (cls) { + return PyObject_HasAttrString(cls.get(), "__attrs_attrs__"); + } + // PyObject_GetAttrString returns null on error PyErr_Clear(); return 0; - } + }); + return check_cache->CachedLookup(o); } -// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). +// Returns 1 if `o` is an object of type IndexedSlices. // Returns 0 otherwise. // Returns -1 if an error occurred. -int IsSequenceHelper(PyObject* o) { +int IsIndexedSlicesHelper(PyObject* o) { static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { - int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType); - - // Don't cache a failed is_instance check. - if (is_instance == -1) return -1; + PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices"); + if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "IndexedSlices type has not been set. " + "Please register the type with the identifier " + "\"IndexedSlices\" using RegisterType.") + .c_str()); + return -1; + } + return PyObject_IsInstance(to_check, indexed_slices_type); + }); + return check_cache->CachedLookup(o); +} - return static_cast<int>(is_instance != 0 && !IsString(to_check)); +// Returns 1 if `o` is a Tensor. +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsTensorHelper(PyObject* o) { + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + PyObject* tensor_type = GetRegisteredType("Tensor"); + if (TF_PREDICT_FALSE(tensor_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "Tensor type has not been set. " + "Please register the type with the identifier " + "\"Tensor\" using RegisterType.") + .c_str()); + return -1; + } + return PyObject_IsInstance(to_check, tensor_type); }); + return check_cache->CachedLookup(o); +} + +// Returns 1 if `o` is considered a sequence for the purposes of Flatten(). +// Returns 0 otherwise. +// Returns -1 if an error occurred. +int IsSequenceHelper(PyObject* o) { // We treat dicts and other mappings as special cases of sequences. if (IsMappingHelper(o)) return true; if (IsAttrsHelper(o)) return true; @@ -226,15 +299,24 @@ int IsSequenceHelper(PyObject* o) { "so consider avoiding using them."; WarnedThatSetIsNotSequence = true; } - if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Sequence type has not been set. " - "Please call RegisterSequenceClass before using this module") - .c_str()); - return -1; - } + static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) { + PyObject* collections_sequence_type = GetRegisteredType("Sequence"); + if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Sequence type has not been set. " + "Please register the type with the identifier " + "\"Sequence\" using RegisterType.") + .c_str()); + return -1; + } + int is_instance = PyObject_IsInstance(to_check, collections_sequence_type); + + // Don't cache a failed is_instance check. + if (is_instance == -1) return -1; + + return static_cast<int>(is_instance != 0 && !IsString(to_check)); + }); return check_cache->CachedLookup(o); } @@ -401,11 +483,13 @@ class AttrsValueIterator : public ValueIterator { }; bool IsSparseTensorValueType(PyObject* o) { - if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { + PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue"); + if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) { return false; } - return PyObject_TypeCheck(o, SparseTensorValueType) == 1; + return PyObject_TypeCheck( + o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1; } int IsSequenceForDataHelper(PyObject* o) { @@ -647,49 +731,11 @@ bool AssertSameStructureHelper( } // namespace -void RegisterSequenceClass(PyObject* sequence_class) { - if (!PyType_Check(sequence_class)) { - PyErr_SetString( - PyExc_TypeError, - tensorflow::strings::StrCat( - "Expecting a class definition for `collections.Sequence`. Got ", - Py_TYPE(sequence_class)->tp_name) - .c_str()); - return; - } - CollectionsSequenceType = sequence_class; -} - -void RegisterMappingClass(PyObject* mapping_class) { - if (!PyType_Check(mapping_class)) { - PyErr_SetString( - PyExc_TypeError, - tensorflow::strings::StrCat( - "Expecting a class definition for `collections.Mapping`. Got ", - Py_TYPE(mapping_class)->tp_name) - .c_str()); - return; - } - CollectionsMappingType = mapping_class; -} - -void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { - if (!PyType_Check(sparse_tensor_value_class)) { - PyErr_SetString( - PyExc_TypeError, - tensorflow::strings::StrCat( - "Expecting a class definition for `SparseTensorValue`. Got ", - Py_TYPE(sparse_tensor_value_class)->tp_name) - .c_str()); - return; - } - SparseTensorValueType = - reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class); -} - bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; } +bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; } +bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); @@ -737,13 +783,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) { } } - if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) { - PyErr_SetString( - PyExc_RuntimeError, - tensorflow::strings::StrCat( - "collections.Sequence type has not been set. " - "Please call RegisterSequenceClass before using this module") - .c_str()); + PyObject* collections_sequence_type = GetRegisteredType("Sequence"); + + if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) { + PyErr_SetString(PyExc_RuntimeError, + tensorflow::strings::StrCat( + "collections.Sequence type has not been set. " + "Please register the type with the identifier " + "\"Sequence\" using RegisterType.") + .c_str()); return nullptr; } @@ -755,7 +803,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) { } Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields")); - int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType); + int is_instance = + PyObject_IsInstance(fields.get(), collections_sequence_type); if (is_instance == 0) { Py_RETURN_FALSE; } else if (is_instance == -1) { diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 01f85ea1dc..f37cd527d8 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -65,6 +65,24 @@ bool IsMapping(PyObject* o); // True if the object is an instance of an attr.s decorated class. bool IsAttrs(PyObject* o); +// Returns a true if its input is an ops.Tensor. +// +// Args: +// seq: the input to be checked. +// +// Returns: +// True if the object is a tensor. +bool IsTensor(PyObject* o); + +// Returns a true if its input is an ops.IndexesSlices. +// +// Args: +// seq: the input to be checked. +// +// Returns: +// True if the object is an ops.IndexedSlices. +bool IsIndexedSlices(PyObject* o); + // Implements the same interface as tensorflow.util.nest._same_namedtuples // Returns Py_True iff the two namedtuples have the same name and fields. // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have @@ -130,18 +148,6 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types); // TypeError: The nest is or contains a dict with non-sortable keys. PyObject* Flatten(PyObject* nested); -// RegisterSequenceClass is used to pass PyTypeObject for collections.Sequence -// (which is defined in python) into the C++ world. -// Alternative approach could be to import the collections modules and retrieve -// the type from the module. This approach also requires some trigger from -// Python so that we know that Python interpreter had been initialzied. -void RegisterSequenceClass(PyObject* sequence_class); -// Like RegisterSequenceClass, but for collections.Mapping. -void RegisterMappingClass(PyObject* mapping_class); -// Similar to the above functions, except for the -// sparse_tensor.SparseTensorValue class. -void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class); - // The tensorflow.python.data package has its own nest utility that follows very // slightly different semantics for its functions than the tensorflow.python // nest utility. Returns a true if its input is a collections.Sequence (except @@ -167,6 +173,10 @@ PyObject* FlattenForData(PyObject* nested); PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2, bool check_types); +// RegisterType is used to pass PyTypeObject (which is defined in python) for an +// arbitrary identifier `type_name` into C++. +PyObject* RegisterType(PyObject* type_name, PyObject* type); + } // namespace swig } // namespace tensorflow diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 32a6e684fa..3c0ec87fa4 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -28,14 +28,8 @@ limitations under the License. // for functions in this module because they use python methods that need GIL. // TODO(iga): Find a way not to leak such definitions across files. -%unignore tensorflow::swig::RegisterSequenceClass; -%noexception tensorflow::swig::RegisterSequenceClass; - -%unignore tensorflow::swig::RegisterMappingClass; -%noexception tensorflow::swig::RegisterMappingClass; - -%unignore tensorflow::swig::RegisterSparseTensorValueClass; -%noexception tensorflow::swig::RegisterSparseTensorValueClass; +%unignore tensorflow::swig::RegisterType; +%noexception tensorflow::swig::RegisterType; %feature("docstring") tensorflow::swig::IsSequence """Returns a true if its input is a collections.Sequence (except strings). |