diff options
10 files changed, 1028 insertions, 2 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 55a56b83a8..bd3e034211 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -35,6 +36,179 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test +class GroupByReducerTest(test.TestCase): + + def checkResults(self, dataset, shapes, values): + self.assertEqual(shapes, dataset.output_shapes) + get_next = dataset.make_one_shot_iterator().get_next() + with self.test_session() as sess: + for expected in values: + got = sess.run(get_next) + self.assertEqual(got, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSum(self): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer(lambda x: x % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testAverage(self): + + def reduce_fn(x, y): + return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / ( + x[1] + 1), x[1] + 1 + + reducer = grouping.Reducer( + init_func=lambda _: (0.0, 0.0), + reduce_func=reduce_fn, + finalize_func=lambda x: x[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).apply( + grouping.group_by_reducer( + lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[i - 1, i]) + + def testConcat(self): + components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray) + reducer = grouping.Reducer( + init_func=lambda x: "", + reduce_func=lambda x, y: x + y[0], + finalize_func=lambda x: x) + for i in range(1, 11): + dataset = dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensor_slices(components), + dataset_ops.Dataset.range(2 * i))).apply( + grouping.group_by_reducer(lambda x, y: y % 2, reducer)) + self.checkResults( + dataset, + shapes=tensor_shape.scalar(), + values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]]) + + def testSparseSum(self): + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1], dtype=np.int64)), + dense_shape=np.array([1, 1])) + + reducer = grouping.Reducer( + init_func=lambda _: _sparse(np.int64(0)), + reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]), + finalize_func=lambda x: x.values[0]) + for i in range(1, 11): + dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply( + grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer)) + self.checkResults( + dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i]) + + def testChangingStateShape(self): + + def reduce_fn(x, _): + # Statically known rank, but dynamic length. + larger_dim = array_ops.concat([x[0], x[0]], 0) + # Statically unknown rank. + larger_rank = array_ops.expand_dims(x[1], 0) + return larger_dim, larger_rank + + reducer = grouping.Reducer( + init_func=lambda x: ([0], 1), + reduce_func=reduce_fn, + finalize_func=lambda x: x) + + for i in range(1, 11): + dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply( + grouping.group_by_reducer(lambda x: x, reducer)) + self.assertEqual([None], dataset.output_shapes[0].as_list()) + self.assertIs(None, dataset.output_shapes[1].ndims) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + with self.test_session() as sess: + x, y = sess.run(get_next) + self.assertAllEqual([0] * (2**i), x) + self.assertAllEqual(np.array(1, ndmin=i), y) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testTypeMismatch(self): + reducer = grouping.Reducer( + init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32), + reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64), + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64(0), reducer)) + + # TODO(b/78665031): Remove once non-scalar keys are supported. + def testInvalidKeyShape(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer)) + + # TODO(b/78665031): Remove once non-int64 keys are supported. + def testInvalidKeyType(self): + reducer = grouping.Reducer( + init_func=lambda x: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + ValueError, "`key_func` must return a single tf.int64 tensor."): + dataset.apply( + grouping.group_by_reducer(lambda _: "wrong", reducer)) + + +class GroupByReducerSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + reducer = grouping.Reducer( + init_func=lambda _: np.int64(0), + reduce_func=lambda x, y: x + y, + finalize_func=lambda x: x) + + return dataset_ops.Dataset.from_tensor_slices(components).apply( + grouping.group_by_reducer(lambda x: x % 5, reducer)) + + def testCoreGroupByReducer(self): + components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 5, verify_exhausted=True) + diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 5, + verify_exhausted=True) + + class GroupByWindowTest(test.TestCase): def testSimple(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index f544b1caa6..eb2ceff893 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -168,7 +168,7 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) -class ScanDatasetSerialzationTest( +class ScanDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_dataset(self, num_elements): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 0531f9cbb9..ea229b5b27 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -33,6 +34,35 @@ from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops +def group_by_reducer(key_func, reducer): + """A transformation that groups elements and performs a reduction. + + This transformation maps element of a dataset to a key using `key_func` and + groups the elements by key. The `reducer` is used to process each group; its + `init_func` is used to initialize state for each group when it is created, the + `reduce_func` is used to update the state every time an element is mapped to + the matching group, and the `finalize_func` is used to map the final state to + an output value. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reducer: An instance of `Reducer`, which captures the reduction logic using + the `init_func`, `reduce_func`, and `finalize_func` functions. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return GroupByReducerDataset(dataset, key_func, reducer) + + return _apply_fn + + def group_by_window(key_func, reduce_func, window_size=None, @@ -227,6 +257,250 @@ class _VariantDataset(dataset_ops.Dataset): return self._output_types +class GroupByReducerDataset(dataset_ops.Dataset): + """A `Dataset` that groups its input and performs a reduction.""" + + def __init__(self, input_dataset, key_func, reducer): + """See `group_by_reducer()` for details.""" + super(GroupByReducerDataset, self).__init__() + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_init_func(reducer.init_func) + self._make_reduce_func(reducer.reduce_func, input_dataset) + self._make_finalize_func(reducer.finalize_func) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) + def tf_key_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + # pylint: disable=protected-access + if dataset_ops._should_unpack_args(nested_args): + ret = key_func(*nested_args) + # pylint: enable=protected-access + else: + ret = key_func(nested_args) + ret = ops.convert_to_tensor(ret) + if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar(): + raise ValueError( + "`key_func` must return a single tf.int64 tensor. " + "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape())) + return ret + + self._key_func = tf_key_func + self._key_func.add_to_graph(ops.get_default_graph()) + + def _make_init_func(self, init_func): + """Make wrapping Defun for init_func.""" + + @function.Defun(dtypes.int64) + def tf_init_func(key): + """A wrapper for Defun that facilitates shape inference.""" + key.set_shape([]) + ret = init_func(key) + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._state_classes = sparse.get_classes(ret) + self._state_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._state_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._init_func = tf_init_func + self._init_func.add_to_graph(ops.get_default_graph()) + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + # Iteratively rerun the reduce function until reaching a fixed point on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + # Create a list in which `tf_reduce_func` will store the new shapes. + flat_new_state_shapes = [] + + @function.Defun(*(nest.flatten( + sparse.as_dense_types( + self._state_types, self._state_classes)) + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) + def tf_reduce_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes)) + + nest.flatten( + sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes))): + arg.set_shape(shape) + + pivot = len(nest.flatten(self._state_shapes)) + nested_state_args = nest.pack_sequence_as(self._state_types, + args[:pivot]) + nested_state_args = sparse.deserialize_sparse_tensors( + nested_state_args, self._state_types, self._state_shapes, + self._state_classes) + nested_input_args = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + nested_input_args = sparse.deserialize_sparse_tensors( + nested_input_args, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) + + ret = reduce_func(nested_state_args, nested_input_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + # Extract shape information from the returned values. + flat_new_state = nest.flatten(ret) + flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, + nest.pack_sequence_as(self._state_types, + [t.dtype for t in flat_new_state]))) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, + [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + # Use the private method that will execute `tf_reduce_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access + + flat_state_shapes = nest.flatten(self._state_shapes) + weakened_state_shapes = [ + old.most_specific_compatible_shape(new) + for old, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for old_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if old_shape.ndims is not None and ( + weakened_shape.ndims is None or + old_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._reduce_func = tf_reduce_func + self._reduce_func.add_to_graph(ops.get_default_graph()) + + def _make_finalize_func(self, finalize_func): + """Make wrapping Defun for finalize_func.""" + + @function.Defun(*(nest.flatten( + sparse.as_dense_types(self._state_types, self._state_classes)))) + def tf_finalize_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + for arg, shape in zip( + args, + nest.flatten( + sparse.as_dense_shapes(self._state_shapes, self._state_classes))): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(self._state_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, self._state_types, self._state_shapes, + self._state_classes) + + ret = finalize_func(nested_args) + + # Convert any `SparseTensorValue`s to `SparseTensor`s and all other + # values to tensors. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor.SparseTensor.from_value(t) + if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t) + for t in nest.flatten(ret) + ]) + + self._output_classes = sparse.get_classes(ret) + self._output_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._output_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors. + ret = nest.pack_sequence_as( + ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))]) + return nest.flatten(ret) + + self._finalize_func = tf_finalize_func + self._finalize_func.add_to_graph(ops.get_default_graph()) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_reducer_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._init_func.captured_inputs, + self._reduce_func.captured_inputs, + self._finalize_func.captured_inputs, + key_func=self._key_func, + init_func=self._init_func, + reduce_func=self._reduce_func, + finalize_func=self._finalize_func, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + class GroupByWindowDataset(dataset_ops.Dataset): """A `Dataset` that groups its input and performs a windowed reduction.""" @@ -336,3 +610,30 @@ class GroupByWindowDataset(dataset_ops.Dataset): sparse.as_dense_types(self.output_types, self.output_classes)), output_shapes=nest.flatten( sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + +class Reducer(object): + """A reducer is used for reducing a set of elements. + + A reducer is represented as a tuple of the three functions: + 1) initialization function: key => initial state + 2) reduce function: (old state, input) => new state + 3) finalization function: state => result + """ + + def __init__(self, init_func, reduce_func, finalize_func): + self._init_func = init_func + self._reduce_func = reduce_func + self._finalize_func = finalize_func + + @property + def init_func(self): + return self._init_func + + @property + def reduce_func(self): + return self._reduce_func + + @property + def finalize_func(self): + return self._finalize_func diff --git a/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt new file mode 100644 index 0000000000..067ad4018b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt @@ -0,0 +1,69 @@ +op { + graph_op_name: "GroupByReducerDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + in_arg { + name: "key_func_other_arguments" + description: <<END +A list of tensors, typically values that were captured when +building a closure for `key_func`. +END + } + attr { + name: "key_func" + description: <<END +A function mapping an element of `input_dataset`, concatenated +with `key_func_other_arguments` to a scalar value of type DT_INT64. +END + } + in_arg { + name: "init_func_other_arguments" + description: <<END +A list of tensors, typically values that were captured when +building a closure for `init_func`. +END + } + attr { + name: "init_func" + description: <<END +A function mapping a key of type DT_INT64, concatenated with +`init_func_other_arguments` to the initial reducer state. +END + } + in_arg { + name: "reduce_func_other_arguments" + description: <<END +A list of tensors, typically values that were captured when +building a closure for `reduce_func`. +END + } + attr { + name: "reduce_func" + description: <<END +A function mapping the current reducer state and an element of `input_dataset`, +concatenated with `reduce_func_other_arguments` to a new reducer state. +END + } + in_arg { + name: "finalize_func_other_arguments" + description: <<END +A list of tensors, typically values that were captured when +building a closure for `finalize_func`. +END + } + attr { + name: "finalize_func" + description: <<END +A function mapping the final reducer state to an output element. +END + } + summary: "Creates a dataset that computes a group-by on `input_dataset`." + description: <<END +Creates a dataset that computes a group-by on `input_dataset`. +END +} diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index c78e0aff83..9ded2667eb 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -124,6 +124,20 @@ tf_kernel_library( ) tf_kernel_library( + name = "group_by_reducer_dataset_op", + srcs = ["group_by_reducer_dataset_op.cc"], + deps = [ + ":captured_function", + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( name = "group_by_window_dataset_op", srcs = ["group_by_window_dataset_op.cc"], deps = [ @@ -550,6 +564,7 @@ tf_kernel_library( ":filter_dataset_op", ":flat_map_dataset_op", ":generator_dataset_op", + ":group_by_reducer_dataset_op", ":group_by_window_dataset_op", ":interleave_dataset_op", ":iterator_ops", diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index dd61b7daee..ee58341cfd 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -32,6 +32,20 @@ Status CapturedFunction::Create( return Status::OK(); } +/* static */ +Status CapturedFunction::Create( + const NameAttrList& func, OpKernelContext* ctx, const string& argument, + std::unique_ptr<CapturedFunction>* out_function) { + OpInputList argument_inputs; + TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs)); + std::vector<Tensor> arguments_t; + arguments_t.reserve(argument_inputs.size()); + for (const Tensor& t : argument_inputs) { + arguments_t.push_back(t); + } + return CapturedFunction::Create(func, std::move(arguments_t), out_function); +} + CapturedFunction::~CapturedFunction() { if (lib_ != nullptr && f_handle_ != kInvalidHandle) { lib_->ReleaseHandle(f_handle_).IgnoreError(); diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 490f5cd1e3..e9ad3e381d 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -40,12 +40,20 @@ class ResourceMgr; // context. class CapturedFunction { public: + // Creates a new instance from a list of named attributes and captured inputs. + // // NOTE(mrry): The `captured_inputs` are passed by value. For // efficiency, you are recommended to move this argument into the call. static Status Create(const NameAttrList& func, std::vector<Tensor> captured_inputs, std::unique_ptr<CapturedFunction>* out_function); + // Creates a new instance using a list of named attributes, fetching captured + // inputs from a context argument. + static Status Create(const NameAttrList& func, OpKernelContext* ctx, + const string& argument, + std::unique_ptr<CapturedFunction>* out_function); + ~CapturedFunction(); // Runs the "Captured function" using the given FLR and caches the lib and @@ -87,6 +95,9 @@ class CapturedFunction { std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done); + // Returns the named list of function arguments. + const NameAttrList& func() { return func_; } + // Returns that additional captured inputs that will be passed to the function // when `Run*()` is called. const std::vector<Tensor>& captured_inputs() { return captured_inputs_; } diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc new file mode 100644 index 0000000000..c8aeaab9cb --- /dev/null +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -0,0 +1,422 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <map> + +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { + public: + explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx), + graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::unique_ptr<CapturedFunction> captured_key_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, + "key_func_other_arguments", + &captured_key_func)); + std::unique_ptr<CapturedFunction> captured_init_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx, + "init_func_other_arguments", + &captured_init_func)); + std::unique_ptr<CapturedFunction> captured_reduce_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, + "reduce_func_other_arguments", + &captured_reduce_func)); + std::unique_ptr<CapturedFunction> captured_finalize_func; + OP_REQUIRES_OK(ctx, + CapturedFunction::Create(finalize_func_, ctx, + "finalize_func_other_arguments", + &captured_finalize_func)); + + *output = new Dataset( + ctx, input, std::move(captured_key_func), std::move(captured_init_func), + std::move(captured_reduce_func), std::move(captured_finalize_func), + output_types_, output_shapes_); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + std::unique_ptr<CapturedFunction> captured_key_func, + std::unique_ptr<CapturedFunction> captured_init_func, + std::unique_ptr<CapturedFunction> captured_reduce_func, + std::unique_ptr<CapturedFunction> captured_finalize_func, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : GraphDatasetBase(ctx), + input_(input), + captured_key_func_(std::move(captured_key_func)), + captured_init_func_(std::move(captured_init_func)), + captured_reduce_func_(std::move(captured_reduce_func)), + captured_finalize_func_(std::move(captured_finalize_func)), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIterator( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { return "GroupByReducerDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name())); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node)); + + std::vector<Node*> key_func_other_arguments_node; + DataTypeVector key_func_other_arguments_types; + TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( + b, captured_key_func_, &key_func_other_arguments_node, + &key_func_other_arguments_types)); + + std::vector<Node*> init_func_other_arguments_node; + DataTypeVector init_func_other_arguments_types; + TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( + b, captured_init_func_, &init_func_other_arguments_node, + &init_func_other_arguments_types)); + + std::vector<Node*> reduce_func_other_arguments_node; + DataTypeVector reduce_func_other_arguments_types; + TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( + b, captured_reduce_func_, &reduce_func_other_arguments_node, + &reduce_func_other_arguments_types)); + + std::vector<Node*> finalize_func_other_arguments_node; + DataTypeVector finalize_func_other_arguments_types; + TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType( + b, captured_finalize_func_, &finalize_func_other_arguments_node, + &finalize_func_other_arguments_types)); + + AttrValue key_func; + b->BuildAttrValue(this->key_func(), &key_func); + AttrValue init_func; + b->BuildAttrValue(this->init_func(), &init_func); + AttrValue reduce_func; + b->BuildAttrValue(this->reduce_func(), &reduce_func); + AttrValue finalize_func; + b->BuildAttrValue(this->finalize_func(), &finalize_func); + + AttrValue key_func_other_arguments_types_attr; + b->BuildAttrValue(key_func_other_arguments_types, + &key_func_other_arguments_types_attr); + AttrValue init_func_other_arguments_types_attr; + b->BuildAttrValue(init_func_other_arguments_types, + &init_func_other_arguments_types_attr); + AttrValue reduce_func_other_arguments_types_attr; + b->BuildAttrValue(reduce_func_other_arguments_types, + &reduce_func_other_arguments_types_attr); + AttrValue finalize_func_other_arguments_types_attr; + b->BuildAttrValue(finalize_func_other_arguments_types, + &finalize_func_other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, {{0, input_graph_node}}, + {{1, key_func_other_arguments_node}, + {2, init_func_other_arguments_node}, + {3, reduce_func_other_arguments_node}, + {4, finalize_func_other_arguments_node}}, + {{"key_func", key_func}, + {"init_func", init_func}, + {"reduce_func", reduce_func}, + {"finalize_func", finalize_func}, + {"Tkey_func_other_arguments", key_func_other_arguments_types_attr}, + {"Tinit_func_other_arguments", init_func_other_arguments_types_attr}, + {"Treduce_func_other_arguments", + reduce_func_other_arguments_types_attr}, + {"Tfinalize_func_other_arguments", + finalize_func_other_arguments_types_attr}}, + output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + + // Iterate through the input dataset, keying input elements to reducers. + while (!end_of_input_) { + std::vector<Tensor> next_input_element; + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, &next_input_element, &end_of_input_)); + + if (!end_of_input_) { + // Run the key function on the input element. + std::vector<Tensor> key_func_output; + TF_RETURN_IF_ERROR( + dataset()->captured_key_func_->RunWithBorrowedArgs( + ctx, next_input_element, &key_func_output)); + + if (key_func_output.size() != 1 || + key_func_output[0].dtype() != DT_INT64 || + key_func_output[0].NumElements() != 1) { + // TODO(b/78665031): Support non-int64 keys. + return errors::InvalidArgument( + "`key_func` must return a scalar int64."); + } + const int64 key = key_func_output[0].scalar<int64>()(); + + if (states_.find(key) == states_.end()) { + // Run the init function to create the initial state. + std::vector<Tensor> init_func_output; + TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Run( + ctx, std::move(key_func_output), &init_func_output)); + states_[key] = init_func_output; + } + + // Run the reduce function to update the current state. + std::vector<Tensor> args; + args.reserve(states_[key].size() + next_input_element.size()); + std::copy(states_[key].begin(), states_[key].end(), + std::back_inserter(args)); + std::copy(next_input_element.begin(), next_input_element.end(), + std::back_inserter(args)); + + std::vector<Tensor> reduce_func_output; + TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run( + ctx, std::move(args), &reduce_func_output)); + states_[key] = reduce_func_output; + } else { + keys_.resize(states_.size()); + int idx = 0; + for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) { + keys_[idx] = it->first; + } + } + } + + if (keys_index_ == keys_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR( + dataset()->captured_finalize_func_->RunWithBorrowedArgs( + ctx, states_[keys_[keys_index_++]], out_tensors)); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_)); + + if (end_of_input_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("end_of_input"), "")); + } + + // Saving states_. + if (!states_.empty()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("states_size"), states_.size())); + int idx = 0; + for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) { + int64 key = it->first; + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("states[", idx, "]->key")), key)); + if (!it->second.empty()) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("states[", idx, "]->state_size")), + it->second.size())); + for (int j = 0; j < it->second.size(); ++j) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat("states[", idx, "]->state[", j, "]")), + it->second[j])); + } + } + } + } + + // Saving keys_index_ and keys_. + if (end_of_input_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("keys_index"), keys_index_)); + if (!keys_.empty()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("keys_size"), keys_.size())); + for (int idx = 0; idx < keys_.size(); ++idx) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("keys[", idx, "]")), keys_[idx])); + } + } + } + + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_)); + + if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; + + // Restoring states_. + if (reader->Contains(full_name("states_size"))) { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("states_size"), &size)); + for (int idx = 0; idx < size; ++idx) { + int64 key; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("states[", idx, "]->key")), &key)); + std::vector<Tensor> state; + if (reader->Contains(full_name( + strings::StrCat("states[", idx, "]->state_size")))) { + int64 state_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("states[", idx, "]->state_size")), + &state_size)); + state.resize(state_size); + for (int j = 0; j < state_size; ++j) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name( + strings::StrCat("states[", idx, "]->state[", j, "]")), + &state[j])); + } + } + states_[key] = state; + } + } + + // Restoring keys_index_ and keys_. + if (end_of_input_) { + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("keys_index"), &keys_index_)); + if (reader->Contains(full_name("keys_size"))) { + int64 size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("keys_size"), &size)); + keys_.resize(size); + for (int idx = 0; idx < size; ++idx) { + int64 key; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("keys[", idx, "]")), &key)); + keys_[idx] = key; + } + } + } + + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + bool end_of_input_ GUARDED_BY(mu_) = false; + std::map<int64, std::vector<Tensor>> states_ GUARDED_BY(mu_); + std::vector<int64> keys_ GUARDED_BY(mu_); + int64 keys_index_ GUARDED_BY(mu_) = 0; + }; + + const NameAttrList& key_func() const { return captured_key_func_->func(); } + + const NameAttrList& init_func() const { + return captured_init_func_->func(); + } + + const NameAttrList& reduce_func() const { + return captured_reduce_func_->func(); + } + + const NameAttrList& finalize_func() const { + return captured_finalize_func_->func(); + } + + Status OtherArgumentsNodeAndType( + DatasetGraphDefBuilder* b, + const std::unique_ptr<CapturedFunction>& captured_func, + std::vector<Node*>* other_arguments_node, + DataTypeVector* other_arguments_types) const { + other_arguments_node->reserve(captured_func->captured_inputs().size()); + other_arguments_types->reserve(captured_func->captured_inputs().size()); + for (const Tensor& t : captured_func->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments_node->emplace_back(node); + other_arguments_types->emplace_back(t.dtype()); + } + return Status::OK(); + } + + const DatasetBase* const input_; + const std::unique_ptr<CapturedFunction> captured_key_func_; + const std::unique_ptr<CapturedFunction> captured_init_func_; + const std::unique_ptr<CapturedFunction> captured_reduce_func_; + const std::unique_ptr<CapturedFunction> captured_finalize_func_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + const int graph_def_version_; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList key_func_; + NameAttrList init_func_; + NameAttrList reduce_func_; + NameAttrList finalize_func_; +}; + +REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU), + GroupByReducerDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 46f43dd1b1..03f847ce9c 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -241,7 +241,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { if (key_func_output.size() != 1 || key_func_output[0].dtype() != DT_INT64 || key_func_output[0].NumElements() != 1) { - // TODO(mrry): Support non-int64 keys. + // TODO(b/78665031): Support non-int64 keys. return errors::InvalidArgument( "`key_func` must return a scalar int64."); } diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 4ba3f15ef0..5f10ad24b6 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -270,6 +270,26 @@ REGISTER_OP("ParallelInterleaveDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("GroupByReducerDataset") + .Input("input_dataset: variant") + .Input("key_func_other_arguments: Tkey_func_other_arguments") + .Input("init_func_other_arguments: Tinit_func_other_arguments") + .Input("reduce_func_other_arguments: Treduce_func_other_arguments") + .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments") + .Output("handle: variant") + .Attr("key_func: func") + .Attr("init_func: func") + .Attr("reduce_func: func") + .Attr("finalize_func: func") + .Attr("Tkey_func_other_arguments: list(type) >= 0") + .Attr("Tinit_func_other_arguments: list(type) >= 0") + .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Tfinalize_func_other_arguments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("GroupByWindowDataset") .Input("input_dataset: variant") .Input("key_func_other_arguments: Tkey_func_other_arguments") |