aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Shivani Agrawal <shivaniagrawal@google.com>2018-01-26 16:51:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 16:55:19 -0800
commite95537708f070a98607393a8f60bc61f1611a77b (patch)
tree871fd63a3b7bc94f638a3e2042bb7d228ce35ea7 /tensorflow
parenta977a77299f292e556ace48c75251a5a11d118ff (diff)
[tf.data] Support for initializing all the tables of the given graph.
PiperOrigin-RevId: 183466905
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py19
2 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 1cf0202fd8..04a21f2b0f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -126,6 +126,7 @@ py_library(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:lookup_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
index 7cde6e05b2..701fc8247e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py
@@ -27,6 +27,7 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@@ -235,8 +236,7 @@ class DatasetSerializationTestBase(test.TestCase):
ds_fn, sparse_tensors=sparse_tensors)
with self.test_session(graph=g) as sess:
self._restore(saver, sess)
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
+ self._initialize(init_op, sess)
for _ in range(num_outputs):
actual.append(sess.run(get_next_op))
if verify_exhausted:
@@ -390,8 +390,7 @@ class DatasetSerializationTestBase(test.TestCase):
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
+ self._initialize(init_op, sess)
for _ in range(break_point):
sess.run(get_next_op)
with self.assertRaises(error):
@@ -493,12 +492,10 @@ class DatasetSerializationTestBase(test.TestCase):
with self.test_session(graph=g) as sess:
if ckpt_saved:
if init_before_restore:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
+ self._initialize(init_op, sess)
self._restore(saver, sess)
else:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
+ self._initialize(init_op, sess)
start = break_points[i - 1] if i > 0 else 0
end = break_points[i] if i < len(break_points) else num_outputs
num_iters = end - start
@@ -621,8 +618,14 @@ class DatasetSerializationTestBase(test.TestCase):
saver.save(sess, self._ckpt_path())
def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
saver.restore(sess, self._latest_ckpt())
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
def _import_meta_graph(self):
meta_file_path = self._ckpt_path() + ".meta"
return saver_lib.import_meta_graph(meta_file_path)