diff options
author | Youlong Cheng <ylc@google.com> | 2018-09-24 19:47:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 19:52:14 -0700 |
commit | e9cdf9f412a3aea324a4a1655d3bffb87abaff0d (patch) | |
tree | 24559e0a02d601cf8c2916b1d7cf4c8ca51918e0 /tensorflow/contrib/tpu | |
parent | bb1c131aad55e336d25fd297ecd8582773d6476f (diff) |
[TF:XLA] Introduce CollectivePermute op.
PiperOrigin-RevId: 214373714
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r-- | tensorflow/contrib/tpu/ops/cross_replica_ops.cc | 20 | ||||
-rw-r--r-- | tensorflow/contrib/tpu/python/ops/tpu_ops.py | 27 |
2 files changed, 46 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index ea8e0e00ed..87e3a5946c 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -125,4 +125,24 @@ output: The sum of all the distributed inputs. T: The type of elements to be summed. )doc"); +REGISTER_OP("CollectivePermute") + .Input("input: T") + .Input("source_target_pairs: int32") + .Output("output: T") + .Attr("T: numbertype") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +An Op to permute tensors across replicated TPU instances. Each instance +supplies its own input. + +For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing +source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: +`[D, A, B, C]`. + +input: The local input to be permuted. Currently only supports float and + bfloat16. +source_target_pairs: A tensor with shape [num_pairs, 2]. +output: The permuted input. +T: The type of elements to be exchanged. +)doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index d92a0652bb..a1aee69691 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -95,7 +95,7 @@ if platform.system() != "Windows": ] def cross_replica_sum(x, group_assignment=None, name=None): - """Sum the input tensor accorss replicas according to group_assignment. + """Sum the input tensor across replicas according to group_assignment. Args: x: The local tensor to the sum. @@ -112,6 +112,31 @@ if platform.system() != "Windows": return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name) + def collective_permute(x, source_target_pairs, name=None): + """Permute the input tensor across replicas given source_target_pairs. + + For each source_target_pair <a, b>, we send replica a's input to replica b. + Each replica id must only appear once in the source column. Also it must + only appear once in the target column. + For the replica id not in the target column, this op returns a zero tensor + with the same shape and dtype of the input x. + + For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing + source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs: + `[0, A, B, C]`. + + Args: + x: The local tensor to be permuted. + source_target_pairs: 2d int lists with shape [num_pairs, 2]. + source_target_pairs[i][0] represents the source replica id and + source_target_pairs[i][1] represents the target replica id. + name: Optional op name. + + Returns: + A `Tensor` which is permuted. + """ + return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name) + @ops.RegisterGradient("CrossReplicaSum") def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. |