aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc74
1 files changed, 55 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 096bbde922..35df792b07 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -69,11 +69,11 @@ Status VerifyReducerShape(const ProgramShape& reducer_shape,
}
const Shape& accumulator_shape = reducer_shape.result();
- if (ShapeUtil::Rank(accumulator_shape) != 0) {
+ if (!ShapeUtil::IsArray(accumulator_shape) ||
+ ShapeUtil::Rank(accumulator_shape) != 0) {
return InvalidArgument(
- "Reduction function must have rank 0 (rank %lld reduction function "
- "given).",
- ShapeUtil::Rank(accumulator_shape));
+ "Reduction function must produce a scalar but has shape: %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str());
}
// Check that the accumulator can be passed in as the first argument.
@@ -222,13 +222,16 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
case HloOpcode::kReal:
case HloOpcode::kImag:
- if (!ShapeUtil::ElementIsComplex(shape)) {
+ if (ShapeUtil::ElementIsComplex(shape)) {
+ return ShapeUtil::ComplexComponentShape(shape);
+ } else if (ShapeUtil::ElementIsFloating(shape)) {
+ return shape;
+ } else {
return InvalidArgument(
- "Expected element type in shape to be complex for real/imag "
- "operation; got %s.",
+ "Expected element type in shape to be floating or complex for "
+ "real/imag operation; got %s.",
PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(shape, F32);
case HloOpcode::kAbs:
if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
@@ -239,7 +242,6 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
case HloOpcode::kNegate:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kSign:
- case HloOpcode::kSort:
return shape;
case HloOpcode::kNot:
@@ -930,6 +932,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InferClampShape(lhs, rhs, ehs);
case HloOpcode::kSelect:
return InferSelectShape(lhs, rhs, ehs);
+ case HloOpcode::kTupleSelect:
+ return InferTupleSelectShape(lhs, rhs, ehs);
default:
return InvalidArgument("Unknown operation %s.",
HloOpcodeString(opcode).c_str());
@@ -962,6 +966,23 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
return result;
}
+ case HloOpcode::kSort: {
+ if (operand_shapes.size() == 1) {
+ return *operand_shapes[0];
+ } else if (operand_shapes.size() == 2) {
+ if (!ShapeUtil::SameDimensions(*operand_shapes[0],
+ *operand_shapes[1])) {
+ return InvalidArgument(
+ "Sort keys and values dimensions must match. "
+ "Keys shape is: %s\n, Values shape is: %s",
+ ShapeUtil::HumanString(*operand_shapes[0]).c_str(),
+ ShapeUtil::HumanString(*operand_shapes[1]).c_str());
+ }
+ return ShapeUtil::MakeTupleShape(
+ {*operand_shapes[0], *operand_shapes[1]});
+ }
+ return InvalidArgument("Unexpected number of operands for sort");
+ }
default:
return InvalidArgument("Unknown operation %s.",
HloOpcodeString(opcode).c_str());
@@ -2259,15 +2280,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
const Shape& pred, const Shape& on_true, const Shape& on_false) {
- bool compatible;
- if (ShapeUtil::IsTuple(on_true)) {
- // Select only defines the top-level buffer, so if it's a tuple, the two
- // input must match exactly.
- compatible = ShapeUtil::Compatible(on_true, on_false);
- } else {
- compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
- }
- if (!compatible) {
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false)) {
return InvalidArgument(
"Operands to select must be the same shape; got %s and %s.",
ShapeUtil::HumanString(on_true).c_str(),
@@ -2279,7 +2292,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(pred).c_str());
}
if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
- ShapeUtil::Rank(pred) == 0) {
+ ShapeUtil::IsScalar(pred)) {
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
@@ -2293,6 +2306,29 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
}
}
+/* static */ StatusOr<Shape> ShapeInference::InferTupleSelectShape(
+ const Shape& pred, const Shape& on_true, const Shape& on_false) {
+ // Select only defines the top-level buffer, so if it's a tuple, the two
+ // input must match exactly.
+ if (!ShapeUtil::Compatible(on_true, on_false)) {
+ return InvalidArgument(
+ "Operands to tuple-select must be the same shape; got %s and %s.",
+ ShapeUtil::HumanString(on_true).c_str(),
+ ShapeUtil::HumanString(on_false).c_str());
+ }
+ if (pred.element_type() != PRED) {
+ return InvalidArgument(
+ "TupleSelect's pred operand must have PRED element type; got %s.",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ if (!ShapeUtil::IsScalar(pred)) {
+ return InvalidArgument(
+ "TupleSelect operation with non-scalar predicate: %s.",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ return on_true;
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
const ProgramShape& to_apply) {