aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-19 15:58:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 16:07:36 -0700
commit8f130ff5b021efb94946ed9deb1341890763fd3f (patch)
treee32a2b674caaf746ffaa70ace6383f30fbc0e120 /tensorflow
parent3647625e531e713ad9a7fb0f3c5b68863ae4e7b8 (diff)
Fix ResourceVariable placement during checkpointing to correctly colocate the
copy of the variable on the same machine. Addresses Issue #20914. PiperOrigin-RevId: 205317119
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/training/saver.py7
-rw-r--r--tensorflow/python/training/saver_test.py18
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 1ee975fbe4..11510d9928 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -126,7 +126,12 @@ class BaseSaverBuilder(object):
def f():
with ops.device(v.device):
x = v.read_value()
- with ops.device("/device:CPU:0"):
+ # To allow variables placed on non-CPU devices to be checkpointed,
+ # we copy them to CPU on the same machine first.
+ device_spec = pydev.DeviceSpec().parse_from_string(v.device)
+ device_spec.merge_from(
+ pydev.DeviceSpec().parse_from_string("/device:CPU:0"))
+ with ops.device(device_spec.to_string()):
return array_ops.identity(x)
return f
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index ae9c244aaf..ecce8ae6bd 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -174,6 +174,24 @@ class SaverTest(test.TestCase):
def testResourceBasic(self):
self.basicSaveRestore(resource_variable_ops.ResourceVariable)
+ def testResourceColocation(self):
+ partitioner = partitioned_variables.fixed_size_partitioner(num_shards=2)
+ with ops_lib.device("/job:ps/device:GPU:0"):
+ v = variable_scope.get_variable("v0",
+ shape=[10, 2],
+ partitioner=partitioner,
+ use_resource=True)
+ saver_module.Saver({"v0": v}).build()
+ save_op = None
+ for op in ops_lib.get_default_graph().get_operations():
+ if op.type == "SaveV2":
+ save_op = op
+ break
+ assert save_op is not None
+ for save_inp in save_op.inputs[3:]:
+ # Input to SaveV2 op is placed on CPU of the same device as the Variable.
+ self.assertEqual("/job:ps/device:CPU:0", save_inp.device)
+
def testResourceVariableReadOpsAddedDeterministically(self):
graph_defs = []
num_graphs = 10