diff options
author | 2018-03-07 03:44:48 -0800 | |
---|---|---|
committer | 2018-03-07 03:48:18 -0800 | |
commit | 4f0aa15e9635c33ca37f3aa714b10f4ca3199e7f (patch) | |
tree | ca4a4ebe930eff77fa011cfb29381676f8d90f01 /tensorflow/compiler/xla/service/shape_inference.cc | |
parent | c0824a4eeaffa7e30119fef21a5b689c972e6657 (diff) |
Fix ShapeUtil::CompatibleIgnoringElementType for scalar vs tuple comparision
Previously if the lhs was a scalar and the rhs was a tuple of arbitrary
shape it reported them as compatible what is clearly wrong.
PiperOrigin-RevId: 188155575
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index c54cb3b48d..915baecc56 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2394,7 +2394,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "Select's pred operand must have PRED element type; got %s.", ShapeUtil::HumanString(pred).c_str()); } - if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) { + if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || + ShapeUtil::Rank(pred) == 0) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. |