aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-19 13:39:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 13:43:04 -0700
commitf9af1e1f742210615a9eed4866cf6744419fde24 (patch)
tree4d53ce30e56df361926b5446017a576c58becc0c /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentca226664780bf980848ffe3552d215568139ed6d (diff)
Disable caching_device for mirrored variables.
PiperOrigin-RevId: 201232817
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 900aa10e93..c1b4b870a5 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -109,6 +109,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if tower_local is not None:
kwargs["trainable"] = False
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
# TODO(josh11b,apassos): It would be better if variable initialization
# was never recorded on the tape instead of having to do this manually
# here.