aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-06-29 10:57:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 10:59:52 -0700
commitec0a702ff1f22b73cec2d8f14c7a84c5a02856fd (patch)
treebbe4a56f06e157fb69628fd40e4ab66f0ece27ea /tensorflow/compiler/xla/service/shape_inference.cc
parentaa060bebb0fb064460ae4c3e92a0272be4ea04de (diff)
[XLA] Add key-value version of Sort HLO.
This is only currently implemented in the evaluator backend, and even that implementation is partial - the key and value type must match. PiperOrigin-RevId: 202673122
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 096bbde922..d05e995a95 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -239,7 +239,6 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
case HloOpcode::kNegate:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSign:
- case HloOpcode::kSort:
return shape;
case HloOpcode::kNot:
@@ -962,6 +961,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
return result;
}
+ case HloOpcode::kSort: {
+ if (operand_shapes.size() == 1) {
+ return *operand_shapes[0];
+ } else if (operand_shapes.size() == 2) {
+ return ShapeUtil::MakeTupleShape(
+ {*operand_shapes[0], *operand_shapes[1]});
+ }
+ return InvalidArgument("Unexpected number of operands for sort");
+ }
default:
return InvalidArgument("Unknown operation %s.",
HloOpcodeString(opcode).c_str());