diff options
-rw-r--r-- | tensorflow/python/eager/backprop.py | 7 | ||||
-rw-r--r-- | tensorflow/python/eager/benchmarks_test.py | 30 | ||||
-rw-r--r-- | tensorflow/python/util/nest.py | 76 | ||||
-rw-r--r-- | tensorflow/python/util/util.cc | 1 | ||||
-rw-r--r-- | tensorflow/python/util/util.h | 9 | ||||
-rw-r--r-- | tensorflow/python/util/util.i | 52 |
6 files changed, 120 insertions, 55 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 553f761a14..e4e99078b9 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import tf_logging as logging @@ -557,7 +558,7 @@ def _aggregate_grads(gradients): if len(gradients) == 1: return gradients[0] if all([isinstance(g, ops.Tensor) for g in gradients]): - return math_ops.add_n(gradients) + return gen_math_ops.add_n(gradients) else: assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices)) for g in gradients]) @@ -592,7 +593,9 @@ def _num_elements(grad): def _fast_fill(value, shape, dtype): - return array_ops.fill(shape, constant_op.constant(value, dtype=dtype)) + return array_ops.fill( + constant_op.constant(shape, dtype=dtypes.int32), + constant_op.constant(value, dtype=dtype)) def _zeros(shape, dtype): diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py index e2b1890c2f..a2e8422671 100644 --- a/tensorflow/python/eager/benchmarks_test.py +++ b/tensorflow/python/eager/benchmarks_test.py @@ -350,6 +350,21 @@ class MicroBenchmarks(test.Benchmark): func = lambda: f(m, m, transpose_b) self._run(func, num_iters, execution_mode=execution_mode) + def _benchmark_defun_matmul_forward_backward(self, + m, + transpose_b, + num_iters, + execution_mode=None): + f = function.defun(math_ops.matmul) + + def func(): + with backprop.GradientTape() as gt: + gt.watch(m) + y = f(m, m, transpose_b) + _ = gt.gradient(y, m) + + self._run(func, num_iters, execution_mode=execution_mode) + def _benchmark_read_variable(self, m, num_iters): self._run(m.value, num_iters) @@ -421,6 +436,21 @@ class MicroBenchmarks(test.Benchmark): num_iters=self._num_iters_2_by_2, execution_mode=context.ASYNC) + def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_matmul_forward_backward( + m, transpose_b=False, num_iters=self._num_iters_2_by_2) + + def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self): + with context.device(CPU): + m = self._m_2_by_2.cpu() + self._benchmark_defun_matmul_forward_backward( + m, + transpose_b=False, + num_iters=self._num_iters_2_by_2, + execution_mode=context.ASYNC) + def benchmark_tf_matmul_2_by_2_GPU(self): if not context.num_gpus(): return diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py index faae0d89c3..2968ca9c07 100644 --- a/tensorflow/python/util/nest.py +++ b/tensorflow/python/util/nest.py @@ -62,6 +62,10 @@ def _is_namedtuple(instance, strict=False): return _pywrap_tensorflow.IsNamedtuple(instance, strict) +# See the swig file (util.i) for documentation. +_is_mapping = _pywrap_tensorflow.IsMapping + + def _sequence_like(instance, args): """Converts the sequence `args` to the same type as `instance`. @@ -73,7 +77,7 @@ def _sequence_like(instance, args): Returns: `args` with the type of `instance`. """ - if isinstance(instance, (dict, _collections.Mapping)): + if _is_mapping(instance): # Pack dictionaries in a deterministic order by sorting the keys. # Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing @@ -89,7 +93,7 @@ def _sequence_like(instance, args): def _yield_value(iterable): - if isinstance(iterable, (dict, _collections.Mapping)): + if _is_mapping(iterable): # Iterate through dictionaries in a deterministic order by sorting the # keys. Notice this means that we ignore the original order of `OrderedDict` # instances. This is intentional, to avoid potential bugs caused by mixing @@ -102,53 +106,16 @@ def _yield_value(iterable): yield value -def is_sequence(seq): - """Returns a true if its input is a collections.Sequence (except strings). +# See the swig file (util.i) for documentation. +is_sequence = _pywrap_tensorflow.IsSequence - Args: - seq: an input sequence. - Returns: - True if the sequence is a not a string and is a collections.Sequence or a - dict. - """ - return _pywrap_tensorflow.IsSequence(seq) +# See the swig file (util.i) for documentation. +flatten = _pywrap_tensorflow.Flatten -def flatten(nest): - """Returns a flat list from a given nested structure. - - If `nest` is not a sequence, tuple, or dict, then returns a single-element - list: `[nest]`. - - In the case of dict instances, the sequence consists of the values, sorted by - key to ensure deterministic behavior. This is true also for `OrderedDict` - instances: their sequence order is ignored, the sorting order of keys is - used instead. The same convention is followed in `pack_sequence_as`. This - correctly repacks dicts and `OrderedDict`s after they have been flattened, - and also allows flattening an `OrderedDict` and then repacking it back using - a corresponding plain dict, or vice-versa. - Dictionaries with non-sortable keys cannot be flattened. - - Users must not modify any collections used in `nest` while this function is - running. - - Args: - nest: an arbitrarily nested structure or a scalar object. Note, numpy - arrays are considered scalars. - - Returns: - A Python list, the flattened version of the input. - - Raises: - TypeError: The nest is or contains a dict with non-sortable keys. - """ - return _pywrap_tensorflow.Flatten(nest) - - -def _same_namedtuples(nest1, nest2): - """Returns True if the two namedtuples have the same name and fields.""" - return _pywrap_tensorflow.SameNamedtuples(nest1, nest2) +# See the swig file (util.i) for documentation. +_same_namedtuples = _pywrap_tensorflow.SameNamedtuples def assert_same_structure(nest1, nest2, check_types=True): @@ -311,14 +278,17 @@ def pack_sequence_as(structure, flat_sequence): % len(flat_sequence)) return flat_sequence[0] - flat_structure = flatten(structure) - if len(flat_structure) != len(flat_sequence): - raise ValueError( - "Could not pack sequence. Structure had %d elements, but flat_sequence " - "had %d elements. Structure: %s, flat_sequence: %s." - % (len(flat_structure), len(flat_sequence), structure, flat_sequence)) - - _, packed = _packed_nest_with_indices(structure, flat_sequence, 0) + try: + final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0) + if final_index < len(flat_sequence): + raise IndexError + except IndexError: + flat_structure = flatten(structure) + if len(flat_structure) != len(flat_sequence): + raise ValueError( + "Could not pack sequence. Structure had %d elements, but " + "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." % + (len(flat_structure), len(flat_sequence), structure, flat_sequence)) return _sequence_like(structure, packed) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index ebb72079ef..61249d664b 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -647,6 +647,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { } bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } +bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index 41dcc969f8..f15ebb6efe 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -47,6 +47,15 @@ bool IsSequence(PyObject* o); // True if `instance` is a `namedtuple`. PyObject* IsNamedtuple(PyObject* o, bool strict); +// Returns a true if its input is a collections.Mapping. +// +// Args: +// seq: the input to be checked. +// +// Returns: +// True if the sequence subclasses mapping. +bool IsMapping(PyObject* o); + // Implements the same interface as tensorflow.util.nest._same_namedtuples // Returns Py_True iff the two namedtuples have the same name and fields. // Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index 6ad1484295..8d9f9615d7 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -37,18 +37,70 @@ limitations under the License. %unignore tensorflow::swig::RegisterSparseTensorValueClass; %noexception tensorflow::swig::RegisterSparseTensorValueClass; +%feature("docstring") tensorflow::swig::IsSequence +"""Returns a true if its input is a collections.Sequence (except strings). + +Args: + seq: an input sequence. + +Returns: + True if the sequence is a not a string and is a collections.Sequence or a + dict. +""" %unignore tensorflow::swig::IsSequence; %noexception tensorflow::swig::IsSequence; %unignore tensorflow::swig::IsNamedtuple; %noexception tensorflow::swig::IsNamedtuple; +%feature("docstring") tensorflow::swig::IsMapping +"""Returns True iff `instance` is a `collections.Mapping`. + +Args: + instance: An instance of a Python object. + +Returns: + True if `instance` is a `collections.Mapping`. +""" +%unignore tensorflow::swig::IsMapping; +%noexception tensorflow::swig::IsMapping; + +%feature("docstring") tensorflow::swig::SameNamedtuples +"Returns True if the two namedtuples have the same name and fields." %unignore tensorflow::swig::SameNamedtuples; %noexception tensorflow::swig::SameNamedtuples; %unignore tensorflow::swig::AssertSameStructure; %noexception tensorflow::swig::AssertSameStructure; +%feature("docstring") tensorflow::swig::Flatten +"""Returns a flat list from a given nested structure. + +If `nest` is not a sequence, tuple, or dict, then returns a single-element +list: `[nest]`. + +In the case of dict instances, the sequence consists of the values, sorted by +key to ensure deterministic behavior. This is true also for `OrderedDict` +instances: their sequence order is ignored, the sorting order of keys is +used instead. The same convention is followed in `pack_sequence_as`. This +correctly repacks dicts and `OrderedDict`s after they have been flattened, +and also allows flattening an `OrderedDict` and then repacking it back using +a corresponding plain dict, or vice-versa. +Dictionaries with non-sortable keys cannot be flattened. + +Users must not modify any collections used in `nest` while this function is +running. + +Args: + nest: an arbitrarily nested structure or a scalar object. Note, numpy + arrays are considered scalars. + +Returns: + A Python list, the flattened version of the input. + +Raises: + TypeError: The nest is or contains a dict with non-sortable keys. +""" %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; |