diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-09-28 10:56:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 11:04:36 -0700 |
commit | 0e926947421cc47546efb7f7e2dd8505fbe0ac45 (patch) | |
tree | b76fbbe6fbcde653ec970382909e64d98514b8a8 /tensorflow/contrib/data | |
parent | 301e3043e67493ce3777d2b36b43d0210f7b920c (diff) |
[tf.data] Throws appropriate error while trying to checkpoint input pipeline with associated stats_aggregator.
PiperOrigin-RevId: 214961678
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py index 14cd3e9c4a..a10f85263a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -90,6 +91,16 @@ class StatsDatasetSerializationTest( lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), None, num_outputs) + def _build_dataset_stats_aggregator(self): + stats_aggregator = stats_ops.StatsAggregator() + return dataset_ops.Dataset.range(10).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + + def test_set_stats_aggregator_not_support_checkpointing(self): + with self.assertRaisesRegexp(errors.UnimplementedError, + "does not support checkpointing"): + self.run_core_tests(self._build_dataset_stats_aggregator, None, 10) + if __name__ == "__main__": test.main() |