aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-08 17:39:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 17:42:02 -0700
commit80459fe0fdcb86b286311559c65a7ec43525e278 (patch)
treeccdb1d36d150c4b8a261c9c48d2a0b2ef7fac637 /tensorflow/compiler/xla/service/shape_inference.cc
parentf81f62a0d35ccf7c4e83e09510447d93933ef87e (diff)
Cleanup shape_inference.
PiperOrigin-RevId: 199876297
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc346
1 files changed, 99 insertions, 247 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index fdc7f41759..bd98e86b08 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -44,129 +44,6 @@ namespace xla {
namespace {
-// Return the UnaryOperation proto enum value associated with the given HLO
-// opcode.
-UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAbs:
- return UNOP_ABS;
- case HloOpcode::kCeil:
- return UNOP_CEIL;
- case HloOpcode::kClz:
- return UNOP_CLZ;
- case HloOpcode::kCos:
- return UNOP_COS;
- case HloOpcode::kExp:
- return UNOP_EXP;
- case HloOpcode::kExpm1:
- return UNOP_EXPM1;
- case HloOpcode::kFloor:
- return UNOP_FLOOR;
- case HloOpcode::kImag:
- return UNOP_IMAG;
- case HloOpcode::kIsFinite:
- return UNOP_IS_FINITE;
- case HloOpcode::kLog:
- return UNOP_LOG;
- case HloOpcode::kLog1p:
- return UNOP_LOG1P;
- case HloOpcode::kNot:
- return UNOP_NOT;
- case HloOpcode::kNegate:
- return UNOP_NEGATE;
- case HloOpcode::kReal:
- return UNOP_REAL;
- case HloOpcode::kRoundNearestAfz:
- return UNOP_ROUND_NEAREST_AFZ;
- case HloOpcode::kSign:
- return UNOP_SIGN;
- case HloOpcode::kSin:
- return UNOP_SIN;
- case HloOpcode::kSort:
- return UNOP_SORT;
- case HloOpcode::kTanh:
- return UNOP_TANH;
- default:
- LOG(FATAL) << "Unhandled opcode for conversion to unary operation: "
- << opcode;
- }
-}
-
-// Return the BinaryOperation proto enum value associated with the given HLO
-// opcode.
-BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kAtan2:
- return BINOP_ATAN2;
- case HloOpcode::kComplex:
- return BINOP_COMPLEX;
- case HloOpcode::kMultiply:
- return BINOP_MUL;
- case HloOpcode::kAdd:
- return BINOP_ADD;
- case HloOpcode::kSubtract:
- return BINOP_SUB;
- case HloOpcode::kDivide:
- return BINOP_DIV;
- case HloOpcode::kEq:
- return BINOP_EQ;
- case HloOpcode::kGe:
- return BINOP_GE;
- case HloOpcode::kGt:
- return BINOP_GT;
- case HloOpcode::kLe:
- return BINOP_LE;
- case HloOpcode::kLt:
- return BINOP_LT;
- case HloOpcode::kNe:
- return BINOP_NE;
- case HloOpcode::kMaximum:
- return BINOP_MAX;
- case HloOpcode::kMinimum:
- return BINOP_MIN;
- case HloOpcode::kPower:
- return BINOP_POW;
- case HloOpcode::kRemainder:
- return BINOP_REM;
- case HloOpcode::kOr:
- return BINOP_OR;
- case HloOpcode::kAnd:
- return BINOP_AND;
- case HloOpcode::kShiftLeft:
- return BINOP_SHIFT_LEFT;
- case HloOpcode::kShiftRightArithmetic:
- return BINOP_SHIFT_RIGHT_ARITHMETIC;
- case HloOpcode::kShiftRightLogical:
- return BINOP_SHIFT_RIGHT_LOGICAL;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
-// Return the TernaryOperation proto enum value associated with the given HLO
-// opcode.
-TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kClamp:
- return TRIOP_CLAMP;
- case HloOpcode::kSelect:
- return TRIOP_SELECT;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
-// Return the VariadicOperation proto enum value associated with the given HLO
-// opcode.
-VariadicOperation OpcodeToVariadicOperation(HloOpcode opcode) {
- switch (opcode) {
- case HloOpcode::kTuple:
- return VAROP_TUPLE;
- default:
- LOG(FATAL) << "unhandled opcode " << opcode;
- }
-}
-
// Returns true if no element is present in slice more than once.
bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
@@ -321,84 +198,81 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
}
- return InferUnaryOpShape(OpcodeToUnaryOperation(opcode), shape);
-}
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(shape, "operand of unary operation"));
-/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
- UnaryOperation operation, const Shape& arg) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
-
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(arg));
- switch (operation) {
- case UNOP_FLOOR:
- case UNOP_CEIL:
- if (!ShapeUtil::ElementIsFloating(arg)) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+ switch (opcode) {
+ case HloOpcode::kFloor:
+ case HloOpcode::kCeil:
+ if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
"Expected element type in shape to be floating for floor/ceil "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
- case UNOP_COS:
- case UNOP_SIN:
- case UNOP_EXP:
- case UNOP_EXPM1:
- case UNOP_LOG:
- case UNOP_LOG1P:
- case UNOP_TANH:
- if (!ShapeUtil::ElementIsFloating(arg) &&
- !ShapeUtil::ElementIsComplex(arg)) {
+ return shape;
+ case HloOpcode::kCos:
+ case HloOpcode::kSin:
+ case HloOpcode::kExp:
+ case HloOpcode::kExpm1:
+ case HloOpcode::kLog:
+ case HloOpcode::kLog1p:
+ case HloOpcode::kTanh:
+ if (!ShapeUtil::ElementIsFloating(shape) &&
+ !ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
"sin/cos/exp/log/tanh operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
- case UNOP_REAL:
- case UNOP_IMAG:
- if (!ShapeUtil::ElementIsComplex(arg)) {
+ return shape;
+ case HloOpcode::kReal:
+ case HloOpcode::kImag:
+ if (!ShapeUtil::ElementIsComplex(shape)) {
return InvalidArgument(
"Expected element type in shape to be complex for real/imag "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(arg, F32);
- case UNOP_ABS:
- if (ShapeUtil::ElementIsComplex(arg)) {
+ return ShapeUtil::ChangeElementType(shape, F32);
+ case HloOpcode::kAbs:
+ if (ShapeUtil::ElementIsComplex(shape)) {
return ShapeUtil::ChangeElementType(
- arg, primitive_util::ComplexComponentType(arg.element_type()));
+ shape, primitive_util::ComplexComponentType(shape.element_type()));
}
- return arg;
- case UNOP_CLZ:
- case UNOP_NEGATE:
- case UNOP_ROUND_NEAREST_AFZ:
- case UNOP_SIGN:
- case UNOP_SORT:
- return arg;
-
- case UNOP_NOT:
- if (arg.element_type() != PRED &&
- !primitive_util::IsIntegralType(arg.element_type())) {
+ return shape;
+ case HloOpcode::kClz:
+ case HloOpcode::kNegate:
+ case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kSign:
+ case HloOpcode::kSort:
+ return shape;
+
+ case HloOpcode::kNot:
+ if (shape.element_type() != PRED &&
+ !primitive_util::IsIntegralType(shape.element_type())) {
return InvalidArgument(
"Expected pred or an integral element type in argument to Not "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return arg;
+ return shape;
- case UNOP_IS_FINITE:
- if (!ShapeUtil::ElementIsFloating(arg)) {
+ case HloOpcode::kIsFinite:
+ if (!ShapeUtil::ElementIsFloating(shape)) {
return InvalidArgument(
- "Expected element type in shape to be floating point for IsFinite "
+ "Expected element type in shape to be floating "
+ "point for IsFinite "
"operation; got %s.",
- PrimitiveType_Name(arg.element_type()).c_str());
+ PrimitiveType_Name(shape.element_type()).c_str());
}
- return ShapeUtil::ChangeElementType(arg, PRED);
+ return ShapeUtil::ChangeElementType(shape, PRED);
default:
return InvalidArgument(
"Unknown operation for unary shape inference: \"%s\".",
- UnaryOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -779,8 +653,9 @@ Status ValidateDotDimensionNumbers(
}
/* static */ StatusOr<Shape>
-ShapeInference::InferDegenerateDimensionBroadcastShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
+ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
+ const Shape& lhs,
+ const Shape& rhs) {
TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
// The shapes have to be compatible. That is, if some dimension d has a
@@ -798,7 +673,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
} else {
return InvalidArgument(
"Binary op %s with incompatible shapes: %s and %s.",
- BinaryOperation_Name(operation).c_str(),
+ HloOpcodeString(operation).c_str(),
ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
@@ -808,8 +683,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
- BinaryOperation operation, const Shape& smaller_shape,
- const Shape& larger_shape,
+ const Shape& smaller_shape, const Shape& larger_shape,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
// Reject "magic" inference for binops on different shapes, requiring
@@ -910,7 +784,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
@@ -920,8 +794,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Binary op %s with different element types: %s and %s.",
- BinaryOperation_Name(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(),
+ HloOpcodeString(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(),
ShapeUtil::HumanString(rhs).c_str());
}
@@ -954,10 +827,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
// After InDim broadcasting, perform degenerate dimensions broadcasting.
- TF_ASSIGN_OR_RETURN(
- Shape indim_broadcast_shape,
- InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
- broadcast_dimensions));
+ TF_ASSIGN_OR_RETURN(Shape indim_broadcast_shape,
+ InferInDimBroadcastShape(smaller_shape, larger_shape,
+ broadcast_dimensions));
return InferDegenerateDimensionBroadcastShape(
operation, indim_broadcast_shape, larger_shape);
@@ -966,51 +838,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const HloInstruction* lhs, const HloInstruction* rhs) {
- return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs->shape(),
- rhs->shape(), /*broadcast_dimensions=*/{});
+ return InferBinaryOpShape(opcode, lhs->shape(), rhs->shape(),
+ /*broadcast_dimensions=*/{});
}
/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- return InferBinaryOpShape(OpcodeToBinaryOperation(opcode), lhs, rhs,
- broadcast_dimensions);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
- BinaryOperation operation, const Shape& lhs, const Shape& rhs,
- tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
VLOG(2) << tensorflow::strings::Printf(
"inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
- BinaryOperation_Name(operation).c_str(),
- ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
+ HloOpcodeString(opcode).c_str(), ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str(),
Join(broadcast_dimensions, ", ").c_str());
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- BinaryOperation_Name(operation))));
+ HloOpcodeString(opcode))));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- BinaryOperation_Name(operation))));
- switch (operation) {
- case BINOP_MAX:
- case BINOP_MIN:
- case BINOP_SUB:
- case BINOP_ADD:
- case BINOP_ATAN2:
- case BINOP_POW:
- case BINOP_DIV:
- case BINOP_REM:
- case BINOP_MUL:
- case BINOP_SHIFT_LEFT:
- case BINOP_SHIFT_RIGHT_ARITHMETIC:
- case BINOP_SHIFT_RIGHT_LOGICAL:
- return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ HloOpcodeString(opcode))));
+ switch (opcode) {
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kAdd:
+ case HloOpcode::kAtan2:
+ case HloOpcode::kPower:
+ case HloOpcode::kDivide:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kShiftLeft:
+ case HloOpcode::kShiftRightArithmetic:
+ case HloOpcode::kShiftRightLogical:
+ return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
- case BINOP_COMPLEX: {
+ case HloOpcode::kComplex: {
if (!ShapeUtil::ElementIsFloating(lhs)) {
return InvalidArgument(
"Expected element type in shape to be floating for complex compose "
@@ -1018,7 +883,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(lhs.element_type()).c_str());
}
TF_ASSIGN_OR_RETURN(const Shape& shape,
- InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
@@ -1026,8 +891,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return Unimplemented("Complex component type is not implemented.");
}
}
- case BINOP_AND:
- case BINOP_OR:
+ case HloOpcode::kAnd:
+ case HloOpcode::kOr:
if (lhs.element_type() != PRED &&
!primitive_util::IsIntegralType(lhs.element_type())) {
return InvalidArgument(
@@ -1035,24 +900,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"got %s.",
PrimitiveType_Name(lhs.element_type()).c_str());
}
- return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ return InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions);
- case BINOP_EQ:
- case BINOP_GE:
- case BINOP_GT:
- case BINOP_LE:
- case BINOP_LT:
- case BINOP_NE: {
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kNe: {
TF_ASSIGN_OR_RETURN(const Shape& shape,
- InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ InferElementwiseBinaryOpShape(opcode, lhs, rhs,
broadcast_dimensions));
return ShapeUtil::ChangeElementType(shape, PRED);
}
default:
return Unimplemented(
"Binary op shape inference: %s; lhs: %s; rhs: %s is not implemented.",
- BinaryOperation_Name(operation).c_str(),
- lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
+ HloOpcodeString(opcode).c_str(), lhs.ShortDebugString().c_str(),
+ rhs.ShortDebugString().c_str());
}
}
@@ -1064,23 +929,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
HloOpcode opcode, const Shape& lhs, const Shape& rhs, const Shape& ehs) {
- return InferTernaryOpShape(OpcodeToTernaryOperation(opcode), lhs, rhs, ehs);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
- TernaryOperation operation, const Shape& lhs, const Shape& rhs,
- const Shape& ehs) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(ehs));
- switch (operation) {
- case TRIOP_CLAMP:
+ switch (opcode) {
+ case HloOpcode::kClamp:
return InferClampShape(lhs, rhs, ehs);
- case TRIOP_SELECT:
+ case HloOpcode::kSelect:
return InferSelectShape(lhs, rhs, ehs);
default:
return InvalidArgument("Unknown operation %s.",
- TernaryOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}
@@ -1097,18 +956,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
HloOpcode opcode,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
- return InferVariadicOpShape(OpcodeToVariadicOperation(opcode),
- operand_shapes);
-}
-
-/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
- VariadicOperation operation,
- tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* shape : operand_shapes) {
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(*shape));
}
- switch (operation) {
- case VAROP_TUPLE: {
+ switch (opcode) {
+ case HloOpcode::kTuple: {
Shape result = ShapeUtil::MakeTupleShape({});
for (const Shape* shape : operand_shapes) {
ShapeUtil::AppendShapeToTuple(*shape, &result);
@@ -1117,7 +969,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
default:
return InvalidArgument("Unknown operation %s.",
- VariadicOperation_Name(operation).c_str());
+ HloOpcodeString(opcode).c_str());
}
}