diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-09-21 08:33:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 08:37:28 -0700 |
commit | bf574689f1aa631bbbb19801050152690772108d (patch) | |
tree | 74e103849722252fb8979bd5e33f19f22fae09a2 /tensorflow/contrib/data | |
parent | df31f7a0d6570055eebf0a19449aecbecdb748fa (diff) |
[data-stats] Collects prefetch `buffer_size` and `buffer_capacity` as scalar, if stats_aggregator is associated with dataset.
PiperOrigin-RevId: 213989745
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py | 24 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py | 8 |
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 14c5cffdf4..be8ae5e955 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -93,6 +93,8 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): summary_str = sess.run(summary_t) self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", float(i + 1)) + self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity") + self._assertSummaryContains(summary_str, "Prefetch::buffer_size") self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization", 0, 1) with self.assertRaises(errors.OutOfRangeError): @@ -101,6 +103,28 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", 100) + def testPrefetchBufferScalars(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(10).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + 0).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.cached_session() as sess: + sess.run(iterator.initializer) + for i in range(10): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasScalarValue(summary_str, + "Prefetch::buffer_capacity", 0) + self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size", + 0) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + def testFilteredElementsStats(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(101).filter( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index 6951564091..b1b4c23510 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -25,6 +25,14 @@ from tensorflow.python.platform import test class StatsDatasetTestBase(test.TestCase): """Base class for testing statistics gathered in `StatsAggregator`.""" + def _assertSummaryContains(self, summary_str, tag): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasCount(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() summary_proto.ParseFromString(summary_str) |