aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-07-17 14:24:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 14:28:24 -0700
commit8c5d2127182e0fadc0dcd6e97cb4acfba3a4c343 (patch)
treedfcdb842e4871717c64f15b05e749acbb8a64ac7 /tensorflow/compiler/xla/service/shape_inference.cc
parent2f93ac4891f81137ce5fc40a8bbb2714b6cf2151 (diff)
[XLA] Shape inference should verify the shapes of sort keys and sort values match.
PiperOrigin-RevId: 204974328
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 214146cf68..35df792b07 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -970,6 +970,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
} else if (operand_shapes.size() == 2) {
+ if (!ShapeUtil::SameDimensions(*operand_shapes[0],
+ *operand_shapes[1])) {
+ return InvalidArgument(
+ "Sort keys and values dimensions must match. "
+ "Keys shape is: %s\n, Values shape is: %s",
+ ShapeUtil::HumanString(*operand_shapes[0]).c_str(),
+ ShapeUtil::HumanString(*operand_shapes[1]).c_str());
+ }
return ShapeUtil::MakeTupleShape(
{*operand_shapes[0], *operand_shapes[1]});
}