aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar Saurabh Saxena <srbs@google.com>2017-10-23 11:07:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-23 11:11:17 -0700
commit1038927c096ecc81ca48665871d1be390444b121 (patch)
tree40c7ff20843bc62f248153b5b85e8116e16c3f4c /tensorflow/python/kernel_tests
parent57f3e529d935e6b08a6c0a3a418ad367d9314fde (diff)
Add SerializeIterator op that serializes an IteratorResource into a variant tensor.
Add DeserializeIterator op that builds IteratorResource from a variant tensor. Move BundleReaderWrapper and BundleWriterWrapper from dataset.h to iterator_ops.cc. Add generic key-value store interfaces IteratorStateReader and IteratorStateWriter for reading/writing state of iterators. Get rid of IteratorBundleReader and IteratorBundleWriter. PiperOrigin-RevId: 173140858
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/BUILD5
-rw-r--r--tensorflow/python/kernel_tests/iterator_ops_test.py29
-rw-r--r--tensorflow/python/kernel_tests/range_dataset_op_test.py67
-rw-r--r--tensorflow/python/kernel_tests/reader_dataset_ops_test.py26
4 files changed, 79 insertions, 48 deletions
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 0e36c3498a..b02bae95fd 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -2886,7 +2886,9 @@ tf_py_test(
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
@@ -2907,7 +2909,9 @@ tf_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
@@ -3022,6 +3026,7 @@ tf_py_test(
"//tensorflow/python:function",
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index b5ec9f7db0..2128ef4ae1 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@@ -538,9 +539,23 @@ class IteratorTest(test.TestCase):
def testIncorrectIteratorRestore(self):
- def _iterator_checkpoint_prefix():
+ def _path():
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def _build_range_dataset_graph():
start = 1
stop = 10
@@ -548,22 +563,18 @@ class IteratorTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = _iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
def _build_reader_dataset_graph():
filenames = ["test"] # Does not exist but we don't care in this test.
- path = _iterator_checkpoint_prefix()
iterator = readers.FixedLengthRecordDataset(
filenames, 1, 0, 0).make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
# Saving iterator for RangeDataset graph.
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 8291967155..0c530522b8 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -27,6 +27,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -169,6 +171,21 @@ class RangeDatasetTest(test.TestCase):
def _iterator_checkpoint_prefix(self):
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def testSaveRestore(self):
def _build_graph(start, stop):
@@ -176,10 +193,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -222,14 +237,13 @@ class RangeDatasetTest(test.TestCase):
def testRestoreWithoutBuildingDatasetGraph(self):
- def _build_graph(start, stop, num_epochs, path):
+ def _build_graph(start, stop, num_epochs):
dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -238,10 +252,8 @@ class RangeDatasetTest(test.TestCase):
num_epochs = 5
break_point = 5
break_epoch = 3
- path = self._iterator_checkpoint_prefix()
with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs,
- path)
+ init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
@@ -258,8 +270,7 @@ class RangeDatasetTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types,
output_shapes)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ restore_op = self._restore_op(iterator._iterator_resource)
get_next = iterator.get_next()
with self.test_session(graph=g) as sess:
sess.run(restore_op)
@@ -278,10 +289,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -319,10 +328,8 @@ class RangeDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
# Saving and restoring in different sessions.
@@ -355,10 +362,8 @@ class RangeDatasetTest(test.TestCase):
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@@ -400,10 +405,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
@@ -447,10 +450,8 @@ class RangeDatasetTest(test.TestCase):
start, stop).repeat(num_epochs).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- path = self._iterator_checkpoint_prefix()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next, save_op, restore_op
start = 2
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index 38420328ef..c8e7333b4b 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -31,6 +31,8 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -273,18 +275,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
def _iterator_checkpoint_path(self):
return os.path.join(self.get_temp_dir(), "iterator")
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_path(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
def _build_iterator_graph(self, num_epochs):
filenames = self._createFiles()
- path = self._iterator_checkpoint_path()
dataset = (readers.FixedLengthRecordDataset(
filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
.repeat(num_epochs))
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next_op = iterator.get_next()
- save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path)
- restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource,
- path)
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
return init_op, get_next_op, save_op, restore_op
def _restore_iterator(self):
@@ -292,8 +307,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
output_shapes = tensor_shape.scalar()
iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
get_next = iterator.get_next()
- restore_op = gen_dataset_ops.restore_iterator(
- iterator._iterator_resource, self._iterator_checkpoint_path())
+ restore_op = self._restore_op(iterator._iterator_resource)
return restore_op, get_next
def testSaveRestore(self):