aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-04-30 12:41:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 12:43:36 -0700
commit8609ef4db1a2af0da0c2c20b26756031637de3ff (patch)
tree93b6e682c9eb03135ac507a43eeaeec0f7153076 /tensorflow/contrib/distribute/python/values.py
parent9d79acc6aae306e0444c193e945f0c87fe5bb509 (diff)
When a mirrored variable is fetched in cross-tower mode, fetch its primary variable.
This prevents errors like ValueError: Fetch argument MirroredVariable({'/job:localhost/replica:0/task:0/device:GPU:0': <tf.Variable 'global_step:0' shape=() dtype=int64>, '/job:localhost/replica:0/task:0/device:GPU:1': <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64>}) cannot be interpreted as a Tensor. (Device /job:localhost/replica:0/task:0/device:CPU:0 not found in ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1'] (current device )) I ran distribute/examples/resnet with and without the change and it fixed the problem. PiperOrigin-RevId: 194828672
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8cb5276579..466678ef2e 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -229,6 +229,12 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass