aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/get_single_element.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/get_single_element.py')
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index 0f4cd8e20c..ef9284456e 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
@@ -68,3 +71,30 @@ def get_single_element(dataset):
return sparse.deserialize_sparse_tensors(
nested_ret, dataset.output_types, dataset.output_shapes,
dataset.output_classes)
+
+
+def reduce_dataset(dataset, reducer):
+ """Returns the result of reducing the `dataset` using `reducer`.
+
+ Args:
+ dataset: A @{tf.data.Dataset} object.
+ reducer: A @{tf.contrib.data.Reducer} object representing the reduce logic.
+
+ Returns:
+ A nested structure of @{tf.Tensor} objects, corresponding to the result
+ of reducing `dataset` using `reducer`.
+
+ Raises:
+ TypeError: if `dataset` is not a `tf.data.Dataset` object.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+
+ # The sentinel dataset is used in case the reduced dataset is empty.
+ sentinel_dataset = dataset_ops.Dataset.from_tensors(
+ reducer.finalize_func(reducer.init_func(np.int64(0))))
+ reduced_dataset = dataset.apply(
+ grouping.group_by_reducer(lambda x: np.int64(0), reducer))
+
+ return get_single_element(
+ reduced_dataset.concatenate(sentinel_dataset).take(1))