diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-03-29 19:56:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-29 19:59:30 -0700 |
commit | 9451c12d62b272789947f475554601295ada4083 (patch) | |
tree | aae01bbe8c523fdeba6de2059e3f621583f7bdc1 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 566f9041e19831a4eb8904654ddd365fd8f234c0 (diff) |
Internal change
PiperOrigin-RevId: 191024677
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 8cf83c52d8..eb0edb3a11 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -78,9 +78,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): [device_util.canonicalize(d) for d in devices]) self._device_index = values.PerDevice( dict((d, i) for i, d in enumerate(devices))) - self.cross_tower_ops = ( - cross_tower_ops or - cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) + self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device def _create_variable(self, next_creator, *args, **kwargs): @@ -149,7 +147,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _broadcast(self, tensor, destinations): # TODO(josh11b): In eager mode, use one thread per device, or async mode. - return self.cross_tower_ops.broadcast(tensor, destinations or self._devices) + return self._get_cross_tower_ops().broadcast(tensor, destinations or + self._devices) def _call_for_each_tower(self, fn, *args, **kwargs): """Run `fn` in separate threads, once per tower/worker device. @@ -272,16 +271,28 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # in addition to PerDevice data. return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()}) + def configure(self, session_config=None): + if self._cross_tower_ops is None: + self._cross_tower_ops = cross_tower_ops_lib.choose_the_best( + self._devices, session_config=session_config) + + def _get_cross_tower_ops(self): + if self._cross_tower_ops is None: + self._cross_tower_ops = ( + cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps()) + return self._cross_tower_ops + def _reduce(self, method_string, value, destinations): if len(self._devices) == 1 and not isinstance(value, values.PerDevice): value = values.PerDevice({self._devices[0]: value}) assert isinstance(value, values.PerDevice) - return self.cross_tower_ops.reduce( + + return self._get_cross_tower_ops().reduce( method_string, value, destinations=destinations) def _batch_reduce(self, method_string, value_destination_pairs): - return self.cross_tower_ops.batch_reduce(method_string, - value_destination_pairs) + return self._get_cross_tower_ops().batch_reduce(method_string, + value_destination_pairs) def _update(self, var, fn, *args, **kwargs): # TODO(josh11b): Also support TowerLocalVariables here? If so, args and |