diff options
author | 2018-09-27 16:05:51 -0700 | |
---|---|---|
committer | 2018-09-27 16:09:18 -0700 | |
commit | ece50dd9992ac17e3094c7f6d1914febd7a036b5 (patch) | |
tree | 18de739f4a7e33abbc9631b46b3992ac53ff446b /tensorflow/python/data | |
parent | b8c86c3bbd8271ed968087f24e7fb704103bc733 (diff) |
[tf.data Introducing tf.data.Dataset.reduce() which reduces elements of a (finite) dataset to a single element.
PiperOrigin-RevId: 214852364
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/kernel_tests/BUILD | 18 | ||||
-rw-r--r-- | tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py | 124 | ||||
-rw-r--r-- | tensorflow/python/data/ops/dataset_ops.py | 120 |
3 files changed, 262 insertions, 0 deletions
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index fdcbfc3684..5f9818566f 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -405,6 +405,24 @@ tf_py_test( ) tf_py_test( + name = "reduce_dataset_op_test", + size = "small", + srcs = ["reduce_dataset_op_test.py"], + additional_deps = [ + ":test_base", + "@absl_py//absl/testing:parameterized", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + +tf_py_test( name = "sequence_dataset_op_test", size = "small", srcs = ["sequence_dataset_op_test.py"], diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py new file mode 100644 index 0000000000..11e07300b9 --- /dev/null +++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py @@ -0,0 +1,124 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + + def testSum(self): + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + result = ds.reduce(np.int64(0), lambda x, y: x + y) + with self.cached_session() as sess: + self.assertEqual(((i + 1) * i) // 2, sess.run(result)) + + def testSumTuple(self): + + def reduce_fn(state, value): + v1, v2 = value + return state + v1 + v2 + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + ds = dataset_ops.Dataset.zip((ds, ds)) + result = ds.reduce(np.int64(0), reduce_fn) + with self.cached_session() as sess: + self.assertEqual(((i + 1) * i), sess.run(result)) + + def testSumAndCount(self): + + def reduce_fn(state, value): + s, c = state + return s + value, c + 1 + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn) + with self.cached_session() as sess: + s, c = sess.run(result) + self.assertEqual(((i + 1) * i) // 2, s) + self.assertEqual(i, c) + + def testSquareUsingPlaceholder(self): + delta = array_ops.placeholder(dtype=dtypes.int64) + + def reduce_fn(state, _): + return state + delta + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1) + result = ds.reduce(np.int64(0), reduce_fn) + with self.cached_session() as sess: + square = sess.run(result, feed_dict={delta: i}) + self.assertEqual(i * i, square) + + def testSparse(self): + + def reduce_fn(_, value): + return value + + def make_sparse_fn(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + for i in range(10): + ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1)) + result = ds.reduce(make_sparse_fn(0), reduce_fn) + with self.cached_session() as sess: + self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result)) + + def testNested(self): + + def reduce_fn(state, value): + state["dense"] += value["dense"] + state["sparse"] = value["sparse"] + return state + + def make_sparse_fn(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def map_fn(i): + return {"dense": math_ops.cast(i, dtype=dtypes.int64), + "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))} + + for i in range(10): + ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn) + result = ds.reduce(map_fn(0), reduce_fn) + with self.cached_session() as sess: + result = sess.run(result) + self.assertEqual(((i + 1) * i) // 2, result["dense"]) + self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"]) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index ac87a451b1..6bba72a8e9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1205,6 +1205,126 @@ class Dataset(object): shift = size return WindowDataset(self, size, shift, stride, drop_remainder) + def reduce(self, initial_state, reduce_func): + """Reduces the input dataset to a single element. + + The transformation calls `reduce_func` successively on every element of + the input dataset until the dataset is exhausted, aggregating information in + its internal state. The `initial_state` argument is used for the initial + state and the final state is returned as the result. + + For example: + - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)` + produces `5` + - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)` + produces `10` + + Args: + initial_state: A nested structure of tensors, representing the initial + state of the transformation. + reduce_func: A function that maps `(old_state, input_element)` to + `new_state`. It must take two arguments and return a nested structure + of tensors. The structure of `new_state` must match the structure of + `initial_state`. + + Returns: + A nested structure of `tf.Tensor` objects, corresponding to the final + state of the transformation. + + """ + + with ops.name_scope("initial_state"): + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + initial_state = nest.pack_sequence_as(initial_state, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( + t, name="component_%d" % i) + for i, t in enumerate(nest.flatten(initial_state)) + ]) + + # Compute initial values for the state classes, shapes and types based on + # the initial state. + state_classes = sparse.get_classes(initial_state) + state_shapes = nest.pack_sequence_as( + initial_state, [t.get_shape() for t in nest.flatten(initial_state)]) + state_types = nest.pack_sequence_as( + initial_state, [t.dtype for t in nest.flatten(initial_state)]) + + # Iteratively rerun the reduce function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + wrapped_func = StructuredFunctionWrapper( + reduce_func, + "reduce()", + input_classes=(state_classes, self.output_classes), + input_shapes=(state_shapes, self.output_shapes), + input_types=(state_types, self.output_types), + add_to_graph=False) + + # Extract and validate class information from the returned values. + output_classes = wrapped_func.output_classes + for new_state_class, state_class in zip( + nest.flatten(output_classes), nest.flatten(state_classes)): + if not issubclass(new_state_class, state_class): + raise TypeError( + "The element classes for the new state must match the initial " + "state. Expected %s; got %s." % (state_classes, + wrapped_func.output_classes)) + + # Extract and validate type information from the returned values. + output_types = wrapped_func.output_types + for new_state_type, state_type in zip( + nest.flatten(output_types), nest.flatten(state_types)): + if new_state_type != state_type: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % (state_types, + wrapped_func.output_types)) + + # Extract shape information from the returned values. + output_shapes = wrapped_func.output_shapes + flat_state_shapes = nest.flatten(state_shapes) + flat_new_state_shapes = nest.flatten(output_shapes) + weakened_state_shapes = [ + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( + weakened_shape.ndims is None or + original_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + state_shapes = nest.pack_sequence_as(state_shapes, + weakened_state_shapes) + + reduce_func = wrapped_func.function + reduce_func.add_to_graph(ops.get_default_graph()) + + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as( + output_types, + gen_dataset_ops.reduce_dataset( + self._as_variant_tensor(), # pylint: disable=protected-access + nest.flatten(sparse.serialize_sparse_tensors(initial_state)), + reduce_func.captured_inputs, + f=reduce_func, + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)))), + output_types, + output_shapes, + output_classes) + class DatasetSource(Dataset): """Abstract class representing a dataset with no inputs.""" |