aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-03-29 19:56:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 19:59:30 -0700
commit9451c12d62b272789947f475554601295ada4083 (patch)
treeaae01bbe8c523fdeba6de2059e3f621583f7bdc1 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent566f9041e19831a4eb8904654ddd365fd8f234c0 (diff)
Internal change
PiperOrigin-RevId: 191024677
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py25
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 8cf83c52d8..eb0edb3a11 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -78,9 +78,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
[device_util.canonicalize(d) for d in devices])
self._device_index = values.PerDevice(
dict((d, i) for i, d in enumerate(devices)))
- self.cross_tower_ops = (
- cross_tower_ops or
- cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
+ self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
def _create_variable(self, next_creator, *args, **kwargs):
@@ -149,7 +147,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _broadcast(self, tensor, destinations):
# TODO(josh11b): In eager mode, use one thread per device, or async mode.
- return self.cross_tower_ops.broadcast(tensor, destinations or self._devices)
+ return self._get_cross_tower_ops().broadcast(tensor, destinations or
+ self._devices)
def _call_for_each_tower(self, fn, *args, **kwargs):
"""Run `fn` in separate threads, once per tower/worker device.
@@ -272,16 +271,28 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# in addition to PerDevice data.
return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
+ def configure(self, session_config=None):
+ if self._cross_tower_ops is None:
+ self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
+ self._devices, session_config=session_config)
+
+ def _get_cross_tower_ops(self):
+ if self._cross_tower_ops is None:
+ self._cross_tower_ops = (
+ cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps())
+ return self._cross_tower_ops
+
def _reduce(self, method_string, value, destinations):
if len(self._devices) == 1 and not isinstance(value, values.PerDevice):
value = values.PerDevice({self._devices[0]: value})
assert isinstance(value, values.PerDevice)
- return self.cross_tower_ops.reduce(
+
+ return self._get_cross_tower_ops().reduce(
method_string, value, destinations=destinations)
def _batch_reduce(self, method_string, value_destination_pairs):
- return self.cross_tower_ops.batch_reduce(method_string,
- value_destination_pairs)
+ return self._get_cross_tower_ops().batch_reduce(method_string,
+ value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):
# TODO(josh11b): Also support TowerLocalVariables here? If so, args and