aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-04-30 12:41:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 12:43:36 -0700
commit8609ef4db1a2af0da0c2c20b26756031637de3ff (patch)
tree93b6e682c9eb03135ac507a43eeaeec0f7153076 /tensorflow
parent9d79acc6aae306e0444c193e945f0c87fe5bb509 (diff)
When a mirrored variable is fetched in cross-tower mode, fetch its primary variable.
This prevents errors like ValueError: Fetch argument MirroredVariable({'/job:localhost/replica:0/task:0/device:GPU:0': <tf.Variable 'global_step:0' shape=() dtype=int64>, '/job:localhost/replica:0/task:0/device:GPU:1': <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64>}) cannot be interpreted as a Tensor. (Device /job:localhost/replica:0/task:0/device:CPU:0 not found in ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1'] (current device )) I ran distribute/examples/resnet with and without the change and it fixed the problem. PiperOrigin-RevId: 194828672
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/distribute/python/values.py6
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py16
2 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 8cb5276579..466678ef2e 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -229,6 +229,12 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index e96ce54741..1d4e801cd8 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import saver as saver_lib
@@ -582,6 +583,21 @@ class MirroredVariableTest(test.TestCase):
save_path = self._save_normal()
self._restore_mirrored(save_path)
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testFetchAMirroredVariable(self):
+ if context.num_gpus() < 1 or context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test or it's eager mode.")
+
+ with self.test_session(
+ graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0"]).scope():
+ with ops.device("/device:GPU:0"):
+ v = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+ mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run({"complicated": mirrored})
+
_devices = ["/device:GPU:0", "/device:CPU:0"]