aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-09-10 10:49:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 10:59:00 -0700
commitb5c0161db4546dd8a71239ab563cd7398c9cff2c (patch)
treecbd1d379d350e61ca4c17eebb90315772cd289df /tensorflow/contrib/data
parenta0bec62c0219e143a8b0d8e3dd3fb5b577db388e (diff)
Automated rollback of commit e258e52d2c4060fc26fda43e4ce068d5ba2ab1ff
PiperOrigin-RevId: 212294062
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py25
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py10
2 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 43067b4245..e25570c5ad 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -75,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+ def testPrefetchBufferUtilization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ float(i + 1))
+ self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+ 0, 1)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ 100)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 9a13acf8f0..2f5a44408f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+ def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertLessEqual(min_value, value.histo.min)
+ self.assertGreaterEqual(max_value, value.histo.max)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasSum(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)