aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/one_device_strategy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/one_device_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py12
1 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index a580dac96c..dbd3514aec 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.distribute.python import values
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import distribute as distribute_lib
@@ -43,11 +44,6 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._default_device = device
def _create_variable(self, next_creator, *args, **kwargs):
- # No need to distinguish tower-local variables when not mirroring,
- # we just enforce that they are not trainable.
- if kwargs.pop("tower_local_reduce_method", None) is not None:
- kwargs["trainable"] = False
-
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
with ops.device(self._device):
@@ -80,15 +76,15 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
with ops.device(self._device):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
if not isinstance(value, values.MapOutput):
return value
l = value.get()
assert l
with ops.device(self._device):
- if method_string == "sum":
+ if aggregation == vs.VariableAggregation.SUM:
return math_ops.add_n(l)
- elif method_string == "mean":
+ elif aggregation == vs.VariableAggregation.MEAN:
return math_ops.add_n(l) / len(l)
else:
assert False