aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-09-24 11:00:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 11:07:30 -0700
commitf7017ef769bd603b61f25dfffc772e2153a9f076 (patch)
tree328b23433a0abe79322d0a53abd1f704086024a0 /tensorflow/contrib/data
parent770a81b1edcb923086b82252d2c1a0271b0c49c5 (diff)
[data-stats] Exposes `StatsAggregator` and `set_stats_aggregator` in tf.contrib.data.
PiperOrigin-RevId: 214294955
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/__init__.py6
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py31
2 files changed, 19 insertions, 18 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index c378b1ce8d..3cb51279c3 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -44,6 +44,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@group_by_reducer
@@group_by_window
@@ignore_errors
+@@latency_stats
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
@@ -57,9 +58,11 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@reduce_dataset
@@sample_from_datasets
@@scan
+@@set_stats_aggregator
@@shuffle_and_repeat
@@sliding_window_batch
@@sloppy_interleave
+@@StatsAggregator
@@unbatch
@@unique
@@ -111,6 +114,9 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
+from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
+from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
+from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 8426228992..7410ee8e05 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -23,34 +23,31 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
To record statistics, use one of the custom transformation functions defined
in this module when defining your `tf.data.Dataset`. All statistics will be
aggregated by the `StatsAggregator` that is associated with a particular
- iterator (see below). For example, to record the total number of bytes
- produced by iterating over a dataset:
+ iterator (see below). For example, to record the latency of producing each
+ element by iterating over a dataset:
```python
dataset = ...
- dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes"))
+ dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
```
- To associate a `StatsAggregator` with a `tf.data.Iterator` object, use
+ To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
the following pattern:
```python
- dataset = ...
- iterator = dataset.make_one_shot_iterator()
stats_aggregator = stats_ops.StatsAggregator()
- set_op = stats_aggregator.subscribe(iterator)
+ dataset = ...
- with tf.Session() as sess:
- # Running `set_op` will associate `iterator` with `stats_aggregator`.
- sess.run(set_op)
+ # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
+ dataset = dataset.apply(
+ tf.contrib.data.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_one_shot_iterator()
```
To get a protocol buffer summary of the currently aggregated statistics,
@@ -60,6 +57,7 @@ class StatsAggregator(object):
```python
stats_aggregator = stats_ops.StatsAggregator()
+ # ...
stats_summary = stats_aggregator.get_summary()
tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
```
@@ -73,6 +71,7 @@ class StatsAggregator(object):
"""Creates a `StatsAggregator`."""
self._resource = gen_dataset_ops.stats_aggregator_handle()
+ # TODO(b/116314787): Update this/add support for V2 summary API.
def get_summary(self):
"""Returns a string `tf.Tensor` that summarizes the aggregated statistics.
@@ -112,13 +111,11 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset):
return self._input_dataset.output_classes
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
def set_stats_aggregator(stats_aggregator):
- """Set the given stats_aggregator for aggregating the input dataset stats.
+ """Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `StatsAggregator` object.
+ stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -155,8 +152,6 @@ def bytes_produced_stats(tag):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.