diff options
Diffstat (limited to 'tensorflow/contrib/tpu/ops/cross_replica_ops.cc')
-rw-r--r-- | tensorflow/contrib/tpu/ops/cross_replica_ops.cc | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc index 06553929dc..9ee5ecb123 100644 --- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -21,9 +21,9 @@ namespace tensorflow { REGISTER_OP("CrossReplicaSum") .Input("input: T") + .Input("group_assignment: int32") .Output("output: T") .Attr("T: {bfloat16, float}") - .Attr("group_assignment: list(int) = []") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( An Op to sum inputs across replicated TPU instances. Each @@ -31,15 +31,17 @@ instance supplies its own input. If group_assignment is empty, the output of each is the sum of all the inputs, otherwise the output of each is the sum of the inputs belonging to the same group. -For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing -group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1. -Thus we get the outputs: `[A+C, B+D, A+C, B+D]`. +For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`. +Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0, +and `B, D, F, H` as group 1. Thus we get the outputs: +`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`. input: The local input to the sum. +group_assignment: An int32 tensor with shape + [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the + replica ids in the ith subgroup. output: The sum of all the distributed inputs. T: The type of elements to be summed. -group_assignment: The list of group ids. `group_assignment[i]` represents the - group id of replica i. )doc"); } // namespace tensorflow |