aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py')
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py692
1 files changed, 692 insertions, 0 deletions
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
new file mode 100644
index 0000000000..7f435b8239
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
@@ -0,0 +1,692 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing serializable datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
+
+
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
+class DatasetSerializationTestBase(test.TestCase):
+ """Base class for testing serializable datasets."""
+
+ def tearDown(self):
+ self._delete_ckpt()
+
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
+ def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+ """Runs the core tests.
+
+ Args:
+ ds_fn1: 0-argument function that returns a Dataset.
+ ds_fn2: 0-argument function that returns a Dataset different from
+ ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+ num_outputs: Total number of outputs expected from this Dataset.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_unused_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_fully_used_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_exhausted_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_init_before_restore(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_multiple_breaks(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_reset_restored_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_restore_in_empty_graph(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ if ds_fn2:
+ self.verify_restore_in_modified_graph(
+ ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_unused_iterator(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that saving and restoring an unused iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [0],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_fully_used_iterator(self, ds_fn, num_outputs,
+ sparse_tensors=False):
+ """Verifies that saving and restoring a fully used iterator works.
+
+ Note that this only checks saving and restoring an iterator from which
+ `num_outputs` items have been produced but does not check for an
+ exhausted iterator, i.e., one from which an OutOfRange error has been
+ returned.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
+ """Verifies that saving and restoring an exhausted iterator works.
+
+ An exhausted iterator is one which has returned an OutOfRange error.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ actual = self.gen_outputs(
+ ds_fn, [],
+ 0,
+ ckpt_saved=True,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ self.assertEqual(len(actual), 0)
+
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that restoring into an already initialized iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ init_before_restore=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to save/restore at multiple break points.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ num_breaks: The number of break points. These are uniformly spread in
+ [0, num_outputs] both inclusive.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to re-initialize a restored iterator.
+
+ This is useful when restoring a training checkpoint during validation.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Collect ground truth containing all outputs.
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Skip some items and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Restore from checkpoint and then run init_op.
+ with ops.Graph().as_default() as g:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ self._initialize(init_op, sess)
+ for _ in range(num_outputs):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ self.match(expected, actual)
+
+ def verify_restore_in_modified_graph(self,
+ ds_fn1,
+ ds_fn2,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in a modified graph.
+
+ Builds an input pipeline using ds_fn1, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new graph using ds_fn2, restores
+ the checkpoint from ds_fn1 and verifies that the restore is successful.
+
+ Args:
+ ds_fn1: See `run_core_tests`.
+ ds_fn2: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn1
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn1, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn1 and save checkpoint.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build graph for ds_fn2 but load checkpoint for ds_fn1.
+ with ops.Graph().as_default() as g:
+ _, get_next_op, saver = self._build_graph(
+ ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_restore_in_empty_graph(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in an empty graph.
+
+ Builds an input pipeline using ds_fn, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new empty graph, restores
+ the checkpoint from ds_fn and verifies that the restore is successful.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build an empty graph but load checkpoint for ds_fn.
+ with ops.Graph().as_default() as g:
+ get_next_op, saver = self._build_empty_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_error_on_save(self,
+ ds_fn,
+ num_outputs,
+ error,
+ break_point=None,
+ sparse_tensors=False):
+ """Attempts to save a non-saveable iterator.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ error: Declared error when trying to save iterator.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+
+ break_point = num_outputs // 2 if not break_point else break_point
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._initialize(init_op, sess)
+ for _ in range(break_point):
+ sess.run(get_next_op)
+ with self.assertRaises(error):
+ self._save(sess, saver)
+
+ def verify_run_with_breaks(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that ds_fn() produces the same outputs with and without breaks.
+
+ 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ *without* stopping at break points.
+ 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ with stopping at break points.
+
+ Deep matches outputs from 1 and 2.
+
+ Args:
+ ds_fn: See `gen_outputs`.
+ break_points: See `gen_outputs`.
+ num_outputs: See `gen_outputs`.
+ init_before_restore: See `gen_outputs`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ self.match(expected, actual)
+
+ def gen_outputs(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ ckpt_saved=False,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
+ """Generates elements from input dataset while stopping at break points.
+
+ Produces `num_outputs` outputs and saves the state of the iterator in the
+ Saver checkpoint.
+
+ Args:
+ ds_fn: 0-argument function that returns the dataset.
+ break_points: A list of integers. For each `break_point` in
+ `break_points`, we produce outputs till `break_point` number of items
+ have been produced and then checkpoint the state. The current graph
+ and session are destroyed and a new graph and session are used to
+ produce outputs till next checkpoint or till `num_outputs` elements
+ have been produced. `break_point` must be <= `num_outputs`.
+ num_outputs: The total number of outputs to produce from the iterator.
+ ckpt_saved: Whether a checkpoint already exists. If False, we build the
+ graph from ds_fn.
+ init_before_restore: Whether init should be called before saver.restore.
+ This is just so that we can verify that restoring an already initialized
+ iterator works.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+ verify_exhausted: Whether to verify that the iterator has been exhausted
+ after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
+
+ Returns:
+ A list of `num_outputs` items.
+ """
+ outputs = []
+
+ def get_ops():
+ if ckpt_saved:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ else:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ return init_op, get_next_op, saver
+
+ for i in range(len(break_points) + 1):
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ if ckpt_saved:
+ if init_before_restore:
+ self._initialize(init_op, sess)
+ self._restore(saver, sess)
+ else:
+ self._initialize(init_op, sess)
+ start = break_points[i - 1] if i > 0 else 0
+ end = break_points[i] if i < len(break_points) else num_outputs
+ num_iters = end - start
+ for _ in range(num_iters):
+ outputs.append(sess.run(get_next_op))
+ if i == len(break_points) and verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
+
+ return outputs
+
+ def match(self, expected, actual):
+ """Matches nested structures.
+
+ Recursively matches shape and values of `expected` and `actual`.
+ Handles scalars, numpy arrays and other python sequence containers
+ e.g. list, dict.
+
+ Args:
+ expected: Nested structure 1.
+ actual: Nested structure 2.
+
+ Raises:
+ AssertionError if matching fails.
+ """
+ if isinstance(expected, np.ndarray):
+ expected = expected.tolist()
+ if isinstance(actual, np.ndarray):
+ actual = actual.tolist()
+ self.assertEqual(type(expected), type(actual))
+
+ if nest.is_sequence(expected):
+ self.assertEqual(len(expected), len(actual))
+ if isinstance(expected, dict):
+ for key1, key2 in zip(sorted(expected), sorted(actual)):
+ self.assertEqual(key1, key2)
+ self.match(expected[key1], actual[key2])
+ else:
+ for item1, item2 in zip(expected, actual):
+ self.match(item1, item2)
+ else:
+ self.assertEqual(expected, actual)
+
+ def does_not_match(self, expected, actual):
+ with self.assertRaises(AssertionError):
+ self.match(expected, actual)
+
+ def gen_break_points(self, num_outputs, num_samples=10):
+ """Generates `num_samples` breaks points in [0, num_outputs]."""
+ return np.linspace(0, num_outputs, num_samples, dtype=int)
+
+ def _build_graph(self, ds_fn, sparse_tensors=False):
+ iterator = ds_fn().make_initializable_iterator()
+
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ init_op = iterator.initializer
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
+ saver = saver_lib.Saver(allow_empty=True)
+ return init_op, get_next, saver
+
+ def _build_empty_graph(self, ds_fn, sparse_tensors=False):
+ iterator = iterator_ops.Iterator.from_structure(
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ saver = saver_lib.Saver(allow_empty=True)
+ return get_next, saver
+
+ def _add_iterator_ops_to_collection(self,
+ init_op,
+ get_next,
+ ds_fn,
+ sparse_tensors=False):
+ ops.add_to_collection("iterator_ops", init_op)
+ # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
+ # do not support tuples we flatten the tensors and restore the shape in
+ # `_get_iterator_ops_from_collection`.
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ ops.add_to_collection("iterator_ops", get_next.indices)
+ ops.add_to_collection("iterator_ops", get_next.values)
+ ops.add_to_collection("iterator_ops", get_next.dense_shape)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
+
+ def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
+ all_ops = ops.get_collection("iterator_ops")
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ init_op, indices, values, dense_shape = all_ops
+ return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
+
+ def _get_output_types(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_types
+
+ def _get_output_shapes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_shapes
+
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _latest_ckpt(self):
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
+
+ def _save(self, sess, saver):
+ saver.save(sess, self._ckpt_path())
+
+ def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
+ saver.restore(sess, self._latest_ckpt())
+
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
+ def _import_meta_graph(self):
+ meta_file_path = self._ckpt_path() + ".meta"
+ return saver_lib.import_meta_graph(meta_file_path)
+
+ def _delete_ckpt(self):
+ # Remove all checkpoint files.
+ prefix = self._ckpt_path()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)