aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-08-17 16:18:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 16:25:56 -0700
commitf1ad54b58b7ce2e08b5f4e38a1631dc667e3e7af (patch)
tree7bd201b89817a8b516e6adb19ac7e5eccefc74f4 /tensorflow/python/util
parentfbdef63fe5849cde5423f8c3cc9c348ed4fe75c3 (diff)
Add a benchmark for forward+backward for defuns.
Also fix some simple issues that I saw when I benchmarked it (goes from ~3500 examples/sec -> ~4000 examples/sec) - (nest) Expose is_mapping check that caches to python. - (nest) Stop calling flatten when unnecessary in pack_sequence_as - (nest) Set some functions to their swig wrappers directly (instead of wrapping them in another function) - Directly call the gen_math_ops call in _aggregate_grads to skip any unnecessary python overhead. - Stop falling back to slow path in _fast_fill. PiperOrigin-RevId: 209223633
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/nest.py76
-rw-r--r--tensorflow/python/util/util.cc1
-rw-r--r--tensorflow/python/util/util.h9
-rw-r--r--tensorflow/python/util/util.i52
4 files changed, 85 insertions, 53 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index faae0d89c3..2968ca9c07 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -62,6 +62,10 @@ def _is_namedtuple(instance, strict=False):
return _pywrap_tensorflow.IsNamedtuple(instance, strict)
+# See the swig file (util.i) for documentation.
+_is_mapping = _pywrap_tensorflow.IsMapping
+
+
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@@ -73,7 +77,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
- if isinstance(instance, (dict, _collections.Mapping)):
+ if _is_mapping(instance):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -89,7 +93,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
- if isinstance(iterable, (dict, _collections.Mapping)):
+ if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -102,53 +106,16 @@ def _yield_value(iterable):
yield value
-def is_sequence(seq):
- """Returns a true if its input is a collections.Sequence (except strings).
+# See the swig file (util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequence
- Args:
- seq: an input sequence.
- Returns:
- True if the sequence is a not a string and is a collections.Sequence or a
- dict.
- """
- return _pywrap_tensorflow.IsSequence(seq)
+# See the swig file (util.i) for documentation.
+flatten = _pywrap_tensorflow.Flatten
-def flatten(nest):
- """Returns a flat list from a given nested structure.
-
- If `nest` is not a sequence, tuple, or dict, then returns a single-element
- list: `[nest]`.
-
- In the case of dict instances, the sequence consists of the values, sorted by
- key to ensure deterministic behavior. This is true also for `OrderedDict`
- instances: their sequence order is ignored, the sorting order of keys is
- used instead. The same convention is followed in `pack_sequence_as`. This
- correctly repacks dicts and `OrderedDict`s after they have been flattened,
- and also allows flattening an `OrderedDict` and then repacking it back using
- a corresponding plain dict, or vice-versa.
- Dictionaries with non-sortable keys cannot be flattened.
-
- Users must not modify any collections used in `nest` while this function is
- running.
-
- Args:
- nest: an arbitrarily nested structure or a scalar object. Note, numpy
- arrays are considered scalars.
-
- Returns:
- A Python list, the flattened version of the input.
-
- Raises:
- TypeError: The nest is or contains a dict with non-sortable keys.
- """
- return _pywrap_tensorflow.Flatten(nest)
-
-
-def _same_namedtuples(nest1, nest2):
- """Returns True if the two namedtuples have the same name and fields."""
- return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
+# See the swig file (util.i) for documentation.
+_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
def assert_same_structure(nest1, nest2, check_types=True):
@@ -311,14 +278,17 @@ def pack_sequence_as(structure, flat_sequence):
% len(flat_sequence))
return flat_sequence[0]
- flat_structure = flatten(structure)
- if len(flat_structure) != len(flat_sequence):
- raise ValueError(
- "Could not pack sequence. Structure had %d elements, but flat_sequence "
- "had %d elements. Structure: %s, flat_sequence: %s."
- % (len(flat_structure), len(flat_sequence), structure, flat_sequence))
-
- _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ try:
+ final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ if final_index < len(flat_sequence):
+ raise IndexError
+ except IndexError:
+ flat_structure = flatten(structure)
+ if len(flat_structure) != len(flat_sequence):
+ raise ValueError(
+ "Could not pack sequence. Structure had %d elements, but "
+ "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." %
+ (len(flat_structure), len(flat_sequence), structure, flat_sequence))
return _sequence_like(structure, packed)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index ebb72079ef..61249d664b 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -647,6 +647,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
}
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
+bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 41dcc969f8..f15ebb6efe 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -47,6 +47,15 @@ bool IsSequence(PyObject* o);
// True if `instance` is a `namedtuple`.
PyObject* IsNamedtuple(PyObject* o, bool strict);
+// Returns a true if its input is a collections.Mapping.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the sequence subclasses mapping.
+bool IsMapping(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6ad1484295..8d9f9615d7 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -37,18 +37,70 @@ limitations under the License.
%unignore tensorflow::swig::RegisterSparseTensorValueClass;
%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%feature("docstring") tensorflow::swig::IsSequence
+"""Returns a true if its input is a collections.Sequence (except strings).
+
+Args:
+ seq: an input sequence.
+
+Returns:
+ True if the sequence is a not a string and is a collections.Sequence or a
+ dict.
+"""
%unignore tensorflow::swig::IsSequence;
%noexception tensorflow::swig::IsSequence;
%unignore tensorflow::swig::IsNamedtuple;
%noexception tensorflow::swig::IsNamedtuple;
+%feature("docstring") tensorflow::swig::IsMapping
+"""Returns True iff `instance` is a `collections.Mapping`.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is a `collections.Mapping`.
+"""
+%unignore tensorflow::swig::IsMapping;
+%noexception tensorflow::swig::IsMapping;
+
+%feature("docstring") tensorflow::swig::SameNamedtuples
+"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;
%noexception tensorflow::swig::SameNamedtuples;
%unignore tensorflow::swig::AssertSameStructure;
%noexception tensorflow::swig::AssertSameStructure;
+%feature("docstring") tensorflow::swig::Flatten
+"""Returns a flat list from a given nested structure.
+
+If `nest` is not a sequence, tuple, or dict, then returns a single-element
+list: `[nest]`.
+
+In the case of dict instances, the sequence consists of the values, sorted by
+key to ensure deterministic behavior. This is true also for `OrderedDict`
+instances: their sequence order is ignored, the sorting order of keys is
+used instead. The same convention is followed in `pack_sequence_as`. This
+correctly repacks dicts and `OrderedDict`s after they have been flattened,
+and also allows flattening an `OrderedDict` and then repacking it back using
+a corresponding plain dict, or vice-versa.
+Dictionaries with non-sortable keys cannot be flattened.
+
+Users must not modify any collections used in `nest` while this function is
+running.
+
+Args:
+ nest: an arbitrarily nested structure or a scalar object. Note, numpy
+ arrays are considered scalars.
+
+Returns:
+ A Python list, the flattened version of the input.
+
+Raises:
+ TypeError: The nest is or contains a dict with non-sortable keys.
+"""
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;