diff options
-rw-r--r-- | tensorflow/contrib/distribute/python/cross_tower_ops.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/cross_tower_ops_test.py | 10 |
2 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py index 3a7addf221..dd74d5eed7 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py @@ -756,7 +756,7 @@ class CollectiveAllReduce(CrossTowerOps): ) super(CollectiveAllReduce, self).__init__() - # TODO(yuefengz, tucker): is index slices supported by collective ops? + # TODO(yuefengz, tucker): is indexed slices supported by collective ops? def _reduce(self, aggregation, per_device_value, destinations): all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0] if destinations is None or _devices_match(per_device_value, destinations): @@ -768,8 +768,10 @@ class CollectiveAllReduce(CrossTowerOps): if d in all_reduced._index: index[d] = all_reduced._index[d] else: - with ops.device(d): + with ops.control_dependencies(list( + all_reduced._index.values())), ops.device(d): index[d] = array_ops.identity(list(all_reduced._index.values())[0]) + return value_lib.Mirrored(index) def _batch_reduce(self, aggregation, value_destination_pairs): diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py index aec53b01d7..3508c9d599 100644 --- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py +++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py @@ -417,7 +417,7 @@ class MultiWorkerCollectiveAllReduceTest( devices = ["/device:GPU:%d" % i for i in range(num_gpus)] else: devices = ["/device:CPU:0"] - return collective_all_reduce_ops, devices, "local" + return collective_all_reduce_ops, devices, "" else: collective_all_reduce_ops = cross_tower_ops_lib.CollectiveAllReduce( 3, num_gpus, collective_keys=collective_keys) @@ -476,7 +476,7 @@ class MultiWorkerCollectiveAllReduceTest( destination_list = devices all_destinations = [ - None, destination_mirrored, destination_different, destination_str, + destination_different, None, destination_mirrored, destination_str, destination_list ] @@ -540,6 +540,12 @@ class MultiWorkerCollectiveAllReduceTest( self._run_between_graph_clients(self._test_reduction, self._cluster_spec, num_gpus) + # Collective ops doesn't support strategy with one device. + def testReductionLocal(self, num_gpus=2): + if context.num_gpus() < num_gpus: + return + self._test_reduction(None, None, num_gpus, local_mode=True) + if __name__ == "__main__": test.main() |