diff options
author | Priya Gupta <priyag@google.com> | 2018-05-15 13:20:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-15 13:22:27 -0700 |
commit | f17620153c47370f30a84b99eaba82bef8cd7d8e (patch) | |
tree | f5c30f0e31d4241cb7d0c572a3195819fa1f5a44 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 103638433f16f31dbde3480504c4c0a33273cc64 (diff) |
Handle delayed variable initialization in MirroredStrategy. Test with RNN layer.
Bug reported and solution suggested in #19069
PiperOrigin-RevId: 196718454
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 8237b23dbb..89f2c431fe 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -111,10 +111,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): kwargs["name"] = "%s/replica_%d" % (var0name, i) # Initialize replicas with the same value: if context.executing_eagerly(): - initial_value = index[devices[0]].value() + kwargs["initial_value"] = array_ops.identity( + index[devices[0]].value()) else: - initial_value = index[devices[0]].initial_value - kwargs["initial_value"] = array_ops.identity(initial_value) + def initial_value_fn(device=d): + with ops.device(device): + return array_ops.identity(index[devices[0]].initial_value) + kwargs["initial_value"] = initial_value_fn with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): v = next_creator(*args, **kwargs) assert not isinstance(v, values.DistributedVariable) |