aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py174
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py301
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt69
-rw-r--r--tensorflow/core/kernels/data/BUILD15
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc14
-rw-r--r--tensorflow/core/kernels/data/captured_function.h11
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc422
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc20
10 files changed, 1028 insertions, 2 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 55a56b83a8..bd3e034211 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -35,6 +36,179 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
+class GroupByReducerTest(test.TestCase):
+
+ def checkResults(self, dataset, shapes, values):
+ self.assertEqual(shapes, dataset.output_shapes)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.test_session() as sess:
+ for expected in values:
+ got = sess.run(get_next)
+ self.assertEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSum(self):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(lambda x: x % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testAverage(self):
+
+ def reduce_fn(x, y):
+ return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+ x[1] + 1), x[1] + 1
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: (0.0, 0.0),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x: x[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(
+ lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+ def testConcat(self):
+ components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+ reducer = grouping.Reducer(
+ init_func=lambda x: "",
+ reduce_func=lambda x, y: x + y[0],
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(components),
+ dataset_ops.Dataset.range(2 * i))).apply(
+ grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+ self.checkResults(
+ dataset,
+ shapes=tensor_shape.scalar(),
+ values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+ def testSparseSum(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1], dtype=np.int64)),
+ dense_shape=np.array([1, 1]))
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: _sparse(np.int64(0)),
+ reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+ finalize_func=lambda x: x.values[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+ grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testChangingStateShape(self):
+
+ def reduce_fn(x, _):
+ # Statically known rank, but dynamic length.
+ larger_dim = array_ops.concat([x[0], x[0]], 0)
+ # Statically unknown rank.
+ larger_rank = array_ops.expand_dims(x[1], 0)
+ return larger_dim, larger_rank
+
+ reducer = grouping.Reducer(
+ init_func=lambda x: ([0], 1),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x: x)
+
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+ grouping.group_by_reducer(lambda x: x, reducer))
+ self.assertEqual([None], dataset.output_shapes[0].as_list())
+ self.assertIs(None, dataset.output_shapes[1].ndims)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual([0] * (2**i), x)
+ self.assertAllEqual(np.array(1, ndmin=i), y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testTypeMismatch(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+ reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+ # TODO(b/78665031): Remove once non-scalar keys are supported.
+ def testInvalidKeyShape(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+ # TODO(b/78665031): Remove once non-int64 keys are supported.
+ def testInvalidKeyType(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+
+class GroupByReducerSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ return dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_reducer(lambda x: x % 5, reducer))
+
+ def testCoreGroupByReducer(self):
+ components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 5,
+ verify_exhausted=True)
+
+
class GroupByWindowTest(test.TestCase):
def testSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index f544b1caa6..eb2ceff893 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -168,7 +168,7 @@ class ScanDatasetTest(test.TestCase):
scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-class ScanDatasetSerialzationTest(
+class ScanDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_elements):
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 0531f9cbb9..ea229b5b27 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -33,6 +34,35 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
+def group_by_reducer(key_func, reducer):
+ """A transformation that groups elements and performs a reduction.
+
+ This transformation maps element of a dataset to a key using `key_func` and
+ groups the elements by key. The `reducer` is used to process each group; its
+ `init_func` is used to initialize state for each group when it is created, the
+ `reduce_func` is used to update the state every time an element is mapped to
+ the matching group, and the `finalize_func` is used to map the final state to
+ an output value.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reducer: An instance of `Reducer`, which captures the reduction logic using
+ the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return GroupByReducerDataset(dataset, key_func, reducer)
+
+ return _apply_fn
+
+
def group_by_window(key_func,
reduce_func,
window_size=None,
@@ -227,6 +257,250 @@ class _VariantDataset(dataset_ops.Dataset):
return self._output_types
+class GroupByReducerDataset(dataset_ops.Dataset):
+ """A `Dataset` that groups its input and performs a reduction."""
+
+ def __init__(self, input_dataset, key_func, reducer):
+ """See `group_by_reducer()` for details."""
+ super(GroupByReducerDataset, self).__init__()
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_init_func(reducer.init_func)
+ self._make_reduce_func(reducer.reduce_func, input_dataset)
+ self._make_finalize_func(reducer.finalize_func)
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+
+ @function.Defun(*nest.flatten(
+ sparse.as_dense_types(input_dataset.output_types,
+ input_dataset.output_classes)))
+ def tf_key_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ # Pass in shape information from the input_dataset.
+ dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
+ input_dataset.output_classes)
+ for arg, shape in zip(args, nest.flatten(dense_shapes)):
+ arg.set_shape(shape)
+
+ nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
+ nested_args = sparse.deserialize_sparse_tensors(
+ nested_args, input_dataset.output_types, input_dataset.output_shapes,
+ input_dataset.output_classes)
+ # pylint: disable=protected-access
+ if dataset_ops._should_unpack_args(nested_args):
+ ret = key_func(*nested_args)
+ # pylint: enable=protected-access
+ else:
+ ret = key_func(nested_args)
+ ret = ops.convert_to_tensor(ret)
+ if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
+ return ret
+
+ self._key_func = tf_key_func
+ self._key_func.add_to_graph(ops.get_default_graph())
+
+ def _make_init_func(self, init_func):
+ """Make wrapping Defun for init_func."""
+
+ @function.Defun(dtypes.int64)
+ def tf_init_func(key):
+ """A wrapper for Defun that facilitates shape inference."""
+ key.set_shape([])
+ ret = init_func(key)
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ self._state_classes = sparse.get_classes(ret)
+ self._state_shapes = nest.pack_sequence_as(
+ ret, [t.get_shape() for t in nest.flatten(ret)])
+ self._state_types = nest.pack_sequence_as(
+ ret, [t.dtype for t in nest.flatten(ret)])
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ self._init_func = tf_init_func
+ self._init_func.add_to_graph(ops.get_default_graph())
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ # Create a list in which `tf_reduce_func` will store the new shapes.
+ flat_new_state_shapes = []
+
+ @function.Defun(*(nest.flatten(
+ sparse.as_dense_types(
+ self._state_types, self._state_classes)) + nest.flatten(
+ sparse.as_dense_types(input_dataset.output_types,
+ input_dataset.output_classes))))
+ def tf_reduce_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ for arg, shape in zip(
+ args,
+ nest.flatten(
+ sparse.as_dense_shapes(self._state_shapes, self._state_classes))
+ + nest.flatten(
+ sparse.as_dense_shapes(input_dataset.output_shapes,
+ input_dataset.output_classes))):
+ arg.set_shape(shape)
+
+ pivot = len(nest.flatten(self._state_shapes))
+ nested_state_args = nest.pack_sequence_as(self._state_types,
+ args[:pivot])
+ nested_state_args = sparse.deserialize_sparse_tensors(
+ nested_state_args, self._state_types, self._state_shapes,
+ self._state_classes)
+ nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
+ args[pivot:])
+ nested_input_args = sparse.deserialize_sparse_tensors(
+ nested_input_args, input_dataset.output_types,
+ input_dataset.output_shapes, input_dataset.output_classes)
+
+ ret = reduce_func(nested_state_args, nested_input_args)
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ # Extract shape information from the returned values.
+ flat_new_state = nest.flatten(ret)
+ flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])
+
+ # Extract and validate type information from the returned values.
+ for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
+ if t.dtype != dtype:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types,
+ nest.pack_sequence_as(self._state_types,
+ [t.dtype for t in flat_new_state])))
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret,
+ [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ # Use the private method that will execute `tf_reduce_func` but delay
+ # adding it to the graph in case we need to rerun the function.
+ tf_reduce_func._create_definition_if_needed() # pylint: disable=protected-access
+
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ weakened_state_shapes = [
+ old.most_specific_compatible_shape(new)
+ for old, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for old_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if old_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ old_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._reduce_func = tf_reduce_func
+ self._reduce_func.add_to_graph(ops.get_default_graph())
+
+ def _make_finalize_func(self, finalize_func):
+ """Make wrapping Defun for finalize_func."""
+
+ @function.Defun(*(nest.flatten(
+ sparse.as_dense_types(self._state_types, self._state_classes))))
+ def tf_finalize_func(*args):
+ """A wrapper for Defun that facilitates shape inference."""
+ for arg, shape in zip(
+ args,
+ nest.flatten(
+ sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
+ arg.set_shape(shape)
+
+ nested_args = nest.pack_sequence_as(self._state_types, args)
+ nested_args = sparse.deserialize_sparse_tensors(
+ nested_args, self._state_types, self._state_shapes,
+ self._state_classes)
+
+ ret = finalize_func(nested_args)
+
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ ret = nest.pack_sequence_as(ret, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
+ ])
+
+ self._output_classes = sparse.get_classes(ret)
+ self._output_shapes = nest.pack_sequence_as(
+ ret, [t.get_shape() for t in nest.flatten(ret)])
+ self._output_types = nest.pack_sequence_as(
+ ret, [t.dtype for t in nest.flatten(ret)])
+
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
+ return nest.flatten(ret)
+
+ self._finalize_func = tf_finalize_func
+ self._finalize_func.add_to_graph(ops.get_default_graph())
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_reducer_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._init_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._finalize_func.captured_inputs,
+ key_func=self._key_func,
+ init_func=self._init_func,
+ reduce_func=self._reduce_func,
+ finalize_func=self._finalize_func,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+
class GroupByWindowDataset(dataset_ops.Dataset):
"""A `Dataset` that groups its input and performs a windowed reduction."""
@@ -336,3 +610,30 @@ class GroupByWindowDataset(dataset_ops.Dataset):
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+
+class Reducer(object):
+ """A reducer is used for reducing a set of elements.
+
+ A reducer is represented as a tuple of the three functions:
+ 1) initialization function: key => initial state
+ 2) reduce function: (old state, input) => new state
+ 3) finalization function: state => result
+ """
+
+ def __init__(self, init_func, reduce_func, finalize_func):
+ self._init_func = init_func
+ self._reduce_func = reduce_func
+ self._finalize_func = finalize_func
+
+ @property
+ def init_func(self):
+ return self._init_func
+
+ @property
+ def reduce_func(self):
+ return self._reduce_func
+
+ @property
+ def finalize_func(self):
+ return self._finalize_func
diff --git a/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt
new file mode 100644
index 0000000000..067ad4018b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_GroupByReducerDataset.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "GroupByReducerDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "key_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `key_func`.
+END
+ }
+ attr {
+ name: "key_func"
+ description: <<END
+A function mapping an element of `input_dataset`, concatenated
+with `key_func_other_arguments` to a scalar value of type DT_INT64.
+END
+ }
+ in_arg {
+ name: "init_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `init_func`.
+END
+ }
+ attr {
+ name: "init_func"
+ description: <<END
+A function mapping a key of type DT_INT64, concatenated with
+`init_func_other_arguments` to the initial reducer state.
+END
+ }
+ in_arg {
+ name: "reduce_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `reduce_func`.
+END
+ }
+ attr {
+ name: "reduce_func"
+ description: <<END
+A function mapping the current reducer state and an element of `input_dataset`,
+concatenated with `reduce_func_other_arguments` to a new reducer state.
+END
+ }
+ in_arg {
+ name: "finalize_func_other_arguments"
+ description: <<END
+A list of tensors, typically values that were captured when
+building a closure for `finalize_func`.
+END
+ }
+ attr {
+ name: "finalize_func"
+ description: <<END
+A function mapping the final reducer state to an output element.
+END
+ }
+ summary: "Creates a dataset that computes a group-by on `input_dataset`."
+ description: <<END
+Creates a dataset that computes a group-by on `input_dataset`.
+END
+}
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index c78e0aff83..9ded2667eb 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -124,6 +124,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "group_by_reducer_dataset_op",
+ srcs = ["group_by_reducer_dataset_op.cc"],
+ deps = [
+ ":captured_function",
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "group_by_window_dataset_op",
srcs = ["group_by_window_dataset_op.cc"],
deps = [
@@ -550,6 +564,7 @@ tf_kernel_library(
":filter_dataset_op",
":flat_map_dataset_op",
":generator_dataset_op",
+ ":group_by_reducer_dataset_op",
":group_by_window_dataset_op",
":interleave_dataset_op",
":iterator_ops",
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index dd61b7daee..ee58341cfd 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -32,6 +32,20 @@ Status CapturedFunction::Create(
return Status::OK();
}
+/* static */
+Status CapturedFunction::Create(
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ std::unique_ptr<CapturedFunction>* out_function) {
+ OpInputList argument_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
+ std::vector<Tensor> arguments_t;
+ arguments_t.reserve(argument_inputs.size());
+ for (const Tensor& t : argument_inputs) {
+ arguments_t.push_back(t);
+ }
+ return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+}
+
CapturedFunction::~CapturedFunction() {
if (lib_ != nullptr && f_handle_ != kInvalidHandle) {
lib_->ReleaseHandle(f_handle_).IgnoreError();
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index 490f5cd1e3..e9ad3e381d 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -40,12 +40,20 @@ class ResourceMgr;
// context.
class CapturedFunction {
public:
+ // Creates a new instance from a list of named attributes and captured inputs.
+ //
// NOTE(mrry): The `captured_inputs` are passed by value. For
// efficiency, you are recommended to move this argument into the call.
static Status Create(const NameAttrList& func,
std::vector<Tensor> captured_inputs,
std::unique_ptr<CapturedFunction>* out_function);
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
+ std::unique_ptr<CapturedFunction>* out_function);
+
~CapturedFunction();
// Runs the "Captured function" using the given FLR and caches the lib and
@@ -87,6 +95,9 @@ class CapturedFunction {
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done);
+ // Returns the named list of function arguments.
+ const NameAttrList& func() { return func_; }
+
// Returns that additional captured inputs that will be passed to the function
// when `Run*()` is called.
const std::vector<Tensor>& captured_inputs() { return captured_inputs_; }
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
new file mode 100644
index 0000000000..c8aeaab9cb
--- /dev/null
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::unique_ptr<CapturedFunction> captured_key_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
+ std::unique_ptr<CapturedFunction> captured_init_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(init_func_, ctx,
+ "init_func_other_arguments",
+ &captured_init_func));
+ std::unique_ptr<CapturedFunction> captured_reduce_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
+ std::unique_ptr<CapturedFunction> captured_finalize_func;
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_arguments",
+ &captured_finalize_func));
+
+ *output = new Dataset(
+ ctx, input, std::move(captured_key_func), std::move(captured_init_func),
+ std::move(captured_reduce_func), std::move(captured_finalize_func),
+ output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ std::unique_ptr<CapturedFunction> captured_key_func,
+ std::unique_ptr<CapturedFunction> captured_init_func,
+ std::unique_ptr<CapturedFunction> captured_reduce_func,
+ std::unique_ptr<CapturedFunction> captured_finalize_func,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ captured_key_func_(std::move(captured_key_func)),
+ captured_init_func_(std::move(captured_init_func)),
+ captured_reduce_func_(std::move(captured_reduce_func)),
+ captured_finalize_func_(std::move(captured_finalize_func)),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() override { return "GroupByReducerDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+
+ std::vector<Node*> key_func_other_arguments_node;
+ DataTypeVector key_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_key_func_, &key_func_other_arguments_node,
+ &key_func_other_arguments_types));
+
+ std::vector<Node*> init_func_other_arguments_node;
+ DataTypeVector init_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_init_func_, &init_func_other_arguments_node,
+ &init_func_other_arguments_types));
+
+ std::vector<Node*> reduce_func_other_arguments_node;
+ DataTypeVector reduce_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_reduce_func_, &reduce_func_other_arguments_node,
+ &reduce_func_other_arguments_types));
+
+ std::vector<Node*> finalize_func_other_arguments_node;
+ DataTypeVector finalize_func_other_arguments_types;
+ TF_RETURN_IF_ERROR(OtherArgumentsNodeAndType(
+ b, captured_finalize_func_, &finalize_func_other_arguments_node,
+ &finalize_func_other_arguments_types));
+
+ AttrValue key_func;
+ b->BuildAttrValue(this->key_func(), &key_func);
+ AttrValue init_func;
+ b->BuildAttrValue(this->init_func(), &init_func);
+ AttrValue reduce_func;
+ b->BuildAttrValue(this->reduce_func(), &reduce_func);
+ AttrValue finalize_func;
+ b->BuildAttrValue(this->finalize_func(), &finalize_func);
+
+ AttrValue key_func_other_arguments_types_attr;
+ b->BuildAttrValue(key_func_other_arguments_types,
+ &key_func_other_arguments_types_attr);
+ AttrValue init_func_other_arguments_types_attr;
+ b->BuildAttrValue(init_func_other_arguments_types,
+ &init_func_other_arguments_types_attr);
+ AttrValue reduce_func_other_arguments_types_attr;
+ b->BuildAttrValue(reduce_func_other_arguments_types,
+ &reduce_func_other_arguments_types_attr);
+ AttrValue finalize_func_other_arguments_types_attr;
+ b->BuildAttrValue(finalize_func_other_arguments_types,
+ &finalize_func_other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {{0, input_graph_node}},
+ {{1, key_func_other_arguments_node},
+ {2, init_func_other_arguments_node},
+ {3, reduce_func_other_arguments_node},
+ {4, finalize_func_other_arguments_node}},
+ {{"key_func", key_func},
+ {"init_func", init_func},
+ {"reduce_func", reduce_func},
+ {"finalize_func", finalize_func},
+ {"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
+ {"Tinit_func_other_arguments", init_func_other_arguments_types_attr},
+ {"Treduce_func_other_arguments",
+ reduce_func_other_arguments_types_attr},
+ {"Tfinalize_func_other_arguments",
+ finalize_func_other_arguments_types_attr}},
+ output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+
+ // Iterate through the input dataset, keying input elements to reducers.
+ while (!end_of_input_) {
+ std::vector<Tensor> next_input_element;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &next_input_element, &end_of_input_));
+
+ if (!end_of_input_) {
+ // Run the key function on the input element.
+ std::vector<Tensor> key_func_output;
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_key_func_->RunWithBorrowedArgs(
+ ctx, next_input_element, &key_func_output));
+
+ if (key_func_output.size() != 1 ||
+ key_func_output[0].dtype() != DT_INT64 ||
+ key_func_output[0].NumElements() != 1) {
+ // TODO(b/78665031): Support non-int64 keys.
+ return errors::InvalidArgument(
+ "`key_func` must return a scalar int64.");
+ }
+ const int64 key = key_func_output[0].scalar<int64>()();
+
+ if (states_.find(key) == states_.end()) {
+ // Run the init function to create the initial state.
+ std::vector<Tensor> init_func_output;
+ TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Run(
+ ctx, std::move(key_func_output), &init_func_output));
+ states_[key] = init_func_output;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(states_[key].size() + next_input_element.size());
+ std::copy(states_[key].begin(), states_[key].end(),
+ std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Run(
+ ctx, std::move(args), &reduce_func_output));
+ states_[key] = reduce_func_output;
+ } else {
+ keys_.resize(states_.size());
+ int idx = 0;
+ for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
+ keys_[idx] = it->first;
+ }
+ }
+ }
+
+ if (keys_index_ == keys_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_finalize_func_->RunWithBorrowedArgs(
+ ctx, states_[keys_[keys_index_++]], out_tensors));
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+
+ // Saving states_.
+ if (!states_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("states_size"), states_.size()));
+ int idx = 0;
+ for (auto it = states_.begin(); it != states_.end(); ++idx, ++it) {
+ int64 key = it->first;
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("states[", idx, "]->key")), key));
+ if (!it->second.empty()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("states[", idx, "]->state_size")),
+ it->second.size()));
+ for (int j = 0; j < it->second.size(); ++j) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("states[", idx, "]->state[", j, "]")),
+ it->second[j]));
+ }
+ }
+ }
+ }
+
+ // Saving keys_index_ and keys_.
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("keys_index"), keys_index_));
+ if (!keys_.empty()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("keys_size"), keys_.size()));
+ for (int idx = 0; idx < keys_.size(); ++idx) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("keys[", idx, "]")), keys_[idx]));
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+
+ // Restoring states_.
+ if (reader->Contains(full_name("states_size"))) {
+ int64 size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("states_size"), &size));
+ for (int idx = 0; idx < size; ++idx) {
+ int64 key;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("states[", idx, "]->key")), &key));
+ std::vector<Tensor> state;
+ if (reader->Contains(full_name(
+ strings::StrCat("states[", idx, "]->state_size")))) {
+ int64 state_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("states[", idx, "]->state_size")),
+ &state_size));
+ state.resize(state_size);
+ for (int j = 0; j < state_size; ++j) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(
+ strings::StrCat("states[", idx, "]->state[", j, "]")),
+ &state[j]));
+ }
+ }
+ states_[key] = state;
+ }
+ }
+
+ // Restoring keys_index_ and keys_.
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("keys_index"), &keys_index_));
+ if (reader->Contains(full_name("keys_size"))) {
+ int64 size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("keys_size"), &size));
+ keys_.resize(size);
+ for (int idx = 0; idx < size; ++idx) {
+ int64 key;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("keys[", idx, "]")), &key));
+ keys_[idx] = key;
+ }
+ }
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ bool end_of_input_ GUARDED_BY(mu_) = false;
+ std::map<int64, std::vector<Tensor>> states_ GUARDED_BY(mu_);
+ std::vector<int64> keys_ GUARDED_BY(mu_);
+ int64 keys_index_ GUARDED_BY(mu_) = 0;
+ };
+
+ const NameAttrList& key_func() const { return captured_key_func_->func(); }
+
+ const NameAttrList& init_func() const {
+ return captured_init_func_->func();
+ }
+
+ const NameAttrList& reduce_func() const {
+ return captured_reduce_func_->func();
+ }
+
+ const NameAttrList& finalize_func() const {
+ return captured_finalize_func_->func();
+ }
+
+ Status OtherArgumentsNodeAndType(
+ DatasetGraphDefBuilder* b,
+ const std::unique_ptr<CapturedFunction>& captured_func,
+ std::vector<Node*>* other_arguments_node,
+ DataTypeVector* other_arguments_types) const {
+ other_arguments_node->reserve(captured_func->captured_inputs().size());
+ other_arguments_types->reserve(captured_func->captured_inputs().size());
+ for (const Tensor& t : captured_func->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments_node->emplace_back(node);
+ other_arguments_types->emplace_back(t.dtype());
+ }
+ return Status::OK();
+ }
+
+ const DatasetBase* const input_;
+ const std::unique_ptr<CapturedFunction> captured_key_func_;
+ const std::unique_ptr<CapturedFunction> captured_init_func_;
+ const std::unique_ptr<CapturedFunction> captured_reduce_func_;
+ const std::unique_ptr<CapturedFunction> captured_finalize_func_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList key_func_;
+ NameAttrList init_func_;
+ NameAttrList reduce_func_;
+ NameAttrList finalize_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
+ GroupByReducerDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 46f43dd1b1..03f847ce9c 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -241,7 +241,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
if (key_func_output.size() != 1 ||
key_func_output[0].dtype() != DT_INT64 ||
key_func_output[0].NumElements() != 1) {
- // TODO(mrry): Support non-int64 keys.
+ // TODO(b/78665031): Support non-int64 keys.
return errors::InvalidArgument(
"`key_func` must return a scalar int64.");
}
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 4ba3f15ef0..5f10ad24b6 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -270,6 +270,26 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("GroupByReducerDataset")
+ .Input("input_dataset: variant")
+ .Input("key_func_other_arguments: Tkey_func_other_arguments")
+ .Input("init_func_other_arguments: Tinit_func_other_arguments")
+ .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+ .Input("finalize_func_other_arguments: Tfinalize_func_other_arguments")
+ .Output("handle: variant")
+ .Attr("key_func: func")
+ .Attr("init_func: func")
+ .Attr("reduce_func: func")
+ .Attr("finalize_func: func")
+ .Attr("Tkey_func_other_arguments: list(type) >= 0")
+ .Attr("Tinit_func_other_arguments: list(type) >= 0")
+ .Attr("Treduce_func_other_arguments: list(type) >= 0")
+ .Attr("Tfinalize_func_other_arguments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByWindowDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")