diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-10-03 12:44:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 12:51:54 -0700 |
commit | 808b1dcb318b1feb5a8c9fed5558f95cd05728e4 (patch) | |
tree | a2286241b6a0c8cba24b1da629fa6e7db475d7dc /tensorflow/python/data | |
parent | 19833284cc8fa555115aacde350ad66652b250dc (diff) |
[data-stats] Sets user given `tag` and `counter_prefix` with `set_stats_aggregator`. `tag` would get prep-end with all the statistics recorded as summary and `counter_prefix` would set the prefix for the statistics recorded as counter.
Note: `counter` defaults to `\tensorflow`, and `tag` and `prefix` gets associated with the dataset (not the stats_aggregator).
PiperOrigin-RevId: 215609159
Diffstat (limited to 'tensorflow/python/data')
-rw-r--r-- | tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py | 69 | ||||
-rw-r--r-- | tensorflow/python/data/experimental/ops/stats_ops.py | 17 |
2 files changed, 82 insertions, 4 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py index 6761fbd16b..19f5a62d45 100644 --- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base from tensorflow.python.data.experimental.ops import stats_ops from tensorflow.python.data.ops import dataset_ops @@ -248,6 +249,74 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + def testMultipleDatasetWithTags(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator, "dataset1")) + dataset2 = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.set_stats_aggregator(stats_aggregator, "dataset2")) + iterator_0 = dataset.make_initializable_iterator() + iterator_1 = dataset2.make_initializable_iterator() + next_element = iterator_0.get_next() + iterator_1.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator_0.initializer, iterator_1.initializer]) + for i in range(100): + self.assertEqual(i * 2, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset1_record_latency", float(i + 1)) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset2_record_latency", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset1_record_latency", 100.0) + self._assertSummaryHasCount( + sess.run(summary_t), "dataset2_record_latency", 100.0) + + +class FeatureStatsDatasetTest( + stats_dataset_test_base.StatsDatasetTestBase, + reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): + + def testFeaturesStats(self): + num_epochs = 5 + total_records = num_epochs * self._num_records + batch_size = 2 + stats_aggregator = stats_ops.StatsAggregator() + dataset = self.make_batch_feature( + filenames=self.test_filenames[0], + num_epochs=num_epochs, + batch_size=batch_size, + shuffle=True, + shuffle_seed=5, + drop_final_batch=False).apply( + stats_ops.set_stats_aggregator(stats_aggregator, "record_stats")) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for _ in range(total_records // batch_size + 1 if total_records % + batch_size else total_records // batch_size): + sess.run(next_element) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats_features", total_records) + self._assertSummaryHasCount( + sess.run(summary_t), "record_stats_feature-values", total_records) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats_features", total_records * 4) + self._assertSummaryHasSum( + sess.run(summary_t), "record_stats_feature-values", + self._sum_keywords(1) * num_epochs + 3 * total_records) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py index c918d223e8..54ef6fc3e8 100644 --- a/tensorflow/python/data/experimental/ops/stats_ops.py +++ b/tensorflow/python/data/experimental/ops/stats_ops.py @@ -89,15 +89,19 @@ class StatsAggregator(object): class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset): """A `Dataset` that acts as an identity, and sets given stats_aggregator.""" - def __init__(self, input_dataset, stats_aggregator): + def __init__(self, input_dataset, stats_aggregator, tag, prefix): super(_SetStatsAggregatorDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._stats_aggregator = stats_aggregator + self._tag = tag + self._prefix = prefix def _as_variant_tensor(self): return gen_dataset_ops.set_stats_aggregator_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._stats_aggregator._resource, # pylint: disable=protected-access + self._tag, + self._prefix, **dataset_ops.flat_structure(self)) @property @@ -114,11 +118,15 @@ class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset): @tf_export("data.experimental.set_stats_aggregator") -def set_stats_aggregator(stats_aggregator): +def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""): """Set the given `stats_aggregator` for aggregating the input dataset stats. Args: - stats_aggregator: A `tf.data.experimental.StatsAggregator` object. + stats_aggregator: A `tf.contrib.data.StatsAggregator` object. + tag: (Optional) String, all statistics recorded for the input `dataset` + will have given `tag` prepend with the name. + counter_prefix: (Optional) String, all statistics recorded as `counters` + will have the given `prefix` for the counter. Defaults to "/tesorflow". Returns: A `Dataset` transformation function, which can be passed to @@ -126,7 +134,8 @@ def set_stats_aggregator(stats_aggregator): """ def _apply_fn(dataset): - return _SetStatsAggregatorDataset(dataset, stats_aggregator) + return _SetStatsAggregatorDataset(dataset, stats_aggregator, tag, + counter_prefix) return _apply_fn |