diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/tpu_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/tpu_strategy.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 1ae12ae98a..bc53898539 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope as vs from tensorflow.python.util import nest @@ -137,9 +138,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def get_finalize_ops(self): return [tpu.shutdown_system()] - def _reduce(self, method_string, value, destinations): + def _reduce(self, aggregation, value, destinations): del destinations # TPU is graph mode only. Rely on implicit Send/Recv. - if method_string == 'mean': + if aggregation == vs.VariableAggregation.MEAN: # TODO(jhseu): Revisit once we support model-parallelism. value *= (1. / self._num_cores_per_host) return tpu_ops.cross_replica_sum(value) |