aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-10 20:48:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 20:51:11 -0700
commit0c0428e41289392be095bb07f5daa1a0c4557c8c (patch)
treea7c7468d3f2f4c0e17000faf77f010a3e5f0d840
parentf3180f3827ef1340f51408385f139143da55f07f (diff)
[XLA] Redesign: implment and test CrossReplicaSum.
PiperOrigin-RevId: 192397189
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 40bafdb5c1..3b96bc72be 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1559,7 +1559,17 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
}
XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
+
+ return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
+ {operand});
+ });
}
XlaOp XlaBuilder::SelectAndScatter(