aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-12-07 12:10:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-07 12:13:51 -0800
commit9620b2df63854538357bf41f4d9761499e8e573d (patch)
tree5ac6a251072a175fce6edff47005741da510cac4
parente35da0b306af78a08d9dc313aed2c9acbbab194c (diff)
Fix the issue with shared saver on GPU.
`ShardedFilename` and ``MergeV2Checkpoints/checkpoint_prefixes` operations were placed on GPU even though there are no GPU kernels for them. PiperOrigin-RevId: 178276605
-rw-r--r--tensorflow/python/BUILD15
-rw-r--r--tensorflow/python/training/monitored_session_test.py14
-rw-r--r--tensorflow/python/training/saver.py4
-rw-r--r--tensorflow/python/training/saver_test.py17
4 files changed, 27 insertions, 23 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index bd8ef6944c..af99754776 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3634,6 +3634,7 @@ cuda_py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/ops:dataset_ops",
],
+ tags = ["multi_gpu"],
)
py_test(
@@ -3787,11 +3788,16 @@ py_test(
],
)
-cuda_py_test(
+py_test(
name = "monitored_session_test",
size = "medium",
srcs = ["training/monitored_session_test.py"],
- additional_deps = [
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_windows",
+ "notsan", # b/67945581
+ ],
+ deps = [
":array_ops",
":client_testlib",
":control_flow_ops",
@@ -3806,11 +3812,6 @@ cuda_py_test(
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/core:protos_all_py",
],
- tags = [
- "multi_gpu",
- "no_windows",
- "notsan", # b/67945581
- ],
)
py_test(
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 349d8537cb..159b2d5c16 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -36,7 +36,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
@@ -1969,19 +1968,6 @@ class MonitoredSessionTest(test.TestCase):
self.assertEqual(2, trace_the_exception['side_effect_counter'])
self.assertNear(0.62, session.run(graph_state), 0.1)
- def test_saver_on_a_gpu(self):
- if not test_util.is_gpu_available():
- return
- with ops.Graph().as_default():
- with self.test_session():
- with ops.device('/gpu:0'):
- variables.Variable(0)
- saver_lib.Saver()
-
- # TODO(b/36964652): Reproduces the issue that needs to be fixed.
- with self.assertRaises(errors_impl.InvalidArgumentError):
- monitored_session.MonitoredSession()
-
class SingularMonitoredSessionTest(test.TestCase):
"""Tests SingularMonitoredSession."""
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index bd47736d4b..ba6301e785 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -349,7 +349,7 @@ class BaseSaverBuilder(object):
last_device = None
for shard, (device, saveables) in enumerate(per_device):
last_device = device
- with ops.device(device):
+ with ops.device(_set_cpu0(device)):
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
num_shards_tensor)
sharded_prefixes.append(sharded_filename)
@@ -357,7 +357,7 @@ class BaseSaverBuilder(object):
with ops.control_dependencies([x.op for x in sharded_saves]):
# Co-locates the merge step with the last device.
- with ops.device(last_device):
+ with ops.device(_set_cpu0(last_device)):
# V2 format write path consists of a metadata merge step. Once merged,
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
merge_step = gen_io_ops.merge_v2_checkpoints(
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index ffe933bb0f..207e4a2842 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -542,6 +542,23 @@ class SaverTest(test.TestCase):
save = saver_module.Saver({"v0": v0_2})
variables.global_variables_initializer().run()
+ def testSharedServerOnGPU(self):
+ if not test.is_gpu_available():
+ return
+ save_path = os.path.join(self.get_temp_dir(), "gpu")
+ with session.Session("", graph=ops_lib.Graph()) as sess:
+ with sess.graph.device(test.gpu_device_name()):
+ v0_1 = variables.Variable(123.45)
+ save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
+ variables.global_variables_initializer().run()
+ save.save(sess, save_path)
+
+ with session.Session("", graph=ops_lib.Graph()) as sess:
+ with sess.graph.device(test.gpu_device_name()):
+ v0_2 = variables.Variable(543.21)
+ save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
+ variables.global_variables_initializer().run()
+
def testVariables(self):
save_path = os.path.join(self.get_temp_dir(), "variables")
with session.Session("", graph=ops_lib.Graph()) as sess: