aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-08-02 16:55:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 16:59:30 -0700
commitbb3ed5ee461988f1020b9768a42ce27966ec08dc (patch)
tree0b2fdb35eda6c64dfc9a365587eee46b28c1d758 /tensorflow/contrib/data
parent4bc5c6c77daea8d5e60c22732b56495a0cd6c681 (diff)
Experimental Cl which adds `LatencyStatsDataset` op after each `Dataset` op to record latency on each edge of dataset input pipeline.
PiperOrigin-RevId: 207190025
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py31
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py27
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py44
4 files changed, 92 insertions, 25 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 2de1a79d28..751cde2b10 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -209,8 +209,11 @@ py_test(
size = "small",
srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
+ ":stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -431,8 +434,8 @@ py_test(
tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
+ ":stats_dataset_test_base",
"//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -442,6 +445,16 @@ py_test(
],
)
+py_library(
+ name = "stats_dataset_test_base",
+ srcs = ["stats_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "threadpool_dataset_ops_test",
size = "small",
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
index d8156dc9c7..2427935c73 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
@@ -19,7 +19,9 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -160,5 +162,34 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(get_next)
+class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testLatencyStatsOptimization(self):
+
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.from_tensors(1).apply(
+ optimization.assert_next(
+ ["LatencyStats", "Map", "LatencyStats", "Prefetch",
+ "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
+ optimization.optimize(["latency_all_edges"])).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertEqual(1 * 1, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_TensorDataset/_1", 1)
+ self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
+ 1)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_PrefetchDataset/_6", 1)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index b4945685c1..a41d21f8c1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.core.framework import summary_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -29,28 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class StatsDatasetTestBase(test.TestCase):
-
- def _assertSummaryHasCount(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.num)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasSum(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.sum)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
-
-class StatsDatasetTest(StatsDatasetTestBase):
+class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testBytesProduced(self):
stats_aggregator = stats_ops.StatsAggregator()
@@ -197,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase):
class FeatureStatsDatasetTest(
- StatsDatasetTestBase,
+ stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
def testFeaturesStats(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
new file mode 100644
index 0000000000..9a13acf8f0
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -0,0 +1,44 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing the input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.python.platform import test
+
+
+class StatsDatasetTestBase(test.TestCase):
+ """Base class for testing statistics gathered in `StatsAggregator`."""
+
+ def _assertSummaryHasCount(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.num)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasSum(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.sum)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))