aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-01 14:27:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 14:30:18 -0700
commit46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (patch)
tree36f8a33c084e65515fc4dd2e903603f450de4d2f /tensorflow/python/training/distribute.py
parent7cbbd3525b4232f2dc8cd117852c26ec472aa9b2 (diff)
Make tower-local variables non-trainable even with the default
DistributionStrategy. PiperOrigin-RevId: 194996819
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r--tensorflow/python/training/distribute.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 6aeecb31dd..c16b05102e 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -1127,8 +1127,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- if kwargs.pop("tower_local_reduce_method", None) is not None:
- kwargs["trainable"] = False
+ kwargs.pop("tower_local_reduce_method", None)
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(
@@ -1138,7 +1137,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
"""Does not set to resource variables."""
def create_tower_local_variable(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
- kwargs["tower_local_reduce_method"] = reduce_method
+ kwargs["trainable"] = False
return next_creator(*args, **kwargs)
_require_distribution_strategy_scope(self)