aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/tpu_strategy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/tpu_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1ae12ae98a..bc53898539 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import nest
@@ -137,9 +138,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def get_finalize_ops(self):
return [tpu.shutdown_system()]
- def _reduce(self, method_string, value, destinations):
+ def _reduce(self, aggregation, value, destinations):
del destinations # TPU is graph mode only. Rely on implicit Send/Recv.
- if method_string == 'mean':
+ if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self._num_cores_per_host)
return tpu_ops.cross_replica_sum(value)