From b5c0161db4546dd8a71239ab563cd7398c9cff2c Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Mon, 10 Sep 2018 10:49:18 -0700 Subject: Automated rollback of commit e258e52d2c4060fc26fda43e4ce068d5ba2ab1ff PiperOrigin-RevId: 212294062 --- .../python/kernel_tests/stats_dataset_ops_test.py | 25 ++++++++++++++++++++++ .../python/kernel_tests/stats_dataset_test_base.py | 10 +++++++++ 2 files changed, 35 insertions(+) (limited to 'tensorflow/contrib/data') diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 43067b4245..e25570c5ad 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -75,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): sess.run(next_element) self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + def testPrefetchBufferUtilization(self): + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch( + -1).apply(stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + float(i + 1)) + self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization", + 0, 1) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization", + 100) + def testReinitialize(self): stats_aggregator = stats_ops.StatsAggregator() dataset = dataset_ops.Dataset.range(100).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index 9a13acf8f0..2f5a44408f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase): return self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertLessEqual(min_value, value.histo.min) + self.assertGreaterEqual(max_value, value.histo.max) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + def _assertSummaryHasSum(self, summary_str, tag, expected_value): summary_proto = summary_pb2.Summary() summary_proto.ParseFromString(summary_str) -- cgit v1.2.3