diff options
author | 2018-07-19 15:58:10 -0700 | |
---|---|---|
committer | 2018-07-19 16:07:36 -0700 | |
commit | 8f130ff5b021efb94946ed9deb1341890763fd3f (patch) | |
tree | e32a2b674caaf746ffaa70ace6383f30fbc0e120 /tensorflow | |
parent | 3647625e531e713ad9a7fb0f3c5b68863ae4e7b8 (diff) |
Fix ResourceVariable placement during checkpointing to correctly colocate the
copy of the variable on the same machine. Addresses Issue #20914.
PiperOrigin-RevId: 205317119
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/training/saver.py | 7 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 18 |
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 1ee975fbe4..11510d9928 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -126,7 +126,12 @@ class BaseSaverBuilder(object): def f(): with ops.device(v.device): x = v.read_value() - with ops.device("/device:CPU:0"): + # To allow variables placed on non-CPU devices to be checkpointed, + # we copy them to CPU on the same machine first. + device_spec = pydev.DeviceSpec().parse_from_string(v.device) + device_spec.merge_from( + pydev.DeviceSpec().parse_from_string("/device:CPU:0")) + with ops.device(device_spec.to_string()): return array_ops.identity(x) return f diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index ae9c244aaf..ecce8ae6bd 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -174,6 +174,24 @@ class SaverTest(test.TestCase): def testResourceBasic(self): self.basicSaveRestore(resource_variable_ops.ResourceVariable) + def testResourceColocation(self): + partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2) + with ops_lib.device("/job:ps/device:GPU:0"): + v = variable_scope.get_variable("v0", + shape=[10, 2], + partitioner=partitioner, + use_resource=True) + saver_module.Saver({"v0": v}).build() + save_op = None + for op in ops_lib.get_default_graph().get_operations(): + if op.type == "SaveV2": + save_op = op + break + assert save_op is not None + for save_inp in save_op.inputs[3:]: + # Input to SaveV2 op is placed on CPU of the same device as the Variable. + self.assertEqual("/job:ps/device:CPU:0", save_inp.device) + def testResourceVariableReadOpsAddedDeterministically(self): graph_defs = [] num_graphs = 10 |