aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-09-21 08:33:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 08:37:28 -0700
commitbf574689f1aa631bbbb19801050152690772108d (patch)
tree74e103849722252fb8979bd5e33f19f22fae09a2 /tensorflow/contrib/data
parentdf31f7a0d6570055eebf0a19449aecbecdb748fa (diff)
[data-stats] Collects prefetch `buffer_size` and `buffer_capacity` as scalar, if stats_aggregator is associated with dataset.
PiperOrigin-RevId: 213989745
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py24
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py8
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 14c5cffdf4..be8ae5e955 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -93,6 +93,8 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
summary_str = sess.run(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
@@ -101,6 +103,28 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
+ def testPrefetchBufferScalars(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasScalarValue(summary_str,
+ "Prefetch::buffer_capacity", 0)
+ self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+ 0)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testFilteredElementsStats(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(101).filter(
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 6951564091..b1b4c23510 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -25,6 +25,14 @@ from tensorflow.python.platform import test
class StatsDatasetTestBase(test.TestCase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
+ def _assertSummaryContains(self, summary_str, tag):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasCount(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)