aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-13 14:19:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 14:24:30 -0700
commite1296c15a32cac020160a1c89002dc561333c66b (patch)
tree62ef334470d3484d166ea583eddea10fc24d1718 /tensorflow/compiler/xla/service/shape_inference.cc
parentbf920de58a3ccb2cfe6642be9c487c3fcb13ccae (diff)
Fix assumptions that a Shape must be a tuple or an array.
A TOKEN primitive type was added with cl/199215963 and XLA also has an OPAQUE primitive type. However, in many places in XLA we assume either a tuple or array. This CL fixes many of those instances, but some may remain. Identified instances were discovered by searching for IsTuple or IsArray so the set of fixes is not exhaustive. Also opportunistically addressed a couple potential points of confusion in the ShapeUtil interface: (1) Rename ShapeUtil::HasZeroElements to ShapeUtil::IsZeroElementArray. The point of confusion here is that tuples can also have zero elements and HasZeroElements would check fail on tuple shapes. Method no longer check fails if the given shape is not an array. (2) ShapeUtil::IsNil now returns true only for empty tuples. Previously it also returned true for zero-element array types which was confusing because ShapeUtil::MakeNil creates an empty tuple. PiperOrigin-RevId: 200452672
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc141
1 files changed, 62 insertions, 79 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index bd98e86b08..e25f5e67c7 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -49,19 +49,13 @@ bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectNotTupleOrOpaque(const Shape& shape,
- tensorflow::StringPiece op_type) {
- if (ShapeUtil::IsTuple(shape)) {
- return InvalidArgument("Expected non-tuple argument for %s, but got %s.",
+Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument("Expected array argument for %s, but got %s.",
std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
- } else if (ShapeUtil::IsOpaque(shape)) {
- return InvalidArgument("Expected non-opaque argument for %s, but got %s.",
- std::string(op_type).c_str(),
- ShapeUtil::HumanString(shape).c_str());
- } else {
- return Status::OK();
}
+ return Status::OK();
}
Status VerifyReducerShape(const ProgramShape& reducer_shape,
@@ -198,8 +192,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
}
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(shape, "operand of unary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
switch (opcode) {
@@ -289,8 +282,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const Shape* arg_shape = nullptr;
PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
for (const Shape* shape : arg_shapes) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
+ TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
if (!arg_shape) {
arg_shape = shape;
element_type = arg_shape->element_type();
@@ -337,7 +329,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
-/* static */ StatusOr<Shape> ShapeInference::InferTokenShape(
+/* static */ StatusOr<Shape> ShapeInference::InferGenerateTokenShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
@@ -358,12 +350,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
return InvalidArgument(
- "Convert does not allow tuples, so cannot convert from %s to %s.",
+ "Convert does not allow non-arrays, so cannot convert from %s to %s.",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
@@ -380,7 +373,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
@@ -427,7 +421,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
const Shape& operand_shape, const Shape& padding_value_shape,
const PaddingConfig& padding_config) {
- if (ShapeUtil::IsTuple(operand_shape)) {
+ if (!ShapeUtil::IsArray(operand_shape)) {
return InvalidArgument(
"Pad operation does not support tuple-shape operands.");
}
@@ -566,8 +560,8 @@ Status ValidateDotDimensionNumbers(
/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
auto fail = [lhs, rhs](const string& addendum) -> Status {
string message = tensorflow::strings::Printf(
@@ -786,10 +780,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -853,12 +845,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- HloOpcodeString(opcode))));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
+ HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
+ HloOpcodeString(opcode))));
switch (opcode) {
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -984,15 +976,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// All arguments must have the same shape.
const Shape* arg_shape = arg_shapes[0];
for (size_t i = 1; i < arg_shapes.size(); ++i) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
+ TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
continue;
}
- if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
- !ShapeUtil::IsTuple(*arg_shape) &&
- ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
+ if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
*arg_shape)) {
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
continue;
@@ -1075,11 +1064,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& scale_shape,
const Shape& offset_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm training"));
+ ExpectArray(operand_shape, "operand of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm training"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1181,11 +1170,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& offset_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm inference"));
+ ExpectArray(operand_shape, "operand of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1328,16 +1317,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,
const Shape& output_grad_shape, int64 feature_index) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
+ ExpectArray(scale_shape, "scale input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- output_grad_shape, "output_grad input of batch norm grad"));
+ ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
@@ -1486,8 +1472,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
const ConvolutionDimensionNumbers& dnums) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -1722,7 +1708,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum"));
+ ExpectArray(*operand_shape, "operand of cross replica sum"));
}
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
@@ -1764,8 +1750,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
operand_shape.element_type()));
return InferWindowOutputShape(operand_shape, window,
@@ -1778,7 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Window& window, const Shape& source_shape,
const Shape& init_value_shape, const ProgramShape& scatter_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
+ ExpectArray(operand_shape, "operand of select-and-scatter"));
// Check if the select function has a proper shape of (T,T) -> PRED.
if (select_shape.parameters_size() != 2) {
@@ -1843,7 +1828,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
Join(starts, ",").c_str(), Join(limits, ",").c_str(),
Join(strides, ",").c_str());
};
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
@@ -1902,10 +1887,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
- "start indices of dynamic slice"));
+ ExpectArray(start_indices_shape, "start indices of dynamic slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
@@ -1963,11 +1947,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& update_shape,
const Shape& start_indices_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
+ ExpectArray(operand_shape, "operand of dynamic update slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- start_indices_shape, "start indices of dynamic update slice"));
+ ExpectArray(update_shape, "update of dynamic update slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
+ "start indices of dynamic update slice"));
VLOG(2) << tensorflow::strings::Printf(
"updating slice of shape %s at dynamic start_indices %s with update "
@@ -2035,8 +2019,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
}
@@ -2166,7 +2149,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
return InvalidArgument("Broadcast with negative dimension size %lld.",
@@ -2185,7 +2168,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
ShapeUtil::MakeShape(operand.element_type(), new_sizes);
@@ -2217,7 +2200,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
std::iota(indices.begin(), indices.end(), 0);
@@ -2238,9 +2221,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// "degenerate" cases, as with binary elementwise ops.
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
const Shape& min, const Shape& operand, const Shape& max) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
+ TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
+ TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
@@ -2439,9 +2422,9 @@ static Status ValidateGatherDimensionNumbers(
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(input_shape, "input tensor operand gather op"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
return InvalidArgument(