diff options
author | 2018-05-01 14:27:33 -0700 | |
---|---|---|
committer | 2018-05-01 14:30:18 -0700 | |
commit | 46bf1e8934b3bc8edeff3f218a50b0ee5806e96b (patch) | |
tree | 36f8a33c084e65515fc4dd2e903603f450de4d2f /tensorflow/python/training/distribute.py | |
parent | 7cbbd3525b4232f2dc8cd117852c26ec472aa9b2 (diff) |
Make tower-local variables non-trainable even with the default
DistributionStrategy.
PiperOrigin-RevId: 194996819
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r-- | tensorflow/python/training/distribute.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 6aeecb31dd..c16b05102e 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -1127,8 +1127,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): def creator(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) - if kwargs.pop("tower_local_reduce_method", None) is not None: - kwargs["trainable"] = False + kwargs.pop("tower_local_reduce_method", None) return next_creator(*args, **kwargs) return _CurrentDistributionContext( @@ -1138,7 +1137,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): """Does not set to resource variables.""" def create_tower_local_variable(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) - kwargs["tower_local_reduce_method"] = reduce_method + kwargs["trainable"] = False return next_creator(*args, **kwargs) _require_distribution_strategy_scope(self) |