diff options
author | 2018-01-26 16:51:25 -0800 | |
---|---|---|
committer | 2018-01-26 16:55:19 -0800 | |
commit | e95537708f070a98607393a8f60bc61f1611a77b (patch) | |
tree | 871fd63a3b7bc94f638a3e2042bb7d228ce35ea7 /tensorflow | |
parent | a977a77299f292e556ace48c75251a5a11d118ff (diff) |
[tf.data] Support for initializing all the tables of the given graph.
PiperOrigin-RevId: 183466905
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py | 19 |
2 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1cf0202fd8..04a21f2b0f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -126,6 +126,7 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index 7cde6e05b2..701fc8247e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -27,6 +27,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -235,8 +236,7 @@ class DatasetSerializationTestBase(test.TestCase): ds_fn, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: self._restore(saver, sess) - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) if verify_exhausted: @@ -390,8 +390,7 @@ class DatasetSerializationTestBase(test.TestCase): init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) for _ in range(break_point): sess.run(get_next_op) with self.assertRaises(error): @@ -493,12 +492,10 @@ class DatasetSerializationTestBase(test.TestCase): with self.test_session(graph=g) as sess: if ckpt_saved: if init_before_restore: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) self._restore(saver, sess) else: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len(break_points) else num_outputs num_iters = end - start @@ -621,8 +618,14 @@ class DatasetSerializationTestBase(test.TestCase): saver.save(sess, self._ckpt_path()) def _restore(self, saver, sess): + sess.run(lookup_ops.tables_initializer()) saver.restore(sess, self._latest_ckpt()) + def _initialize(self, init_op, sess): + sess.run(variables.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + sess.run(init_op) + def _import_meta_graph(self): meta_file_path = self._ckpt_path() + ".meta" return saver_lib.import_meta_graph(meta_file_path) |