aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-10-03 12:44:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 12:51:54 -0700
commit808b1dcb318b1feb5a8c9fed5558f95cd05728e4 (patch)
treea2286241b6a0c8cba24b1da629fa6e7db475d7dc /tensorflow/python/data
parent19833284cc8fa555115aacde350ad66652b250dc (diff)
[data-stats] Sets user given `tag` and `counter_prefix` with `set_stats_aggregator`. `tag` would get prep-end with all the statistics recorded as summary and `counter_prefix` would set the prefix for the statistics recorded as counter.
Note: `counter` defaults to `\tensorflow`, and `tag` and `prefix` gets associated with the dataset (not the stats_aggregator). PiperOrigin-RevId: 215609159
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py69
-rw-r--r--tensorflow/python/data/experimental/ops/stats_ops.py17
2 files changed, 82 insertions, 4 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index 6761fbd16b..19f5a62d45 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
@@ -248,6 +249,74 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+ def testMultipleDatasetWithTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "dataset1"))
+ dataset2 = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "dataset2"))
+ iterator_0 = dataset.make_initializable_iterator()
+ iterator_1 = dataset2.make_initializable_iterator()
+ next_element = iterator_0.get_next() + iterator_1.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run([iterator_0.initializer, iterator_1.initializer])
+ for i in range(100):
+ self.assertEqual(i * 2, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset1_record_latency", float(i + 1))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset2_record_latency", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset1_record_latency", 100.0)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset2_record_latency", 100.0)
+
+
+class FeatureStatsDatasetTest(
+ stats_dataset_test_base.StatsDatasetTestBase,
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
+
+ def testFeaturesStats(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ batch_size = 2
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5,
+ drop_final_batch=False).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "record_stats"))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(total_records // batch_size + 1 if total_records %
+ batch_size else total_records // batch_size):
+ sess.run(next_element)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats_features", total_records)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats_feature-values", total_records)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats_features", total_records * 4)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats_feature-values",
+ self._sum_keywords(1) * num_epochs + 3 * total_records)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
index c918d223e8..54ef6fc3e8 100644
--- a/tensorflow/python/data/experimental/ops/stats_ops.py
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -89,15 +89,19 @@ class StatsAggregator(object):
class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
"""A `Dataset` that acts as an identity, and sets given stats_aggregator."""
- def __init__(self, input_dataset, stats_aggregator):
+ def __init__(self, input_dataset, stats_aggregator, tag, prefix):
super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._stats_aggregator = stats_aggregator
+ self._tag = tag
+ self._prefix = prefix
def _as_variant_tensor(self):
return gen_dataset_ops.set_stats_aggregator_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._stats_aggregator._resource, # pylint: disable=protected-access
+ self._tag,
+ self._prefix,
**dataset_ops.flat_structure(self))
@property
@@ -114,11 +118,15 @@ class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
@tf_export("data.experimental.set_stats_aggregator")
-def set_stats_aggregator(stats_aggregator):
+def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""):
"""Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `tf.data.experimental.StatsAggregator` object.
+ stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
+ tag: (Optional) String, all statistics recorded for the input `dataset`
+ will have given `tag` prepend with the name.
+ counter_prefix: (Optional) String, all statistics recorded as `counters`
+ will have the given `prefix` for the counter. Defaults to "/tesorflow".
Returns:
A `Dataset` transformation function, which can be passed to
@@ -126,7 +134,8 @@ def set_stats_aggregator(stats_aggregator):
"""
def _apply_fn(dataset):
- return _SetStatsAggregatorDataset(dataset, stats_aggregator)
+ return _SetStatsAggregatorDataset(dataset, stats_aggregator, tag,
+ counter_prefix)
return _apply_fn