diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/one_device_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/one_device_strategy.py | 12 |
1 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py index a580dac96c..dbd3514aec 100644 --- a/tensorflow/contrib/distribute/python/one_device_strategy.py +++ b/tensorflow/contrib/distribute/python/one_device_strategy.py @@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import values from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.training import distribute as distribute_lib @@ -43,11 +44,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): self._default_device = device def _create_variable(self, next_creator, *args, **kwargs): - # No need to distinguish tower-local variables when not mirroring, - # we just enforce that they are not trainable. - if kwargs.pop("tower_local_reduce_method", None) is not None: - kwargs["trainable"] = False - colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: with ops.device(self._device): @@ -80,15 +76,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy): with ops.device(self._device): return values.MapOutput([fn(m, *args, **kwargs) for m in map_over]) - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): if not isinstance(value, values.MapOutput): return value l = value.get() assert l with ops.device(self._device): - if method_string == "sum": + if aggregation == vs.VariableAggregation.SUM: return math_ops.add_n(l) - elif method_string == "mean": + elif aggregation == vs.VariableAggregation.MEAN: return math_ops.add_n(l) / len(l) else: assert False |