diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 37 |
1 files changed, 27 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 81f071ecc5..70edf7883f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -929,6 +929,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return InferClampShape(lhs, rhs, ehs); case HloOpcode::kSelect: return InferSelectShape(lhs, rhs, ehs); + case HloOpcode::kTupleSelect: + return InferTupleSelectShape(lhs, rhs, ehs); default: return InvalidArgument("Unknown operation %s.", HloOpcodeString(opcode).c_str()); @@ -2267,15 +2269,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, // broadcast from all operands, not just the predicate. /* static */ StatusOr<Shape> ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - bool compatible; - if (ShapeUtil::IsTuple(on_true)) { - // Select only defines the top-level buffer, so if it's a tuple, the two - // input must match exactly. - compatible = ShapeUtil::Compatible(on_true, on_false); - } else { - compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); - } - if (!compatible) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) { return InvalidArgument( "Operands to select must be the same shape; got %s and %s.", ShapeUtil::HumanString(on_true).c_str(), @@ -2287,7 +2281,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, ShapeUtil::HumanString(pred).c_str()); } if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) || - ShapeUtil::Rank(pred) == 0) { + ShapeUtil::IsScalar(pred)) { // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. @@ -2301,6 +2295,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } } +/* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape( + const Shape& pred, const Shape& on_true, const Shape& on_false) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + if (!ShapeUtil::Compatible(on_true, on_false)) { + return InvalidArgument( + "Operands to tuple-select must be the same shape; got %s and %s.", + ShapeUtil::HumanString(on_true).c_str(), + ShapeUtil::HumanString(on_false).c_str()); + } + if (pred.element_type() != PRED) { + return InvalidArgument( + "TupleSelect's pred operand must have PRED element type; got %s.", + ShapeUtil::HumanString(pred).c_str()); + } + if (!ShapeUtil::IsScalar(pred)) { + return InvalidArgument( + "TupleSelect operation with non-scalar predicate: %s.", + ShapeUtil::HumanString(pred).c_str()); + } + return on_true; +} + /* static */ StatusOr<Shape> ShapeInference::InferCallShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, const ProgramShape& to_apply) { |