aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Keno Fischer <keno@juliacomputing.com>2018-09-16 18:39:50 -0400
committerGravatar Keno Fischer <keno@juliacomputing.com>2018-10-02 17:00:29 -0400
commita12b8c4afdca3ac2945d62b3b83ca2599ab360f9 (patch)
tree22e0daaa50d57108f202d9fcc2a0742595e7b4d2 /tensorflow/compiler
parenta6ee64cd216b3ac440262e1f4ec7872fe7026df6 (diff)
[xla] Improve validation of Broadcast shape
If one misreads the semantics of this instruction, it's easy to cause an out of bounds access into the dimensions here. Add an extra check to return a proper error to the user rather than crashing in that case. Ref #22130
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 50f39cbcb5..0f6ecd42f6 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -313,8 +313,9 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
operand_dimension < ShapeUtil::Rank(operand_shape);
++operand_dimension) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
- TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension))
+ TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) &&
+ (broadcast->shape().dimensions(output_dimension) ==
+ operand_shape.dimensions(operand_dimension)))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return Status::OK();