path: root/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
diff options
Diffstat (limited to 'tensorflow/contrib/tpu/ops/cross_replica_ops.cc')
1 files changed, 85 insertions, 4 deletions
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index 9ee5ecb123..ea8e0e00ed 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -18,6 +18,89 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+ .Input("input: T")
+ .Input("group_assignment: int32")
+ .Output("output: T")
+ .Attr("T: {bfloat16, float}")
+ .Attr("concat_dimension: int")
+ .Attr("split_dimension: int")
+ .Attr("split_count: int")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ int64 rank;
+ if (c->RankKnown(input)) {
+ rank = c->Rank(input);
+ } else {
+ return errors::InvalidArgument("input's rank is unknown.");
+ }
+ int concat_dimension;
+ int split_dimension;
+ TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
+ if (concat_dimension < 0 || concat_dimension >= rank) {
+ return errors::InvalidArgument("concat_dimension ", concat_dimension,
+ " is out of range of input rank ", rank);
+ }
+ TF_RETURN_IF_ERROR(c->GetAttr("split_dimension", &split_dimension));
+ if (split_dimension < 0 || split_dimension >= rank) {
+ return errors::InvalidArgument("split_dimension ", split_dimension,
+ " is out of range of input rank ", rank);
+ }
+ std::vector<DimensionHandle> dims;
+ dims.resize(rank);
+ for (int32 i = 0; i < rank; ++i) {
+ int64 in_idx = i;
+ if (i == concat_dimension) {
+ in_idx = split_dimension;
+ } else if (i == split_dimension) {
+ in_idx = concat_dimension;
+ }
+ dims[i] = c->Dim(input, in_idx);
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+An Op to exchange data across TPU replicas. On each replica, the input is
+split into `split_count` blocks along `split_dimension` and send to the other
+replicas given group_assignment. After receiving `split_count` - 1 blocks from
+other replicas, we concatenate the blocks along `concat_dimension` as the
+For example, suppose there are 2 TPU replicas:
+replica 0 receives input: `[[A, B]]`
+replica 1 receives input: `[[C, D]]`
+group_assignment=`[[0, 1]]`
+replica 0's output: `[[A], [C]]`
+replica 1's output: `[[B], [D]]`
+input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
+concat_dimension: The dimension number to concatenate.
+split_dimension: The dimension number to split.
+split_count: The number of splits, this number must equal to the sub-group
+ size(group_assignment.get_shape()[1])
+output: The exchanged result.
+T: The type of elements to be exchanged.
.Input("input: T")
@@ -26,10 +109,8 @@ REGISTER_OP("CrossReplicaSum")
.Attr("T: {bfloat16, float}")
-An Op to sum inputs across replicated TPU instances. Each
-instance supplies its own input. If group_assignment is empty, the output of
-each is the sum of all the inputs, otherwise the output of each is the sum of
-the inputs belonging to the same group.
+An Op to sum inputs across replicated TPU instances. Each instance supplies its
+own input.
For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`.
Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,