diff options
author | 2018-06-29 10:57:24 -0700 | |
---|---|---|
committer | 2018-06-29 10:59:52 -0700 | |
commit | ec0a702ff1f22b73cec2d8f14c7a84c5a02856fd (patch) | |
tree | bbe4a56f06e157fb69628fd40e4ab66f0ece27ea /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | aa060bebb0fb064460ae4c3e92a0272be4ea04de (diff) |
[XLA] Add key-value version of Sort HLO.
This is only currently implemented in the evaluator backend, and even that implementation is partial - the key and value type must match.
PiperOrigin-RevId: 202673122
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 096bbde922..d05e995a95 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -239,7 +239,6 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kNegate: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSign: - case HloOpcode::kSort: return shape; case HloOpcode::kNot: @@ -962,6 +961,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } return result; } + case HloOpcode::kSort: { + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } else if (operand_shapes.size() == 2) { + return ShapeUtil::MakeTupleShape( + {*operand_shapes[0], *operand_shapes[1]}); + } + return InvalidArgument("Unexpected number of operands for sort"); + } default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); |