diff options
author | Igor Saprykin <isaprykin@google.com> | 2018-04-30 12:41:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-30 12:43:36 -0700 |
commit | 8609ef4db1a2af0da0c2c20b26756031637de3ff (patch) | |
tree | 93b6e682c9eb03135ac507a43eeaeec0f7153076 /tensorflow | |
parent | 9d79acc6aae306e0444c193e945f0c87fe5bb509 (diff) |
When a mirrored variable is fetched in cross-tower mode, fetch its primary variable.
This prevents errors like
ValueError: Fetch argument MirroredVariable({'/job:localhost/replica:0/task:0/device:GPU:0': <tf.Variable 'global_step:0' shape=() dtype=int64>, '/job:localhost/replica:0/task:0/device:GPU:1': <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64>}) cannot be interpreted as a Tensor. (Device /job:localhost/replica:0/task:0/device:CPU:0 not found in ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1'] (current device ))
I ran distribute/examples/resnet with and without the change and it fixed the problem.
PiperOrigin-RevId: 194828672
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/values_test.py | 16 |
2 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 8cb5276579..466678ef2e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -229,6 +229,12 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index e96ce54741..1d4e801cd8 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib @@ -582,6 +583,21 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) + @test_util.run_in_graph_and_eager_modes(config=config) + def testFetchAMirroredVariable(self): + if context.num_gpus() < 1 or context.executing_eagerly(): + self.skipTest("A GPU is not available for this test or it's eager mode.") + + with self.test_session( + graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( + ["/device:GPU:0"]).scope(): + with ops.device("/device:GPU:0"): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + sess.run(variables_lib.global_variables_initializer()) + sess.run({"complicated": mirrored}) + _devices = ["/device:GPU:0", "/device:CPU:0"] |