aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/all_reduce/python/all_reduce.py
blob: 16617b726675da3f76a60a1bbdcd50c5c08d8966 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to construct a TF subgraph implementing distributed All-Reduce."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import re

from tensorflow.contrib import nccl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops


def _flatten_tensors(tensors):
  """Check tensors for isomorphism and flatten.

  Args:
    tensors: list of T @{tf.Tensor} which must all have the same shape.

  Returns:
    tensors: a list of T @{tf.Tensor} which are flattened (1D) views of tensors
    shape: the original shape of each element of input tensors

  Raises:
    ValueError: tensors are empty or non-isomorphic.
  """
  if not tensors:
    raise ValueError("tensors cannot be empty")
  shape = tensors[0].shape
  for tensor in tensors:
    shape = shape.merge_with(tensor.shape)
  if shape.ndims is None:
    raise ValueError("At least one of the tensors in 'tensors' must have "
                     "statically known rank.")
  if len(shape) != 1:
    reshaped = []
    for t in tensors:
      with ops.colocate_with(t):
        reshaped.append(array_ops.reshape(t, [-1]))
    tensors = reshaped
  return tensors, shape


def _reshape_tensors(tensors, shape):
  """Reshape tensors flattened by _flatten_tensors.

  Args:
    tensors: list of T @{tf.Tensor} of identical length 1D tensors.
    shape: list of integers describing the desired shape.  Product of
      the elements must equal the length of each tensor.

  Returns:
    list of T @{tf.Tensor} which are the reshaped inputs.
  """
  reshaped = []
  for t in tensors:
    with ops.colocate_with(t):
      reshaped.append(array_ops.reshape(t, shape))
  return reshaped


def _padded_split(tensor, pieces):
  """Like split for 1D tensors but pads-out case where len % pieces != 0.

  Args:
    tensor: T @{tf.Tensor} that must be 1D.
    pieces: a positive integer specifying the number of pieces into which
      tensor should be split.

  Returns:
    list of T @{tf.Tensor} of length pieces, which hold the values of
      thin input tensor, in order.  The final tensor may
      be zero-padded on the end to make its size equal to those of all
      of the other tensors.

  Raises:
    ValueError: The input tensor is not 1D.
  """
  shape = tensor.shape
  if 1 != len(shape):
    raise ValueError("input tensor must be 1D")
  tensor_len = shape[0].value
  with ops.colocate_with(tensor):
    if tensor_len % pieces != 0:
      # pad to an even length
      chunk_size = 1 + tensor_len // pieces
      if pieces > tensor_len:
        # This is an edge case that should not come up in practice,
        # i.e. a different reduction algorithm would be better,
        # but we'll make it work just for completeness.
        pad_len = pieces - tensor_len
        extended_whole = array_ops.concat(
            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
        parts = array_ops.split(extended_whole, pieces)
        return parts, pad_len
      elif (pieces - 1) * chunk_size >= tensor_len:
        # Another edge case of limited real interest.
        pad_len = (pieces * chunk_size) % tensor_len
        extended_whole = array_ops.concat(
            [tensor, array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
        parts = array_ops.split(extended_whole, pieces)
        return parts, pad_len
      else:
        last_chunk_size = tensor_len - (pieces - 1) * chunk_size
        pad_len = chunk_size - last_chunk_size
        piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
        parts = array_ops.split(tensor, piece_lens)
        parts[-1] = array_ops.concat(
            [parts[-1], array_ops.zeros([pad_len], dtype=tensor.dtype)], 0)
        return parts, pad_len
    else:
      return array_ops.split(tensor, pieces), 0


def _strip_padding(tensors, pad_len):
  """Strip the suffix padding added by _padded_split.

  Args:
    tensors: list of T @{tf.Tensor} of identical length 1D tensors.
    pad_len: number of elements to be stripped from the end of each tensor.

  Returns:
    list of T @{tf.Tensor} which are the stripped inputs.

  Raises:
    ValueError: tensors must be a non-empty list of 1D tensors, and
      each must be longer than pad_len.
  """
  if not tensors:
    raise ValueError("tensors cannot be empty")
  shape = tensors[0].shape
  if len(shape) > 1:
    raise ValueError("tensors must be 1D")
  prefix_len = int(shape[0] - pad_len)
  if prefix_len < 0:
    raise ValueError("pad_len longer than tensor")
  stripped = []
  for t in tensors:
    with ops.colocate_with(t):
      stripped.append(array_ops.slice(t, [0], [prefix_len]))
  return stripped


def _ragged_split(tensor, pieces):
  """Like split for 1D tensors but allows case where len % pieces != 0.

  Args:
    tensor: T @{tf.Tensor} that must be 1D.
    pieces: a positive integer specifying the number of pieces into which
      tensor should be split.

  Returns:
    list of T @{tf.Tensor} of length pieces, which hold the values of
      the input tensor, in order.  The final tensor may be shorter
      than the others, which will all be of equal length.

  Raises:
    ValueError: input tensor must be 1D.
  """
  shape = tensor.shape
  if 1 != len(shape):
    raise ValueError("input tensor must be 1D")
  tensor_len = shape[0].value
  chunk_size = tensor_len // pieces
  with ops.colocate_with(tensor):
    if tensor_len != (pieces * chunk_size):
      # last piece will be short
      assert pieces > 1
      last_chunk_size = tensor_len - ((pieces - 1) * chunk_size)
      assert last_chunk_size > 0
      piece_lens = [chunk_size for _ in range(pieces - 1)] + [last_chunk_size]
      return array_ops.split(tensor, piece_lens)
    else:
      return array_ops.split(tensor, pieces)


def _ring_permutations(num_workers, num_subchunks, gpu_perm):
  """"Generate an array of device index arrays, one for each subchunk.

  In the basic ring reduction algorithm there are size(T)/num_devices
  data chunks and each device process one chunk per tick, i.e. sending
  one chunk and receiving one chunk.  The idea of subchunking is that
  each device processes num_subchunks smaller data regions per tick,
  and the ring rank permutation is different for each subchunk index
  so that a device is potentially sending to and receiving from
  num_subchunks different other devices at each tick.  Where multiple
  independent data channels exist between devices, this strategy
  supplies a method of using them in parallel.

  Args:
    num_workers: number of worker tasks
    num_subchunks: number of subchunks into which to divide each per-GPU chunk.
    gpu_perm: an array of integers in [0, num_gpus-1] giving the default
      ring order of GPUs at each worker.  Other permutations will be generated
      by rotating this array and splicing together per-worker instances.

  Raises:
    ValueError: the number of subchunks may not exceed the number of GPUs.

  Returns:
    pred_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
        preceding device in the permutation for that subchunk.  The
        device index of GPU i at worker j is i + (j * num_gpus).
    rank_by_s_d: list of lists that maps (by index) from (subchunk, dev) to
       local rank of device d in the permutation for that subchunk.
  """
  num_gpus = len(gpu_perm)
  devices = num_workers * num_gpus
  if devices == 0:
    return [], []
  if num_subchunks > num_gpus:
    raise ValueError(
        "num_subchunks %d must be <= num_gpus %d" % (num_subchunks, num_gpus))
  rotation_interval = max(1, int(num_gpus / num_subchunks))
  perms_by_s = []
  for s in range(0, num_subchunks):
    full_order = []
    offset = s * rotation_interval
    for w in range(0, num_workers):
      default_order = [(w * num_gpus) + i for i in gpu_perm]
      dev_order = default_order[offset:] + default_order[:offset]
      full_order += dev_order
    perms_by_s.append(full_order)
  pred_by_s_d = [[-1 for d in range(0, devices)]
                 for s in range(0, num_subchunks)]
  rank_by_s_d = [[-1 for d in range(0, devices)]
                 for s in range(0, num_subchunks)]
  for s in range(0, num_subchunks):
    for d in range(0, devices):
      for t in range(0, devices):
        if d == perms_by_s[s][t]:
          rank_by_s_d[s][d] = t
          pred_by_s_d[s][d] = perms_by_s[s][(t + devices - 1) % devices]
          break
  return (pred_by_s_d, rank_by_s_d)


def build_ring_all_reduce(input_tensors, num_workers, num_subchunks,
                          gpu_perm, red_op, un_op=None):
  """Construct a subgraph performing a ring-style all-reduce of input_tensors.

  Args:
    input_tensors: a list of T @{tf.Tensor} objects, which must all
      have the same shape and type.
    num_workers: number of worker tasks spanned by input_tensors.
    num_subchunks: number of subchunks each device should process in one tick.
    gpu_perm: a list of ints giving a ring-wise rank ordering of GPUs at
      each worker.  All workers must have the same number of
      GPUs with the same rank ordering.  If NVLINK is available, this should
      be a ring order supported by NVLINK edges.
    red_op: a binary operator for elementwise reduction.
    un_op: an optional unary operator to apply to fully reduced values.

  Raises:
    ValueError: empty input_tensors or they don't all have same
    size.

  Returns:
    a list of T @{tf.Tensor} identical sum-reductions of input_tensors.
  """
  if len(input_tensors) < 2:
    raise ValueError("input_tensors must be length 2 or longer")
  input_tensors, shape = _flatten_tensors(input_tensors)
  devices = [t.device for t in input_tensors]
  (pred_by_s_d, rank_by_s_d) = _ring_permutations(
      num_workers, num_subchunks, gpu_perm)
  chunks_by_dev, pad_len = _build_ring_gather(
      input_tensors, devices,
      num_subchunks, pred_by_s_d, rank_by_s_d, red_op)
  if un_op:
    chunks_by_dev = _apply_unary_to_chunks(un_op, chunks_by_dev)
  output_tensors = _build_ring_scatter(pred_by_s_d, rank_by_s_d,
                                       chunks_by_dev)
  if pad_len > 0:
    output_tensors = _strip_padding(output_tensors, pad_len)
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _build_ring_gather(input_tensors, devices, num_subchunks,
                       pred_by_s_d, rank_by_s_d, red_op):
  """Construct a subgraph for the first (reduction) pass of ring all-reduce.

  Args:
    input_tensors: a list of T @{tf.Tensor} 1D input tensors of same
      shape and type.
    devices: array of device name strings
    num_subchunks: number of subchunks each device should process in one tick.
    pred_by_s_d: as produced by _ring_permutations
    rank_by_s_d: as produced by _ring_permutations
    red_op: a binary operator for elementwise reduction

  Raises:
    ValueError: tensors must all be one dimensional.

  Returns:
    list of list of T @{tf.Tensor} of (partially) reduced values where
    exactly num_subchunks chunks at each device are fully reduced.
  """
  num_devices = len(input_tensors)
  if num_devices == 0:
    return []
  if num_devices == 1:
    return input_tensors
  shape = input_tensors[0].shape
  if 1 != len(shape):
    raise ValueError("input tensors must be 1D")
  num_chunks = num_devices * num_subchunks
  num_ticks = num_devices - 1
  # Initialize chunks_by_dev with splits of the input tensors.
  chunks_by_dev = []
  split_pad_len = 0
  for d in range(0, num_devices):
    with ops.device(devices[d]):
      splits, split_pad_len = _padded_split(input_tensors[d], num_chunks)
      chunks_by_dev.append(splits)
  # Reduction phase
  for tick in range(0, num_ticks):
    # One new partial reduction for every chunk
    new_partial_reductions = [None for _ in range(0, num_chunks)]
    # Compute reductions with respect to last tick's values
    for d in range(0, num_devices):
      with ops.device(devices[d]):
        for s in range(0, num_subchunks):
          rank = rank_by_s_d[s][d]
          seg_index = (rank + num_devices - (2 + tick)) % num_devices
          pred_dev = pred_by_s_d[s][d]
          chunk_index = (seg_index * num_subchunks) + s
          new_partial_reductions[chunk_index] = red_op(
              chunks_by_dev[pred_dev][chunk_index],
              chunks_by_dev[d][chunk_index])
    # Update chunks_by_dev with the new values at the end of the tick.
    for d in range(0, num_devices):
      for s in range(0, num_subchunks):
        rank = rank_by_s_d[s][d]
        seg_index = (rank + num_devices - (2 + tick)) % num_devices
        chunk_index = (seg_index * num_subchunks) + s
        chunks_by_dev[d][chunk_index] = new_partial_reductions[chunk_index]
  return chunks_by_dev, split_pad_len


def _apply_unary_to_chunks(f, chunks_by_dev):
  """Apply a unary op to each tensor in chunks_by_dev, on same device.

  Args:
    f: a unary function over T @{tf.Tensor}.
    chunks_by_dev: list of lists of T @{tf.Tensor}.

  Returns:
    new list of lists of T @{tf.Tensor} with the same structure as
    chunks_by_dev containing the derived tensors.
  """
  output = []
  for x in chunks_by_dev:
    with ops.colocate_with(x[0]):
      output.append([f(t) for t in x])
  return output


def _build_ring_scatter(pred_by_s_d, rank_by_s_d,
                        chunks_by_dev):
  """Construct subgraph for second (scatter) pass of ring all-reduce.

  Args:
    pred_by_s_d: as produced by _ring_permutations
    rank_by_s_d: as produced by _ring_permutations
    chunks_by_dev: list of list of T @{tf.Tensor} indexed by ints
      (device, chunk)

  Raises:
    ValueError: chunks_by_dev is not well-formed

  Returns:
    list of T @{tf.Tensor} which are the fully reduced tensors, one
    at each device corresponding to the outer dimension of chunks_by_dev.
  """
  num_devices = len(chunks_by_dev)
  num_chunks = len(chunks_by_dev[0])
  if 0 != num_chunks % num_devices:
    raise ValueError(
        "Expect number of chunks per device to be divisible by num_devices")
  num_subchunks = int(num_chunks / num_devices)
  num_ticks = num_devices - 1
  for tick in range(0, num_ticks):
    passed_values = [None for _ in range(0, num_chunks)]
    for d in range(0, num_devices):
      with ops.colocate_with(chunks_by_dev[d][0]):
        for s in range(0, num_subchunks):
          rank = rank_by_s_d[s][d]
          seg_index = (rank + num_devices - (1 + tick)) % num_devices
          pred_dev = pred_by_s_d[s][d]
          chunk_index = (seg_index * num_subchunks) + s
          passed_values[chunk_index] = array_ops.identity(
              chunks_by_dev[pred_dev][chunk_index])
    for d in range(0, num_devices):
      for s in range(0, num_subchunks):
        rank = rank_by_s_d[s][d]
        seg_index = (rank + num_devices - (1 + tick)) % num_devices
        chunk_index = (seg_index * num_subchunks) + s
        chunks_by_dev[d][chunk_index] = passed_values[chunk_index]
  # Join chunks at each device.
  output = []
  for x in chunks_by_dev:
    with ops.colocate_with(x[0]):
      output.append(array_ops.concat(x, 0))
  return output


def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None):
  """Construct a subgraph for recursive halving-doubling all-reduce.

  The recursive halving-doubling algorithm is described in
  http://www.mcs.anl.gov/~thakur/papers/ijhpca-coll.pdf

  The concept is to arrange the participating n devices in
  a linear sequence where devices exchange data pairwise
  with one other device in each round.  During the gather
  phase there are lg(n) rounds where devices exchange
  increasingly smaller sub-tensors with another device
  at increasingly greater distances, until at the top
  each device has 1/n of the fully reduced values.  During the
  scatter phase each device exchanges its fully reduced
  sub-tensor (which doubles in length at each round)
  with one other device at increasingly smaller distances
  until each device has all of the fully reduced values.

  Note: this preliminary version requires that len(input_tensors) be a
    power of 2.  TODO(tucker): relax this restriction.  Also, the
    number of elements in each tensor must be divisible by 2^h where h
    is the number of hops in each phase.  This will also be relaxed in
    the future with edge-case specific logic.

  Args:
    input_tensors: list of T @{tf.Tensor} to be elementwise reduced.
    red_op: a binary elementwise reduction Op.
    un_op: an optional unary elementwise Op to apply to reduced values.

  Returns:
    list of T @{tf.Tensor} which are the fully reduced tensors, one
    at each device of input_tensors.

  Raises:
    ValueError: num_devices not a power of 2, or tensor len not divisible
    by 2 the proper number of times.
  """
  devices = [t.device for t in input_tensors]
  input_tensors, shape = _flatten_tensors(input_tensors)
  reduced_shards = _build_recursive_hd_gather(input_tensors, devices, red_op)
  if un_op:
    reduced_shards = [un_op(t) for t in reduced_shards]
  output_tensors = _build_recursive_hd_scatter(reduced_shards, devices)
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _build_recursive_hd_gather(input_tensors, devices, red_op):
  """Construct the gather phase of recursive halving-doubling all-reduce.

  Args:
    input_tensors: list of T @{tf.Tensor} to be elementwise reduced.
    devices: a list of strings naming the devices hosting input_tensors,
      which will also be used to host the (partial) reduction values.
    red_op: a binary elementwise reduction Op.

  Returns:
    list of T @{tf.Tensor} which are the fully reduced tensor shards.

  Raises:
    ValueError: num_devices not a power of 2, or tensor len not divisible
    by 2 the proper number of times.
  """
  num_devices = len(devices)
  num_hops = int(math.log(num_devices, 2))
  if num_devices != (2 ** num_hops):
    raise ValueError("num_devices must be a power of 2")
  chunks = input_tensors
  for h in range(0, num_hops):
    span = 2 ** h
    group_size = span * 2
    new_chunks = [[] for _ in devices]
    for d in range(0, num_devices):
      if (d % group_size) >= (group_size / 2):
        # skip right half of a pair
        continue
      left_dev = devices[d]
      right_dev = devices[d + span]
      left_split = array_ops.split(chunks[d], 2)
      right_split = array_ops.split(chunks[d+span], 2)
      with ops.device(left_dev):
        new_chunks[d] = red_op(left_split[0], right_split[0])
      with ops.device(right_dev):
        new_chunks[d + span] = red_op(left_split[1], right_split[1])
    chunks = new_chunks
  return chunks


def _build_recursive_hd_scatter(input_tensors, devices):
  """Construct the scatter phase of recursive halving-doublng all-reduce.

  Args:
    input_tensors: list of T @{tf.Tensor} that are fully-reduced shards.
    devices: a list of strings naming the devices on which the reconstituted
      full tensors should be placed.

  Returns:
    list of T @{tf.Tensor} which are the fully reduced tensors.
  """
  num_devices = len(devices)
  num_hops = int(math.log(num_devices, 2))
  assert num_devices == (2 ** num_hops), "num_devices must be a power of 2"
  chunks = input_tensors
  for h in reversed(range(0, num_hops)):
    span = 2 ** h
    group_size = span * 2
    new_chunks = [[] for _ in devices]
    for d in range(0, num_devices):
      if (d % group_size) >= (group_size / 2):
        # skip right half of a pair
        continue
      left_idx = d
      right_idx = d + span
      left_dev = devices[left_idx]
      right_dev = devices[right_idx]
      with ops.device(left_dev):
        new_chunks[left_idx] = array_ops.concat([chunks[left_idx],
                                                 chunks[right_idx]], 0)
      with ops.device(right_dev):
        new_chunks[right_idx] = array_ops.concat([chunks[left_idx],
                                                  chunks[right_idx]], 0)
    chunks = new_chunks
  return chunks


def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None):
  """Construct a subgraph for shuffle all-reduce.

  Shuffle reduce is essentially the algorithm implemented when using
  parameter servers.  Suppose tensor length is n, there are d devices
  and g gather shards.  Each device sends a n/g length sub-tensor to
  each gather shard.  The gather shards perform a reduction across d
  fragments, then broadcast the result back to each device.  The
  devices then join the g fully reduced fragments they receive from
  the shards.  The gather shards could perform d-1 pairwise
  reductions, or one d-way reduction.  The first is better where
  reduction Op time is low compared to transmission time, the second
  better in the other case.

  Args:
    input_tensors: list of T @(tf.Tensor} values to be reduced.
    gather_devices: list of names of devices on which reduction shards
      should be placed.
    red_op: an n-array elementwise reduction Op
    un_op: optional elementwise unary Op to be applied to fully-reduced values.

  Returns:
    list of T @{tf.Tensor} which are the fully reduced tensors.
  """
  input_tensors, shape = _flatten_tensors(input_tensors)
  dst_devices = [t.device for t in input_tensors]
  reduced_shards = _build_shuffle_gather(input_tensors, gather_devices,
                                         red_op, un_op)
  output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices)
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op=None):
  """Construct the gather (concentrate and reduce) phase of shuffle all-reduce.

  Args:
    input_tensors: list of T @(tf.Tensor} values to be reduced.
    gather_devices: list of names of devices on which reduction shards
      should be placed.
    red_op: the binary reduction Op
    un_op: optional elementwise unary Op to be applied to fully-reduced values.

  Returns:
    list of T @{tf.Tensor} which are the fully reduced shards.

  Raises:
    ValueError: inputs not well-formed.
  """
  num_source_devices = len(input_tensors)
  num_gather_devices = len(gather_devices)
  shape = input_tensors[0].shape
  if len(shape) != 1:
    raise ValueError("input_tensors must be 1D")
  shards_by_source = []
  for d in range(0, num_source_devices):
    with ops.colocate_with(input_tensors[d]):
      shards_by_source.append(
          _ragged_split(input_tensors[d], num_gather_devices))
  reduced_shards = []
  for d in range(0, num_gather_devices):
    with ops.device(gather_devices[d]):
      values = [s[d] for s in shards_by_source]
      red_shard = red_op(values)
      if un_op:
        red_shard = un_op(red_shard)
      reduced_shards.append(red_shard)
  return reduced_shards


def _build_shuffle_scatter(reduced_shards, dst_devices):
  """Build the scatter phase of shuffle all-reduce.

  Args:
    reduced_shards:  list of T @(tf.Tensor} fully reduced shards
    dst_devices: list of names of devices at which the fully-reduced value
      should be reconstituted.

  Returns:
    list of T @{tf.Tensor} scattered tensors.
  """
  num_devices = len(dst_devices)
  out_tensors = []
  for d in range(0, num_devices):
    with ops.device(dst_devices[d]):
      out_tensors.append(array_ops.concat(reduced_shards, 0))
  return out_tensors


def _split_by_task(devices, values):
  """Partition devices and values by common task.

  Args:
    devices: list of device name strings
    values: list of T @{tf.tensor} of same length as devices.

  Returns:
    (per_task_devices, per_task_values) where both values are
    lists of lists with isomorphic structure: the outer list is
    indexed by task, and the inner list has length of the number
    of values belonging to that task.  per_task_devices contains
    the specific devices to which the values are local, and
    per_task_values contains the corresponding values.

  Raises:
    ValueError: devices must be same length as values.
  """
  num_devices = len(devices)
  if num_devices != len(values):
    raise ValueError("len(devices) must equal len(values)")
  pattern = re.compile(r"/task:(\d+)/")
  per_task_devices = []
  per_task_values = []
  for d in range(num_devices):
    m = pattern.search(devices[d])
    if m:
      index = int(m.group(1))
      while index >= len(per_task_devices):
        per_task_devices.append([])
        per_task_values.append([])
      per_task_devices[index].append(devices[d])
      per_task_values[index].append(values[d])
    else:
      assert False, "failed to parse device %s" % devices[d]
  return (per_task_devices, per_task_values)


def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
  """Build a subgraph that does one full all-reduce, using NCCL.

  Args:
    input_tensors: list of T @{tf.Tensor} of same-shape and type values to
      be reduced.
    red_op: binary elementwise reduction operator.  Must be one of
      {tf.add}
    un_op: optional unary elementwise Op to apply to fully-reduce values.

  Returns:
    list of T @{tf.Tensor} of reduced values.

  Raises:
    ValueError: red_op not supported.
  """
  if red_op == math_ops.add:
    output_tensors = nccl.all_sum(input_tensors)
  else:
    raise ValueError("red_op not supported by NCCL all-reduce: ", red_op)
  if un_op:
    un_op_wrapped = []
    for t in output_tensors:
      with ops.colocate_with(t):
        un_op_wrapped.append(un_op(t))
    output_tensors = un_op_wrapped
  return output_tensors


def _build_nccl_hybrid(input_tensors, red_op, upper_level_f):
  """Construct a subgraph for NCCL hybrid all-reduce.

  Args:
    input_tensors: list of T @{tf.Tensor} of same-shape and type values to
      be reduced.
    red_op: binary elementwise reduction operator.
    upper_level_f: function for reducing one value per worker, across
      workers.

  Returns:
    list of T @{tf.Tensor} of reduced values.

  Raises:
    ValueError: inputs not well-formed.
  """
  input_tensors, shape = _flatten_tensors(input_tensors)
  devices = [t.device for t in input_tensors]
  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
  num_workers = len(per_worker_devices)
  up_values = [None for w in range(0, num_workers)]
  up_devices = up_values[:]
  down_values = up_values[:]
  # First stage: reduce within each worker using NCCL
  for w in range(0, num_workers):
    worker_values = build_nccl_all_reduce(per_worker_values[w], red_op)
    # NOTE: these reductions will not run to completion unless
    # every output value is used.  Since we only need one, we
    # need to put control dependencies on the rest.
    with ops.control_dependencies(worker_values):
      with ops.device(worker_values[0].device):
        up_values[w] = array_ops.identity(worker_values[0])
      up_devices[w] = per_worker_devices[w][0]
  # Second stage: Apply upper_level_f to reduce across first device at
  # each worker
  level_2_output = upper_level_f(up_values)
  # Third stage: propagate within each worker using NCCL Broadcast
  for w in range(0, num_workers):
    dst_tensors = []
    with ops.device(per_worker_devices[w][0]):
      broadcast_src = nccl.broadcast(array_ops.identity(level_2_output[w]))
    for d in per_worker_devices[w]:
      with ops.device(d):
        dst_tensors.append(array_ops.identity(broadcast_src))
    down_values[w] = dst_tensors
  output_tensors = [v for sublist in down_values for v in sublist]
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def _reduce_non_singleton(input_tensors, red_f, un_op):
  """If input_tenors has more than one element apply red_f, else apply un_op."""
  if len(input_tensors) > 1:
    return red_f(input_tensors)
  else:
    if not un_op:
      return input_tensors
    output_tensors = []
    for t in input_tensors:
      with ops.colocate_with(t):
        output_tensors.append(un_op(t))
    return output_tensors


def build_nccl_then_ring(input_tensors, subdiv, red_op, un_op=None):
  """Construct hybrid of NCCL within workers, Ring across workers."""
  def upper_builder(y):
    return build_ring_all_reduce(y, len(y), subdiv, [0], red_op, un_op)
  def upper_level_f(x):
    return _reduce_non_singleton(x, upper_builder, un_op)
  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)


def build_nccl_then_recursive_hd(input_tensors, red_op, un_op=None):
  """Construct hybrid of NCCL within workers, Recursive-HD across workers."""
  upper_level_f = lambda x: build_recursive_hd_all_reduce(x, red_op, un_op)
  return _build_nccl_hybrid(input_tensors, red_op, upper_level_f)


def build_nccl_then_shuffle(input_tensors, gather_devices, nccl_red_op,
                            shuffle_red_op, un_op=None):
  """Construct hybrid of NCCL within workers, Shuffle across workers."""
  upper_level_f = lambda x: build_shuffle_all_reduce(x, gather_devices,
                                                     shuffle_red_op, un_op)
  return _build_nccl_hybrid(input_tensors, nccl_red_op, upper_level_f)


def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f):
  """Construct a subgraph for Shuffle hybrid all-reduce.

  Args:
    input_tensors: list of T @{tf.Tensor} of same-shape and type values to
      be reduced.
    gather_devices: list of device names on which to host gather shards.
    red_op: binary elementwise reduction operator.
    upper_level_f: function for reducing one value per worker, across
      workers.

  Returns:
    list of T @{tf.Tensor} of reduced values.

  Raises:
    ValueError: inputs not well-formed.
  """
  input_tensors, shape = _flatten_tensors(input_tensors)
  # First stage, reduce across each worker using gather_devices.
  devices = [t.device for t in input_tensors]
  per_worker_devices, per_worker_values = _split_by_task(devices, input_tensors)
  num_workers = len(per_worker_devices)
  up_values = []
  if len(gather_devices) != num_workers:
    raise ValueError("For shuffle hybrid, gather_devices must contain one "
                     "device per worker. ")
  for w in range(0, num_workers):
    reduced_shards = _build_shuffle_gather(
        per_worker_values[w], [gather_devices[w]], red_op)
    up_values.append(reduced_shards[0])
  # Second stage, apply upper_level_f.
  level_2_output = upper_level_f(up_values)
  # Third stage, apply shuffle scatter at each worker.
  output_tensors = []
  for w in range(0, num_workers):
    output_tensors += _build_shuffle_scatter(
        [level_2_output[w]], per_worker_devices[w])
  if len(shape) != 1:
    output_tensors = _reshape_tensors(output_tensors, shape)
  return output_tensors


def build_shuffle_then_ring(input_tensors, gather_devices, subdiv,
                            red_n_op, red_op, un_op=None):
  """Construct hybrid of Shuffle within workers, Ring across workers."""
  def upper_builder(tensors):
    return build_ring_all_reduce(tensors, len(tensors), subdiv, [0],
                                 red_op, un_op)
  def upper_level_f(tensors):
    return _reduce_non_singleton(tensors, upper_builder, un_op)
  return _build_shuffle_hybrid(
      input_tensors, gather_devices, red_n_op, upper_level_f)


def build_shuffle_then_shuffle(input_tensors, first_gather_devices,
                               second_gather_devices, red_op, un_op=None):
  """Construct hybrid of Shuffle within workers, Shuffle across workers."""
  def upper_builder(tensors):
    return build_shuffle_all_reduce(tensors, second_gather_devices,
                                    red_op, un_op)
  def upper_level_f(tensors):
    return _reduce_non_singleton(tensors, upper_builder, un_op)
  return _build_shuffle_hybrid(
      input_tensors, first_gather_devices, red_op, upper_level_f)