diff options
author | 2018-04-10 20:48:57 -0700 | |
---|---|---|
committer | 2018-04-10 20:51:11 -0700 | |
commit | 0c0428e41289392be095bb07f5daa1a0c4557c8c (patch) | |
tree | a7c7468d3f2f4c0e17000faf77f010a3e5f0d840 | |
parent | f3180f3827ef1340f51408385f139143da55f07f (diff) |
[XLA] Redesign: implment and test CrossReplicaSum.
PiperOrigin-RevId: 192397189
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 40bafdb5c1..3b96bc72be 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1559,7 +1559,17 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, } XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); + TF_ASSIGN_OR_RETURN( + *instr.mutable_shape(), + ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + + return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, + {operand}); + }); } XlaOp XlaBuilder::SelectAndScatter( |