diff options
-rw-r--r-- | tensorflow/python/training/saver.py | 7 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 18 |
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 1ee975fbe4..11510d9928 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -126,7 +126,12 @@ class BaseSaverBuilder(object): def f(): with ops.device(v.device): x = v.read_value() - with ops.device("/device:CPU:0"): + # To allow variables placed on non-CPU devices to be checkpointed, + # we copy them to CPU on the same machine first. + device_spec = pydev.DeviceSpec().parse_from_string(v.device) + device_spec.merge_from( + pydev.DeviceSpec().parse_from_string("/device:CPU:0")) + with ops.device(device_spec.to_string()): return array_ops.identity(x) return f diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index ae9c244aaf..ecce8ae6bd 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -174,6 +174,24 @@ class SaverTest(test.TestCase): def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) + def testResourceColocation(self): + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2) + with ops_lib.device("/job:ps/device:GPU:0"): + v = variable_scope.get_variable("v0", + shape=[10, 2], + partitioner=partitioner, + use_resource=True) + saver_module.Saver({"v0": v}).build() + save_op = None + for op in ops_lib.get_default_graph().get_operations(): + if op.type == "SaveV2": + save_op = op + break + assert save_op is not None + for save_inp in save_op.inputs[3:]: + # Input to SaveV2 op is placed on CPU of the same device as the Variable. + self.assertEqual("/job:ps/device:CPU:0", save_inp.device) + def testResourceVariableReadOpsAddedDeterministically(self): graph_defs = [] num_graphs = 10 |