diff options
author | weidankong <kongweidan84@gmail.com> | 2018-08-16 11:42:59 -0700 |
---|---|---|
committer | weidankong <kongweidan84@gmail.com> | 2018-08-16 11:42:59 -0700 |
commit | 7d9a839a26b7b801ffc53eff59688672021d6a43 (patch) | |
tree | 94be49e73c99c1330998806adba7b37aaab783b6 /tensorflow/contrib/opt | |
parent | 5b2ac79e45dd217fd2954e6b170b569c42162f3f (diff) |
fix feedback/copybara failure: change to self.get_temp_dir for saving checkpoint
Diffstat (limited to 'tensorflow/contrib/opt')
-rw-r--r-- | tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py index 0db368cc4e..5bf6a08de1 100644 --- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import portpicker from tensorflow.python.client import session from tensorflow.python.framework import constant_op @@ -196,15 +197,16 @@ class ElasticAverageOptimizerTest(test.TestCase): sessions[0].run(train_ops[0]) # save, data will be global value + outfile = os.path.join(test.get_temp_dir(), "model") savers[0].save(sessions[0]._sess._sess._sess._sess, - save_path='./model/model') + save_path=outfile) ops.reset_default_graph() # restore on a new graph with session.Session() as sess: v0 = variable_scope.get_variable(initializer=0.0, name="v0") v1 = variable_scope.get_variable(initializer=1.0, name="v1") sess.run(variables.local_variables_initializer()) saver_opt = saver.Saver(var_list=[v1, v0]) - saver_opt.restore(sess, './model/model') + saver_opt.restore(sess, outfile) self.assertAllEqual(2.0, sess.run(v0)) self.assertAllEqual(3.0, sess.run(v1)) @@ -249,8 +251,9 @@ class ElasticAverageOptimizerTest(test.TestCase): # part_0 of global_center copy part_0_g = sessions[0].run(part_0_g) + outfile = os.path.join(test.get_temp_dir(), "model") savers[0].save(sessions[0]._sess._sess._sess._sess, - save_path='./model/model') + save_path=outfile) # verify restore of partitioned_variables ops.reset_default_graph() # restore on a new graph @@ -265,7 +268,7 @@ class ElasticAverageOptimizerTest(test.TestCase): shape=[2, 4], initializer=init_ops.ones_initializer) s = saver.Saver(var_list=[partition_var]) - s.restore(sess, './model/model') + s.restore(sess, outfile) part_0 = g.get_tensor_by_name('partition_var/part_0:0') self.assertAllEqual(part_0_g, sess.run(part_0)) |