aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-10-01 13:46:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 13:57:32 -0700
commitec900f15e352e4b203b1f0678f7d2ff042df57d5 (patch)
tree2d7a7ffc0f17cb28801c7a9937b6f4e3777592c7 /tensorflow/python/util
parent3039a4694e22674b502257ae34b0a5b614a631f3 (diff)
Minor speed improvements to defun.
- EncodeArg in C instead of python. - Also caches parsed device specs, and device spec hashes - Adds a common way to register python types in C. - Fastpath canonicalize function inputs when no kwargs are passed - Set the func name attr directly instead of creating an op to wrap it. - Rewrite IsAttrsHelper without caching Before: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 101.803263028 extras { key: "examples_per_sec" value { double_value: 9822.86785562 } } } After: entry { name: "MicroBenchmarks.benchmark_defun_matmul_2_by_2_CPU" iters: 30000 wall_time: 47.2899993261 extras { key: "examples_per_sec" value { double_value: 21146.1199884 } } } PiperOrigin-RevId: 215272962
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/nest.py4
-rw-r--r--tensorflow/python/util/util.cc223
-rw-r--r--tensorflow/python/util/util.h34
-rw-r--r--tensorflow/python/util/util.i10
4 files changed, 162 insertions, 109 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 758cba7487..d67dbde304 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -819,5 +819,5 @@ def flatten_with_joined_string_paths(structure, separator="/"):
return list(zip(flat_string_paths, flatten(structure)))
-_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
-_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping)
+_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
+_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 38b8491c66..7b3e618e84 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -29,14 +29,51 @@ limitations under the License.
namespace tensorflow {
namespace swig {
-namespace {
+std::unordered_map<string, PyObject*>* PythonTypesMap() {
+ static auto* m = new std::unordered_map<string, PyObject*>();
+ return m;
+}
+
+PyObject* GetRegisteredType(const string& key) {
+ auto* m = PythonTypesMap();
+ auto it = m->find(key);
+ if (it == m->end()) return nullptr;
+ return it->second;
+}
+
+PyObject* RegisterType(PyObject* type_name, PyObject* type) {
+ if (!PyType_Check(type)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat("Expecting a type, got ",
+ Py_TYPE(type)->tp_name)
+ .c_str());
+ return nullptr;
+ }
-// Type object for collections.Sequence. This is set by RegisterSequenceClass.
-PyObject* CollectionsSequenceType = nullptr;
-// Type object for collections.Mapping, set by RegisterMappingClass.
-PyObject* CollectionsMappingType = nullptr;
-PyTypeObject* SparseTensorValueType = nullptr;
+ string key;
+ if (PyBytes_Check(type_name)) {
+ key = PyBytes_AsString(type_name);
+ }
+#if PY_MAJOR_VERSION >= 3
+ if (PyUnicode_Check(type_name)) {
+ key = PyUnicode_AsUTF8(type_name);
+ }
+#endif
+ if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
+ PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
+ "Type already registered for ", key)
+ .c_str());
+ return nullptr;
+ }
+
+ Py_INCREF(type);
+ PythonTypesMap()->emplace(key, type);
+
+ Py_RETURN_NONE;
+}
+
+namespace {
const int kMaxItemsInCache = 1024;
bool WarnedThatSetIsNotSequence = false;
@@ -177,46 +214,82 @@ class CachedTypeCheck {
// Returns -1 if an error occurred.
int IsMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- return PyObject_IsInstance(to_check, CollectionsMappingType);
+ PyObject* collections_mapping_type = GetRegisteredType("Mapping");
+ if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please register the type with the identifier "
+ "\"Mapping\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, collections_mapping_type);
});
if (PyDict_Check(o)) return true;
- if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Mapping type has not been set. "
- "Please call RegisterMappingClass before using this module")
- .c_str());
- return -1;
- }
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
int IsAttrsHelper(PyObject* o) {
- Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
- if (cls) {
- return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
- } else {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ }
+
// PyObject_GetAttrString returns null on error
PyErr_Clear();
return 0;
- }
+ });
+ return check_cache->CachedLookup(o);
}
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 1 if `o` is an object of type IndexedSlices.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
+int IsIndexedSlicesHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
-
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
+ PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices");
+ if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "IndexedSlices type has not been set. "
+ "Please register the type with the identifier "
+ "\"IndexedSlices\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, indexed_slices_type);
+ });
+ return check_cache->CachedLookup(o);
+}
- return static_cast<int>(is_instance != 0 && !IsString(to_check));
+// Returns 1 if `o` is a Tensor.
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsTensorHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* tensor_type = GetRegisteredType("Tensor");
+ if (TF_PREDICT_FALSE(tensor_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "Tensor type has not been set. "
+ "Please register the type with the identifier "
+ "\"Tensor\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, tensor_type);
});
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
if (IsAttrsHelper(o)) return true;
@@ -226,15 +299,24 @@ int IsSequenceHelper(PyObject* o) {
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ int is_instance = PyObject_IsInstance(to_check, collections_sequence_type);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
return check_cache->CachedLookup(o);
}
@@ -401,11 +483,13 @@ class AttrsValueIterator : public ValueIterator {
};
bool IsSparseTensorValueType(PyObject* o) {
- if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
+ PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
+ if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}
- return PyObject_TypeCheck(o, SparseTensorValueType) == 1;
+ return PyObject_TypeCheck(
+ o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
}
int IsSequenceForDataHelper(PyObject* o) {
@@ -647,49 +731,11 @@ bool AssertSameStructureHelper(
} // namespace
-void RegisterSequenceClass(PyObject* sequence_class) {
- if (!PyType_Check(sequence_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Sequence`. Got ",
- Py_TYPE(sequence_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsSequenceType = sequence_class;
-}
-
-void RegisterMappingClass(PyObject* mapping_class) {
- if (!PyType_Check(mapping_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Mapping`. Got ",
- Py_TYPE(mapping_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsMappingType = mapping_class;
-}
-
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
- if (!PyType_Check(sparse_tensor_value_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `SparseTensorValue`. Got ",
- Py_TYPE(sparse_tensor_value_class)->tp_name)
- .c_str());
- return;
- }
- SparseTensorValueType =
- reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class);
-}
-
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
+bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
+bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
@@ -737,13 +783,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
return nullptr;
}
@@ -755,7 +803,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
- int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+ int is_instance =
+ PyObject_IsInstance(fields.get(), collections_sequence_type);
if (is_instance == 0) {
Py_RETURN_FALSE;
} else if (is_instance == -1) {
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 01f85ea1dc..f37cd527d8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -65,6 +65,24 @@ bool IsMapping(PyObject* o);
// True if the object is an instance of an attr.s decorated class.
bool IsAttrs(PyObject* o);
+// Returns a true if its input is an ops.Tensor.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is a tensor.
+bool IsTensor(PyObject* o);
+
+// Returns a true if its input is an ops.IndexesSlices.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is an ops.IndexedSlices.
+bool IsIndexedSlices(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
@@ -130,18 +148,6 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
// TypeError: The nest is or contains a dict with non-sortable keys.
PyObject* Flatten(PyObject* nested);
-// RegisterSequenceClass is used to pass PyTypeObject for collections.Sequence
-// (which is defined in python) into the C++ world.
-// Alternative approach could be to import the collections modules and retrieve
-// the type from the module. This approach also requires some trigger from
-// Python so that we know that Python interpreter had been initialzied.
-void RegisterSequenceClass(PyObject* sequence_class);
-// Like RegisterSequenceClass, but for collections.Mapping.
-void RegisterMappingClass(PyObject* mapping_class);
-// Similar to the above functions, except for the
-// sparse_tensor.SparseTensorValue class.
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
-
// The tensorflow.python.data package has its own nest utility that follows very
// slightly different semantics for its functions than the tensorflow.python
// nest utility. Returns a true if its input is a collections.Sequence (except
@@ -167,6 +173,10 @@ PyObject* FlattenForData(PyObject* nested);
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types);
+// RegisterType is used to pass PyTypeObject (which is defined in python) for an
+// arbitrary identifier `type_name` into C++.
+PyObject* RegisterType(PyObject* type_name, PyObject* type);
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 32a6e684fa..3c0ec87fa4 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -28,14 +28,8 @@ limitations under the License.
// for functions in this module because they use python methods that need GIL.
// TODO(iga): Find a way not to leak such definitions across files.
-%unignore tensorflow::swig::RegisterSequenceClass;
-%noexception tensorflow::swig::RegisterSequenceClass;
-
-%unignore tensorflow::swig::RegisterMappingClass;
-%noexception tensorflow::swig::RegisterMappingClass;
-
-%unignore tensorflow::swig::RegisterSparseTensorValueClass;
-%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%unignore tensorflow::swig::RegisterType;
+%noexception tensorflow::swig::RegisterType;
%feature("docstring") tensorflow::swig::IsSequence
"""Returns a true if its input is a collections.Sequence (except strings).