aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/opt
diff options
context:
space:
mode:
authorGravatar weidankong <kongweidan84@gmail.com>2018-08-16 11:42:59 -0700
committerGravatar weidankong <kongweidan84@gmail.com>2018-08-16 11:42:59 -0700
commit7d9a839a26b7b801ffc53eff59688672021d6a43 (patch)
tree94be49e73c99c1330998806adba7b37aaab783b6 /tensorflow/contrib/opt
parent5b2ac79e45dd217fd2954e6b170b569c42162f3f (diff)
fix feedback/copybara failure: change to self.get_temp_dir for saving checkpoint
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
index 0db368cc4e..5bf6a08de1 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import portpicker
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
@@ -196,15 +197,16 @@ class ElasticAverageOptimizerTest(test.TestCase):
sessions[0].run(train_ops[0])
# save, data will be global value
+ outfile = os.path.join(test.get_temp_dir(), "model")
savers[0].save(sessions[0]._sess._sess._sess._sess,
- save_path='./model/model')
+ save_path=outfile)
ops.reset_default_graph() # restore on a new graph
with session.Session() as sess:
v0 = variable_scope.get_variable(initializer=0.0, name="v0")
v1 = variable_scope.get_variable(initializer=1.0, name="v1")
sess.run(variables.local_variables_initializer())
saver_opt = saver.Saver(var_list=[v1, v0])
- saver_opt.restore(sess, './model/model')
+ saver_opt.restore(sess, outfile)
self.assertAllEqual(2.0, sess.run(v0))
self.assertAllEqual(3.0, sess.run(v1))
@@ -249,8 +251,9 @@ class ElasticAverageOptimizerTest(test.TestCase):
# part_0 of global_center copy
part_0_g = sessions[0].run(part_0_g)
+ outfile = os.path.join(test.get_temp_dir(), "model")
savers[0].save(sessions[0]._sess._sess._sess._sess,
- save_path='./model/model')
+ save_path=outfile)
# verify restore of partitioned_variables
ops.reset_default_graph() # restore on a new graph
@@ -265,7 +268,7 @@ class ElasticAverageOptimizerTest(test.TestCase):
shape=[2, 4],
initializer=init_ops.ones_initializer)
s = saver.Saver(var_list=[partition_var])
- s.restore(sess, './model/model')
+ s.restore(sess, outfile)
part_0 = g.get_tensor_by_name('partition_var/part_0:0')
self.assertAllEqual(part_0_g, sess.run(part_0))