aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu
diff options
context:
space:
mode:
authorGravatar Youlong Cheng <ylc@google.com>2018-09-24 19:47:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 19:52:14 -0700
commite9cdf9f412a3aea324a4a1655d3bffb87abaff0d (patch)
tree24559e0a02d601cf8c2916b1d7cf4c8ca51918e0 /tensorflow/contrib/tpu
parentbb1c131aad55e336d25fd297ecd8582773d6476f (diff)
[TF:XLA] Introduce CollectivePermute op.
PiperOrigin-RevId: 214373714
Diffstat (limited to 'tensorflow/contrib/tpu')
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc20
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py27
2 files changed, 46 insertions, 1 deletions
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index ea8e0e00ed..87e3a5946c 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -125,4 +125,24 @@ output: The sum of all the distributed inputs.
T: The type of elements to be summed.
)doc");
+REGISTER_OP("CollectivePermute")
+ .Input("input: T")
+ .Input("source_target_pairs: int32")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+An Op to permute tensors across replicated TPU instances. Each instance
+supplies its own input.
+
+For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
+`[D, A, B, C]`.
+
+input: The local input to be permuted. Currently only supports float and
+ bfloat16.
+source_target_pairs: A tensor with shape [num_pairs, 2].
+output: The permuted input.
+T: The type of elements to be exchanged.
+)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index d92a0652bb..a1aee69691 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -95,7 +95,7 @@ if platform.system() != "Windows":
]
def cross_replica_sum(x, group_assignment=None, name=None):
- """Sum the input tensor accorss replicas according to group_assignment.
+ """Sum the input tensor across replicas according to group_assignment.
Args:
x: The local tensor to the sum.
@@ -112,6 +112,31 @@ if platform.system() != "Windows":
return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+ def collective_permute(x, source_target_pairs, name=None):
+ """Permute the input tensor across replicas given source_target_pairs.
+
+ For each source_target_pair <a, b>, we send replica a's input to replica b.
+ Each replica id must only appear once in the source column. Also it must
+ only appear once in the target column.
+ For the replica id not in the target column, this op returns a zero tensor
+ with the same shape and dtype of the input x.
+
+ For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
+ source_target_pairs=`[[0,1],[1,2],[2,3]]` gets the outputs:
+ `[0, A, B, C]`.
+
+ Args:
+ x: The local tensor to be permuted.
+ source_target_pairs: 2d int lists with shape [num_pairs, 2].
+ source_target_pairs[i][0] represents the source replica id and
+ source_target_pairs[i][1] represents the target replica id.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is permuted.
+ """
+ return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.