aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-09-28 10:56:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 11:04:36 -0700
commit0e926947421cc47546efb7f7e2dd8505fbe0ac45 (patch)
treeb76fbbe6fbcde653ec970382909e64d98514b8a8 /tensorflow/contrib/data
parent301e3043e67493ce3777d2b36b43d0210f7b920c (diff)
[tf.data] Throws appropriate error while trying to checkpoint input pipeline with associated stats_aggregator.
PiperOrigin-RevId: 214961678
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
index 14cd3e9c4a..a10f85263a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -90,6 +91,16 @@ class StatsDatasetSerializationTest(
lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
None, num_outputs)
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
if __name__ == "__main__":
test.main()