aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-04-18 16:01:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-18 16:04:37 -0700
commite9d47fbff0d644a75c6f3dcdcb852685ef515b64 (patch)
tree87b63378d33241f3fff820b0f70cdd89d2957406 /tensorflow/contrib/data
parent695da2d928b5927c0a4f73e352a597a19886f2cb (diff)
Adds dataset transformation function `set_stats_aggregator(..)`, which sets the given `stats_aggregator` for aggregating the input dataset stats.
PiperOrigin-RevId: 193432590
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py67
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py61
2 files changed, 71 insertions, 57 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index 07bdf92044..7acbc676ce 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -50,17 +50,17 @@ class StatsDatasetTest(test.TestCase):
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
def testBytesProduced(self):
+ stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
- stats_ops.bytes_produced_stats("bytes_produced"))
+ stats_ops.bytes_produced_stats("bytes_produced")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
- stats_aggregator = stats_ops.StatsAggregator()
- stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
- sess.run([iterator.initializer, stats_aggregator_subscriber])
+ sess.run(iterator.initializer)
expected_sum = 0.0
for i in range(100):
self.assertAllEqual(
@@ -76,16 +76,16 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
def testLatencyStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency"))
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
- stats_aggregator = stats_ops.StatsAggregator()
- stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
- sess.run([iterator.initializer, stats_aggregator_subscriber])
+ sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
@@ -95,16 +95,15 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
def testReinitialize(self):
+ stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency"))
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
- stats_aggregator = stats_ops.StatsAggregator()
- stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
- sess.run(stats_aggregator_subscriber)
for j in range(5):
sess.run(iterator.initializer)
for i in range(100):
@@ -130,17 +129,17 @@ class StatsDatasetTest(test.TestCase):
sess.run(next_element)
def testMultipleTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
- stats_ops.latency_stats("record_latency_2"))
+ stats_ops.latency_stats("record_latency_2")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
- stats_aggregator = stats_ops.StatsAggregator()
- stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
- sess.run([iterator.initializer, stats_aggregator_subscriber])
+ sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
@@ -154,17 +153,17 @@ class StatsDatasetTest(test.TestCase):
sess.run(summary_t), "record_latency_2", 100.0)
def testRepeatedTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency")).apply(
- stats_ops.latency_stats("record_latency"))
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_initializable_iterator()
- stats_aggregator = stats_ops.StatsAggregator()
- stats_aggregator_subscriber = stats_aggregator.subscribe(iterator)
next_element = iterator.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
- sess.run([iterator.initializer, stats_aggregator_subscriber])
+ sess.run(iterator.initializer)
for i in range(100):
self.assertEqual(i, sess.run(next_element))
self._assertSummaryHasCount(
@@ -174,19 +173,17 @@ class StatsDatasetTest(test.TestCase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
def testMultipleIteratorsSameAggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency"))
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
iterator_0 = dataset.make_initializable_iterator()
iterator_1 = dataset.make_initializable_iterator()
- stats_aggregator = stats_ops.StatsAggregator()
- stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0),
- stats_aggregator.subscribe(iterator_1)]
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = stats_aggregator.get_summary()
with self.test_session() as sess:
- sess.run([iterator_0.initializer, iterator_1.initializer,
- stats_aggregator_subscribers])
+ sess.run([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, sess.run(next_element))
self._assertSummaryHasCount(
@@ -195,20 +192,6 @@ class StatsDatasetTest(test.TestCase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
- def testMultipleStatsAggregatorsSameIteratorFail(self):
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency"))
- iterator = dataset.make_initializable_iterator()
- stats_aggregator_0 = stats_ops.StatsAggregator()
- stats_aggregator_1 = stats_ops.StatsAggregator()
-
- with self.test_session() as sess:
- sess.run(stats_aggregator_0.subscribe(iterator))
- # TODO(mrry): Consider making this allowable (and also allowing
- # aggregators to unsubscribe).
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(stats_aggregator_1.subscribe(iterator))
-
class StatsDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
@@ -253,5 +236,9 @@ class StatsDatasetSerializationTest(
None, num_outputs)
+# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
+# transformation `stats_ops.set_stats_aggregator`, since we don't support
+# serializing StatsAggregator yet.
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index b5cf0fcfe9..d391720396 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
@@ -85,25 +84,53 @@ class StatsAggregator(object):
"""
return gen_dataset_ops.stats_aggregator_summary(self._resource)
- def subscribe(self, iterator):
- """Returns a @{tf.Operation} to associate this aggregator with `iterator`.
- Note: Each @{tf.data.Iterator} can be associated with at most one
- `StatsAggregator`. After running the operation that this function
- returns, all statistics recorded in the iteration of `iterator`
- will be stored in `stats_aggregator`.
+class _SetStatsAggregatorDataset(dataset_ops.Dataset):
+ """A `Dataset` that acts as an identity, and sets given stats_aggregator."""
- Args:
- iterator: A @{tf.data.Iterator} object.
+ def __init__(self, input_dataset, stats_aggregator):
+ super(_SetStatsAggregatorDataset, self).__init__()
+ self._input_dataset = input_dataset
+ self._stats_aggregator = stats_aggregator
- Returns:
- A @{tf.Operation} that, when run, associates this aggregator with
- `iterator`.
- """
- if not isinstance(iterator, iterator_ops.Iterator):
- raise TypeError("`iterator` must be a `tf.data.Iterator` object.")
- return gen_dataset_ops.iterator_set_stats_aggregator(
- iterator._iterator_resource, self._resource) # pylint: disable=protected-access
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.set_stats_aggregator_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._stats_aggregator._resource, # pylint: disable=protected-access
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+# TODO(shivaniagrawal): Expose these methods in `tf.contrib.data`.
+def set_stats_aggregator(stats_aggregator):
+ """Set the given stats_aggregator for aggregating the input dataset stats.
+
+ Args:
+ stats_aggregator: A `StatsAggregator` object.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _SetStatsAggregatorDataset(dataset, stats_aggregator)
+
+ return _apply_fn
def bytes_produced_stats(tag):