aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/saver.py7
-rw-r--r--tensorflow/python/training/saver_test.py18
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 1ee975fbe4..11510d9928 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -126,7 +126,12 @@ class BaseSaverBuilder(object):
def f():
with ops.device(v.device):
x = v.read_value()
- with ops.device("/device:CPU:0"):
+ # To allow variables placed on non-CPU devices to be checkpointed,
+ # we copy them to CPU on the same machine first.
+ device_spec = pydev.DeviceSpec().parse_from_string(v.device)
+ device_spec.merge_from(
+ pydev.DeviceSpec().parse_from_string("/device:CPU:0"))
+ with ops.device(device_spec.to_string()):
return array_ops.identity(x)
return f
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index ae9c244aaf..ecce8ae6bd 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -174,6 +174,24 @@ class SaverTest(test.TestCase):
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
+ def testResourceColocation(self):
+ partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
+ with ops_lib.device("/job:ps/device:GPU:0"):
+ v = variable_scope.get_variable("v0",
+ shape=[10, 2],
+ partitioner=partitioner,
+ use_resource=True)
+ saver_module.Saver({"v0": v}).build()
+ save_op = None
+ for op in ops_lib.get_default_graph().get_operations():
+ if op.type == "SaveV2":
+ save_op = op
+ break
+ assert save_op is not None
+ for save_inp in save_op.inputs[3:]:
+ # Input to SaveV2 op is placed on CPU of the same device as the Variable.
+ self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
+
def testResourceVariableReadOpsAddedDeterministically(self):
graph_defs = []
num_graphs = 10