aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tpu/ops/cross_replica_ops.cc')
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index 06553929dc..9ee5ecb123 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -21,9 +21,9 @@ namespace tensorflow {
REGISTER_OP("CrossReplicaSum")
.Input("input: T")
+ .Input("group_assignment: int32")
.Output("output: T")
.Attr("T: {bfloat16, float}")
- .Attr("group_assignment: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An Op to sum inputs across replicated TPU instances. Each
@@ -31,15 +31,17 @@ instance supplies its own input. If group_assignment is empty, the output of
each is the sum of all the inputs, otherwise the output of each is the sum of
the inputs belonging to the same group.
-For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
-group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1.
-Thus we get the outputs: `[A+C, B+D, A+C, B+D]`.
+For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`.
+Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
+and `B, D, F, H` as group 1. Thus we get the outputs:
+`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
-group_assignment: The list of group ids. `group_assignment[i]` represents the
- group id of replica i.
)doc");
} // namespace tensorflow