diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-07-17 14:24:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-17 14:28:24 -0700 |
commit | 8c5d2127182e0fadc0dcd6e97cb4acfba3a4c343 (patch) | |
tree | dfcdb842e4871717c64f15b05e749acbb8a64ac7 /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | 2f93ac4891f81137ce5fc40a8bbb2714b6cf2151 (diff) |
[XLA] Shape inference should verify the shapes of sort keys and sort values match.
PiperOrigin-RevId: 204974328
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 214146cf68..35df792b07 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -970,6 +970,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, if (operand_shapes.size() == 1) { return *operand_shapes[0]; } else if (operand_shapes.size() == 2) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[1])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape is: %s", + ShapeUtil::HumanString(*operand_shapes[0]).c_str(), + ShapeUtil::HumanString(*operand_shapes[1]).c_str()); + } return ShapeUtil::MakeTupleShape( {*operand_shapes[0], *operand_shapes[1]}); } |