diff options
-rw-r--r-- | tensorflow/python/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/python/training/monitored_session_test.py | 14 | ||||
-rw-r--r-- | tensorflow/python/training/saver.py | 4 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 17 |
4 files changed, 27 insertions, 23 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index bd8ef6944c..af99754776 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3634,6 +3634,7 @@ cuda_py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python/data/ops:dataset_ops", ], + tags = ["multi_gpu"], ) py_test( @@ -3787,11 +3788,16 @@ py_test( ], ) -cuda_py_test( +py_test( name = "monitored_session_test", size = "medium", srcs = ["training/monitored_session_test.py"], - additional_deps = [ + srcs_version = "PY2AND3", + tags = [ + "no_windows", + "notsan", # b/67945581 + ], + deps = [ ":array_ops", ":client_testlib", ":control_flow_ops", @@ -3806,11 +3812,6 @@ cuda_py_test( "//tensorflow/contrib/testing:testing_py", "//tensorflow/core:protos_all_py", ], - tags = [ - "multi_gpu", - "no_windows", - "notsan", # b/67945581 - ], ) py_test( diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py index 349d8537cb..159b2d5c16 100644 --- a/tensorflow/python/training/monitored_session_test.py +++ b/tensorflow/python/training/monitored_session_test.py @@ -36,7 +36,6 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops @@ -1969,19 +1968,6 @@ class MonitoredSessionTest(test.TestCase): self.assertEqual(2, trace_the_exception['side_effect_counter']) self.assertNear(0.62, session.run(graph_state), 0.1) - def test_saver_on_a_gpu(self): - if not test_util.is_gpu_available(): - return - with ops.Graph().as_default(): - with self.test_session(): - with ops.device('/gpu:0'): - variables.Variable(0) - saver_lib.Saver() - - # TODO(b/36964652): Reproduces the issue that needs to be fixed. - with self.assertRaises(errors_impl.InvalidArgumentError): - monitored_session.MonitoredSession() - class SingularMonitoredSessionTest(test.TestCase): """Tests SingularMonitoredSession.""" diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index bd47736d4b..ba6301e785 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -349,7 +349,7 @@ class BaseSaverBuilder(object): last_device = None for shard, (device, saveables) in enumerate(per_device): last_device = device - with ops.device(device): + with ops.device(_set_cpu0(device)): sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor) sharded_prefixes.append(sharded_filename) @@ -357,7 +357,7 @@ class BaseSaverBuilder(object): with ops.control_dependencies([x.op for x in sharded_saves]): # Co-locates the merge step with the last device. - with ops.device(last_device): + with ops.device(_set_cpu0(last_device)): # V2 format write path consists of a metadata merge step. Once merged, # attempts to delete the temporary directory, "<user-fed prefix>_temp". merge_step = gen_io_ops.merge_v2_checkpoints( diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index ffe933bb0f..207e4a2842 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -542,6 +542,23 @@ class SaverTest(test.TestCase): save = saver_module.Saver({"v0": v0_2}) variables.global_variables_initializer().run() + def testSharedServerOnGPU(self): + if not test.is_gpu_available(): + return + save_path = os.path.join(self.get_temp_dir(), "gpu") + with session.Session("", graph=ops_lib.Graph()) as sess: + with sess.graph.device(test.gpu_device_name()): + v0_1 = variables.Variable(123.45) + save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True) + variables.global_variables_initializer().run() + save.save(sess, save_path) + + with session.Session("", graph=ops_lib.Graph()) as sess: + with sess.graph.device(test.gpu_device_name()): + v0_2 = variables.Variable(543.21) + save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True) + variables.global_variables_initializer().run() + def testVariables(self): save_path = os.path.join(self.get_temp_dir(), "variables") with session.Session("", graph=ops_lib.Graph()) as sess: |