aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-05-15 13:20:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-15 13:22:27 -0700
commitf17620153c47370f30a84b99eaba82bef8cd7d8e (patch)
treef5c30f0e31d4241cb7d0c572a3195819fa1f5a44 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent103638433f16f31dbde3480504c4c0a33273cc64 (diff)
Handle delayed variable initialization in MirroredStrategy. Test with RNN layer.
Bug reported and solution suggested in #19069 PiperOrigin-RevId: 196718454
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 8237b23dbb..89f2c431fe 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -111,10 +111,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
kwargs["name"] = "%s/replica_%d" % (var0name, i)
# Initialize replicas with the same value:
if context.executing_eagerly():
- initial_value = index[devices[0]].value()
+ kwargs["initial_value"] = array_ops.identity(
+ index[devices[0]].value())
else:
- initial_value = index[devices[0]].initial_value
- kwargs["initial_value"] = array_ops.identity(initial_value)
+ def initial_value_fn(device=d):
+ with ops.device(device):
+ return array_ops.identity(index[devices[0]].initial_value)
+ kwargs["initial_value"] = initial_value_fn
with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
assert not isinstance(v, values.DistributedVariable)