aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-27 16:05:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 16:09:18 -0700
commitece50dd9992ac17e3094c7f6d1914febd7a036b5 (patch)
tree18de739f4a7e33abbc9631b46b3992ac53ff446b /tensorflow/python/data
parentb8c86c3bbd8271ed968087f24e7fb704103bc733 (diff)
[tf.data Introducing tf.data.Dataset.reduce() which reduces elements of a (finite) dataset to a single element.
PiperOrigin-RevId: 214852364
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD18
-rw-r--r--tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py124
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py120
3 files changed, 262 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index fdcbfc3684..5f9818566f 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -405,6 +405,24 @@ tf_py_test(
)
tf_py_test(
+ name = "reduce_dataset_op_test",
+ size = "small",
+ srcs = ["reduce_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+tf_py_test(
name = "sequence_dataset_op_test",
size = "small",
srcs = ["sequence_dataset_op_test.py"],
diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
new file mode 100644
index 0000000000..11e07300b9
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testSum(self):
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), lambda x, y: x + y)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i) // 2, sess.run(result))
+
+ def testSumTuple(self):
+
+ def reduce_fn(state, value):
+ v1, v2 = value
+ return state + v1 + v2
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ ds = dataset_ops.Dataset.zip((ds, ds))
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i), sess.run(result))
+
+ def testSumAndCount(self):
+
+ def reduce_fn(state, value):
+ s, c = state
+ return s + value, c + 1
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
+ with self.cached_session() as sess:
+ s, c = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, s)
+ self.assertEqual(i, c)
+
+ def testSquareUsingPlaceholder(self):
+ delta = array_ops.placeholder(dtype=dtypes.int64)
+
+ def reduce_fn(state, _):
+ return state + delta
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ square = sess.run(result, feed_dict={delta: i})
+ self.assertEqual(i * i, square)
+
+ def testSparse(self):
+
+ def reduce_fn(_, value):
+ return value
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
+ result = ds.reduce(make_sparse_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
+
+ def testNested(self):
+
+ def reduce_fn(state, value):
+ state["dense"] += value["dense"]
+ state["sparse"] = value["sparse"]
+ return state
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def map_fn(i):
+ return {"dense": math_ops.cast(i, dtype=dtypes.int64),
+ "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))}
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
+ result = ds.reduce(map_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ result = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, result["dense"])
+ self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index ac87a451b1..6bba72a8e9 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1205,6 +1205,126 @@ class Dataset(object):
shift = size
return WindowDataset(self, size, shift, stride, drop_remainder)
+ def reduce(self, initial_state, reduce_func):
+ """Reduces the input dataset to a single element.
+
+ The transformation calls `reduce_func` successively on every element of
+ the input dataset until the dataset is exhausted, aggregating information in
+ its internal state. The `initial_state` argument is used for the initial
+ state and the final state is returned as the result.
+
+ For example:
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
+ produces `5`
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
+ produces `10`
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial
+ state of the transformation.
+ reduce_func: A function that maps `(old_state, input_element)` to
+ `new_state`. It must take two arguments and return a nested structure
+ of tensors. The structure of `new_state` must match the structure of
+ `initial_state`.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the final
+ state of the transformation.
+
+ """
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state.
+ state_classes = sparse.get_classes(initial_state)
+ state_shapes = nest.pack_sequence_as(
+ initial_state, [t.get_shape() for t in nest.flatten(initial_state)])
+ state_types = nest.pack_sequence_as(
+ initial_state, [t.dtype for t in nest.flatten(initial_state)])
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = StructuredFunctionWrapper(
+ reduce_func,
+ "reduce()",
+ input_classes=(state_classes, self.output_classes),
+ input_shapes=(state_shapes, self.output_shapes),
+ input_types=(state_types, self.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ output_classes = wrapped_func.output_classes
+ for new_state_class, state_class in zip(
+ nest.flatten(output_classes), nest.flatten(state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_classes,
+ wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(output_types), nest.flatten(state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_types,
+ wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ output_shapes = wrapped_func.output_shapes
+ flat_state_shapes = nest.flatten(state_shapes)
+ flat_new_state_shapes = nest.flatten(output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ state_shapes = nest.pack_sequence_as(state_shapes,
+ weakened_state_shapes)
+
+ reduce_func = wrapped_func.function
+ reduce_func.add_to_graph(ops.get_default_graph())
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(
+ output_types,
+ gen_dataset_ops.reduce_dataset(
+ self._as_variant_tensor(), # pylint: disable=protected-access
+ nest.flatten(sparse.serialize_sparse_tensors(initial_state)),
+ reduce_func.captured_inputs,
+ f=reduce_func,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)))),
+ output_types,
+ output_shapes,
+ output_classes)
+
class DatasetSource(Dataset):
"""Abstract class representing a dataset with no inputs."""