aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/eager/backprop.py7
-rw-r--r--tensorflow/python/eager/benchmarks_test.py30
-rw-r--r--tensorflow/python/util/nest.py76
-rw-r--r--tensorflow/python/util/util.cc1
-rw-r--r--tensorflow/python/util/util.h9
-rw-r--r--tensorflow/python/util/util.i52
6 files changed, 120 insertions, 55 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 553f761a14..e4e99078b9 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
@@ -557,7 +558,7 @@ def _aggregate_grads(gradients):
if len(gradients) == 1:
return gradients[0]
if all([isinstance(g, ops.Tensor) for g in gradients]):
- return math_ops.add_n(gradients)
+ return gen_math_ops.add_n(gradients)
else:
assert all([isinstance(g, (ops.Tensor, ops.IndexedSlices))
for g in gradients])
@@ -592,7 +593,9 @@ def _num_elements(grad):
def _fast_fill(value, shape, dtype):
- return array_ops.fill(shape, constant_op.constant(value, dtype=dtype))
+ return array_ops.fill(
+ constant_op.constant(shape, dtype=dtypes.int32),
+ constant_op.constant(value, dtype=dtype))
def _zeros(shape, dtype):
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index e2b1890c2f..a2e8422671 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -350,6 +350,21 @@ class MicroBenchmarks(test.Benchmark):
func = lambda: f(m, m, transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
+ def _benchmark_defun_matmul_forward_backward(self,
+ m,
+ transpose_b,
+ num_iters,
+ execution_mode=None):
+ f = function.defun(math_ops.matmul)
+
+ def func():
+ with backprop.GradientTape() as gt:
+ gt.watch(m)
+ y = f(m, m, transpose_b)
+ _ = gt.gradient(y, m)
+
+ self._run(func, num_iters, execution_mode=execution_mode)
+
def _benchmark_read_variable(self, m, num_iters):
self._run(m.value, num_iters)
@@ -421,6 +436,21 @@ class MicroBenchmarks(test.Benchmark):
num_iters=self._num_iters_2_by_2,
execution_mode=context.ASYNC)
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m, transpose_b=False, num_iters=self._num_iters_2_by_2)
+
+ def benchmark_defun_matmul_forward_backward_2_by_2_CPU_async(self):
+ with context.device(CPU):
+ m = self._m_2_by_2.cpu()
+ self._benchmark_defun_matmul_forward_backward(
+ m,
+ transpose_b=False,
+ num_iters=self._num_iters_2_by_2,
+ execution_mode=context.ASYNC)
+
def benchmark_tf_matmul_2_by_2_GPU(self):
if not context.num_gpus():
return
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index faae0d89c3..2968ca9c07 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -62,6 +62,10 @@ def _is_namedtuple(instance, strict=False):
return _pywrap_tensorflow.IsNamedtuple(instance, strict)
+# See the swig file (util.i) for documentation.
+_is_mapping = _pywrap_tensorflow.IsMapping
+
+
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
@@ -73,7 +77,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
- if isinstance(instance, (dict, _collections.Mapping)):
+ if _is_mapping(instance):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -89,7 +93,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
- if isinstance(iterable, (dict, _collections.Mapping)):
+ if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -102,53 +106,16 @@ def _yield_value(iterable):
yield value
-def is_sequence(seq):
- """Returns a true if its input is a collections.Sequence (except strings).
+# See the swig file (util.i) for documentation.
+is_sequence = _pywrap_tensorflow.IsSequence
- Args:
- seq: an input sequence.
- Returns:
- True if the sequence is a not a string and is a collections.Sequence or a
- dict.
- """
- return _pywrap_tensorflow.IsSequence(seq)
+# See the swig file (util.i) for documentation.
+flatten = _pywrap_tensorflow.Flatten
-def flatten(nest):
- """Returns a flat list from a given nested structure.
-
- If `nest` is not a sequence, tuple, or dict, then returns a single-element
- list: `[nest]`.
-
- In the case of dict instances, the sequence consists of the values, sorted by
- key to ensure deterministic behavior. This is true also for `OrderedDict`
- instances: their sequence order is ignored, the sorting order of keys is
- used instead. The same convention is followed in `pack_sequence_as`. This
- correctly repacks dicts and `OrderedDict`s after they have been flattened,
- and also allows flattening an `OrderedDict` and then repacking it back using
- a corresponding plain dict, or vice-versa.
- Dictionaries with non-sortable keys cannot be flattened.
-
- Users must not modify any collections used in `nest` while this function is
- running.
-
- Args:
- nest: an arbitrarily nested structure or a scalar object. Note, numpy
- arrays are considered scalars.
-
- Returns:
- A Python list, the flattened version of the input.
-
- Raises:
- TypeError: The nest is or contains a dict with non-sortable keys.
- """
- return _pywrap_tensorflow.Flatten(nest)
-
-
-def _same_namedtuples(nest1, nest2):
- """Returns True if the two namedtuples have the same name and fields."""
- return _pywrap_tensorflow.SameNamedtuples(nest1, nest2)
+# See the swig file (util.i) for documentation.
+_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
def assert_same_structure(nest1, nest2, check_types=True):
@@ -311,14 +278,17 @@ def pack_sequence_as(structure, flat_sequence):
% len(flat_sequence))
return flat_sequence[0]
- flat_structure = flatten(structure)
- if len(flat_structure) != len(flat_sequence):
- raise ValueError(
- "Could not pack sequence. Structure had %d elements, but flat_sequence "
- "had %d elements. Structure: %s, flat_sequence: %s."
- % (len(flat_structure), len(flat_sequence), structure, flat_sequence))
-
- _, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ try:
+ final_index, packed = _packed_nest_with_indices(structure, flat_sequence, 0)
+ if final_index < len(flat_sequence):
+ raise IndexError
+ except IndexError:
+ flat_structure = flatten(structure)
+ if len(flat_structure) != len(flat_sequence):
+ raise ValueError(
+ "Could not pack sequence. Structure had %d elements, but "
+ "flat_sequence had %d elements. Structure: %s, flat_sequence: %s." %
+ (len(flat_structure), len(flat_sequence), structure, flat_sequence))
return _sequence_like(structure, packed)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index ebb72079ef..61249d664b 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -647,6 +647,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
}
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
+bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 41dcc969f8..f15ebb6efe 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -47,6 +47,15 @@ bool IsSequence(PyObject* o);
// True if `instance` is a `namedtuple`.
PyObject* IsNamedtuple(PyObject* o, bool strict);
+// Returns a true if its input is a collections.Mapping.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the sequence subclasses mapping.
+bool IsMapping(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 6ad1484295..8d9f9615d7 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -37,18 +37,70 @@ limitations under the License.
%unignore tensorflow::swig::RegisterSparseTensorValueClass;
%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%feature("docstring") tensorflow::swig::IsSequence
+"""Returns a true if its input is a collections.Sequence (except strings).
+
+Args:
+ seq: an input sequence.
+
+Returns:
+ True if the sequence is a not a string and is a collections.Sequence or a
+ dict.
+"""
%unignore tensorflow::swig::IsSequence;
%noexception tensorflow::swig::IsSequence;
%unignore tensorflow::swig::IsNamedtuple;
%noexception tensorflow::swig::IsNamedtuple;
+%feature("docstring") tensorflow::swig::IsMapping
+"""Returns True iff `instance` is a `collections.Mapping`.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is a `collections.Mapping`.
+"""
+%unignore tensorflow::swig::IsMapping;
+%noexception tensorflow::swig::IsMapping;
+
+%feature("docstring") tensorflow::swig::SameNamedtuples
+"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;
%noexception tensorflow::swig::SameNamedtuples;
%unignore tensorflow::swig::AssertSameStructure;
%noexception tensorflow::swig::AssertSameStructure;
+%feature("docstring") tensorflow::swig::Flatten
+"""Returns a flat list from a given nested structure.
+
+If `nest` is not a sequence, tuple, or dict, then returns a single-element
+list: `[nest]`.
+
+In the case of dict instances, the sequence consists of the values, sorted by
+key to ensure deterministic behavior. This is true also for `OrderedDict`
+instances: their sequence order is ignored, the sorting order of keys is
+used instead. The same convention is followed in `pack_sequence_as`. This
+correctly repacks dicts and `OrderedDict`s after they have been flattened,
+and also allows flattening an `OrderedDict` and then repacking it back using
+a corresponding plain dict, or vice-versa.
+Dictionaries with non-sortable keys cannot be flattened.
+
+Users must not modify any collections used in `nest` while this function is
+running.
+
+Args:
+ nest: an arbitrarily nested structure or a scalar object. Note, numpy
+ arrays are considered scalars.
+
+Returns:
+ A Python list, the flattened version of the input.
+
+Raises:
+ TypeError: The nest is or contains a dict with non-sortable keys.
+"""
%unignore tensorflow::swig::Flatten;
%noexception tensorflow::swig::Flatten;