diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-08-17 16:18:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-17 16:25:56 -0700 |
commit | f1ad54b58b7ce2e08b5f4e38a1631dc667e3e7af (patch) | |
tree | 7bd201b89817a8b516e6adb19ac7e5eccefc74f4 /tensorflow/python/util | |
parent | fbdef63fe5849cde5423f8c3cc9c348ed4fe75c3 (diff) |
Add a benchmark for forward+backward for defuns.
Also fix some simple issues that I saw when I benchmarked it (goes from ~3500 examples/sec -> ~4000 examples/sec)
- (nest) Expose is_mapping check that caches to python.
- (nest) Stop calling flatten when unnecessary in pack_sequence_as
- (nest) Set some functions to their swig wrappers directly (instead of wrapping them in another function)
- Directly call the gen_math_ops call in _aggregate_grads to skip any unnecessary python overhead.
- Stop falling back to slow path in _fast_fill.
PiperOrigin-RevId: 209223633
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/nest.py | 76 | ||||
-rw-r--r-- | tensorflow/python/util/util.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/util/util.h | 9 | ||||
-rw-r--r-- | tensorflow/python/util/util.i | 52 |
4 files changed, 85 insertions, 53 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index faae0d89c3..2968ca9c07 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -62,6 +62,10 @@ def _is_namedtuple(instance, strict=False): return _pywrap_tensorflow.IsNamedtuple(instance, strict) +# See the swig file (util.i) for documentation. +_is_mapping = _pywrap_tensorflow.IsMapping + + def _sequence_like(instance, args): """Converts the sequence `args` to the same type as `instance`. @@ -73,7 +77,7 @@ def _sequence_like(instance, args): Returns: `args` with the type of `instance`. """ - if isinstance(instance, (dict, _collections.Mapping)): + if _is_mapping(instance): # Pack dictionaries in a deterministic order by sorting the keys. # Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing @@ -89,7 +93,7 @@ def _sequence_like(instance, args): def _yield_value(iterable): - if isinstance(iterable, (dict, _collections.Mapping)): + if _is_mapping(iterable): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing @@ -102,53 +106,16 @@ def _yield_value(iterable): yield value -def is_sequence(seq): - """Returns a true if its input is a collections.Sequence (except strings). +# See the swig file (util.i) for documentation. +is_sequence = _pywrap_tensorflow.IsSequence - Args: - seq: an input sequence. - Returns: - True if the sequence is a not a string and is a collections.Sequence or a - dict. - """ - return _pywrap_tensorflow.IsSequence(seq) +# See the swig file (util.i) for documentation. +flatten = _pywrap_tensorflow.Flatten -def flatten(nest): - """Returns a flat list from a given nested structure. - - If `nest` is not a sequence, tuple, or dict, then returns a single-element - list: `[nest]`. - - In the case of dict instances, the sequence consists of the values, sorted by - key to ensure deterministic behavior. This is true also for `OrderedDict` - instances: their sequence order is ignored, the sorting order of keys is - used instead. The same convention is followed in `pack_sequence_as`. This - correctly repacks dicts and `OrderedDict`s after they have been flattened, - and also allows flattening an `OrderedDict` and then repacking it back using - a corresponding plain dict, or vice-versa. - Dictionaries with non-sortable keys cannot be flattened. - - Users must not modify any collections used in `nest` while this function is - running. - - Args: - nest: an arbitrarily nested structure or a scalar object. Note, numpy - arrays are considered scalars. - - Returns: - A Python list, the flattened version of the input. - - Raises: - TypeError: The nest is or contains a dict with non-sortable keys. - """ - return _pywrap_tensorflow.Flatten(nest) - - -def _same_namedtuples(nest1, nest2): - """Returns True if the two namedtuples have the same name and fields.""" - return _pywrap_tensorflow.SameNamedtuples(nest1, nest2) +# See the swig file (util.i) for documentation. +_same_namedtuples = _pywrap_tensorflow.SameNamedtuples def assert_same_structure(nest1, nest2, check_types=True): @@ -311,14 +278,17 @@ def pack_sequence_as(structure, flat_sequence): % len(flat_sequence)) return flat_sequence[0] - flat_structure = flatten(structure) - if len(flat_structure) != len(flat_sequence): - raise ValueError( - "Could not pack sequence. Structure had %d elements, but flat_sequence " - "had %d elements. Structure: %s, flat_sequence: %s." - % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) - - _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) + try: + final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0) + if final_index < len(flat_sequence): + raise IndexError + except IndexError: + flat_structure = flatten(structure) + if len(flat_structure) != len(flat_sequence): + raise ValueError( + "Could not pack sequence. Structure had %d elements, but " + "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." % + (len(flat_structure), len(flat_sequence), structure, flat_sequence)) return _sequence_like(structure, packed) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index ebb72079ef..61249d664b 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -647,6 +647,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { } bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } +bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 41dcc969f8..f15ebb6efe 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -47,6 +47,15 @@ bool IsSequence(PyObject* o); // True if `instance` is a `namedtuple`. PyObject* IsNamedtuple(PyObject* o, bool strict); +// Returns a true if its input is a collections.Mapping. +// +// Args: +// seq: the input to be checked. +// +// Returns: +// True if the sequence subclasses mapping. +bool IsMapping(PyObject* o); + // Implements the same interface as tensorflow.util.nest._same_namedtuples // Returns Py_True iff the two namedtuples have the same name and fields. // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 6ad1484295..8d9f9615d7 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -37,18 +37,70 @@ limitations under the License. %unignore tensorflow::swig::RegisterSparseTensorValueClass; %noexception tensorflow::swig::RegisterSparseTensorValueClass; +%feature("docstring") tensorflow::swig::IsSequence +"""Returns a true if its input is a collections.Sequence (except strings). + +Args: + seq: an input sequence. + +Returns: + True if the sequence is a not a string and is a collections.Sequence or a + dict. +""" %unignore tensorflow::swig::IsSequence; %noexception tensorflow::swig::IsSequence; %unignore tensorflow::swig::IsNamedtuple; %noexception tensorflow::swig::IsNamedtuple; +%feature("docstring") tensorflow::swig::IsMapping +"""Returns True iff `instance` is a `collections.Mapping`. + +Args: + instance: An instance of a Python object. + +Returns: + True if `instance` is a `collections.Mapping`. +""" +%unignore tensorflow::swig::IsMapping; +%noexception tensorflow::swig::IsMapping; + +%feature("docstring") tensorflow::swig::SameNamedtuples +"Returns True if the two namedtuples have the same name and fields." %unignore tensorflow::swig::SameNamedtuples; %noexception tensorflow::swig::SameNamedtuples; %unignore tensorflow::swig::AssertSameStructure; %noexception tensorflow::swig::AssertSameStructure; +%feature("docstring") tensorflow::swig::Flatten +"""Returns a flat list from a given nested structure. + +If `nest` is not a sequence, tuple, or dict, then returns a single-element +list: `[nest]`. + +In the case of dict instances, the sequence consists of the values, sorted by +key to ensure deterministic behavior. This is true also for `OrderedDict` +instances: their sequence order is ignored, the sorting order of keys is +used instead. The same convention is followed in `pack_sequence_as`. This +correctly repacks dicts and `OrderedDict`s after they have been flattened, +and also allows flattening an `OrderedDict` and then repacking it back using +a corresponding plain dict, or vice-versa. +Dictionaries with non-sortable keys cannot be flattened. + +Users must not modify any collections used in `nest` while this function is +running. + +Args: + nest: an arbitrarily nested structure or a scalar object. Note, numpy + arrays are considered scalars. + +Returns: + A Python list, the flattened version of the input. + +Raises: + TypeError: The nest is or contains a dict with non-sortable keys. +""" %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; |