diff options
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r-- | tensorflow/python/training/saver_test.py | 7 |
1 files changed, 2 insertions, 5 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 0afc1ba70f..0a14af04de 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -413,20 +413,17 @@ class SaverTest(test.TestCase): return save_path = os.path.join(self.get_temp_dir(), "gpu") with session.Session("", graph=ops_lib.Graph()) as sess: - with sess.graph.device("/gpu:0"): + with sess.graph.device(test.gpu_device_name()): v0_1 = variables.Variable(123.45) save = saver_module.Saver({"v0": v0_1}) variables.global_variables_initializer().run() save.save(sess, save_path) with session.Session("", graph=ops_lib.Graph()) as sess: - with sess.graph.device("/gpu:0"): + with sess.graph.device(test.gpu_device_name()): v0_2 = variables.Variable(543.21) save = saver_module.Saver({"v0": v0_2}) variables.global_variables_initializer().run() - self.assertAllClose(543.21, v0_2.eval()) - save.restore(sess, save_path) - self.assertAllClose(123.45, v0_2.eval()) def testVariables(self): save_path = os.path.join(self.get_temp_dir(), "variables") |