diff options
author | 2018-09-24 11:00:48 -0700 | |
---|---|---|
committer | 2018-09-24 11:07:30 -0700 | |
commit | f7017ef769bd603b61f25dfffc772e2153a9f076 (patch) | |
tree | 328b23433a0abe79322d0a53abd1f704086024a0 /tensorflow/contrib/data | |
parent | 770a81b1edcb923086b82252d2c1a0271b0c49c5 (diff) |
[data-stats] Exposes `StatsAggregator` and `set_stats_aggregator` in tf.contrib.data.
PiperOrigin-RevId: 214294955
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/__init__.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/stats_ops.py | 31 |
2 files changed, 19 insertions, 18 deletions
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index c378b1ce8d..3cb51279c3 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -44,6 +44,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@group_by_reducer @@group_by_window @@ignore_errors +@@latency_stats @@make_batched_features_dataset @@make_csv_dataset @@make_saveable_from_iterator @@ -57,9 +58,11 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@reduce_dataset @@sample_from_datasets @@scan +@@set_stats_aggregator @@shuffle_and_repeat @@sliding_window_batch @@sloppy_interleave +@@StatsAggregator @@unbatch @@unique @@ -111,6 +114,9 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch +from tensorflow.contrib.data.python.ops.stats_ops import latency_stats +from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator +from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator from tensorflow.contrib.data.python.ops.unique import unique from tensorflow.contrib.data.python.ops.writers import TFRecordWriter from tensorflow.python.data.ops.iterator_ops import get_next_as_optional diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index 8426228992..7410ee8e05 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -23,34 +23,31 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops -# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. class StatsAggregator(object): """A stateful resource that aggregates statistics from one or more iterators. To record statistics, use one of the custom transformation functions defined in this module when defining your `tf.data.Dataset`. All statistics will be aggregated by the `StatsAggregator` that is associated with a particular - iterator (see below). For example, to record the total number of bytes - produced by iterating over a dataset: + iterator (see below). For example, to record the latency of producing each + element by iterating over a dataset: ```python dataset = ... - dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) + dataset = dataset.apply(stats_ops.latency_stats("total_bytes")) ``` - To associate a `StatsAggregator` with a `tf.data.Iterator` object, use + To associate a `StatsAggregator` with a `tf.data.Dataset` object, use the following pattern: ```python - dataset = ... - iterator = dataset.make_one_shot_iterator() stats_aggregator = stats_ops.StatsAggregator() - set_op = stats_aggregator.subscribe(iterator) + dataset = ... - with tf.Session() as sess: - # Running `set_op` will associate `iterator` with `stats_aggregator`. - sess.run(set_op) + # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`. + dataset = dataset.apply( + tf.contrib.data.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_one_shot_iterator() ``` To get a protocol buffer summary of the currently aggregated statistics, @@ -60,6 +57,7 @@ class StatsAggregator(object): ```python stats_aggregator = stats_ops.StatsAggregator() + # ... stats_summary = stats_aggregator.get_summary() tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) ``` @@ -73,6 +71,7 @@ class StatsAggregator(object): """Creates a `StatsAggregator`.""" self._resource = gen_dataset_ops.stats_aggregator_handle() + # TODO(b/116314787): Update this/add support for V2 summary API. def get_summary(self): """Returns a string `tf.Tensor` that summarizes the aggregated statistics. @@ -112,13 +111,11 @@ class _SetStatsAggregatorDataset(dataset_ops.Dataset): return self._input_dataset.output_classes -# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. def set_stats_aggregator(stats_aggregator): - """Set the given stats_aggregator for aggregating the input dataset stats. + """Set the given `stats_aggregator` for aggregating the input dataset stats. Args: - stats_aggregator: A `StatsAggregator` object. + stats_aggregator: A `tf.contrib.data.StatsAggregator` object. Returns: A `Dataset` transformation function, which can be passed to @@ -155,8 +152,6 @@ def bytes_produced_stats(tag): return _apply_fn -# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable -# or make private / remove. def latency_stats(tag): """Records the latency of producing each element of the input dataset. |