aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-07 03:44:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-07 03:48:18 -0800
commit4f0aa15e9635c33ca37f3aa714b10f4ca3199e7f (patch)
treeca4a4ebe930eff77fa011cfb29381676f8d90f01 /tensorflow/compiler/xla
parentc0824a4eeaffa7e30119fef21a5b689c972e6657 (diff)
Fix ShapeUtil::CompatibleIgnoringElementType for scalar vs tuple comparision
Previously if the lhs was a scalar and the rhs was a tuple of arbitrary shape it reported them as compatible what is clearly wrong. PiperOrigin-RevId: 188155575
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc3
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc13
-rw-r--r--tensorflow/compiler/xla/shape_util.cc15
-rw-r--r--tensorflow/compiler/xla/shape_util.h1
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc12
5 files changed, 36 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index c54cb3b48d..915baecc56 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2394,7 +2394,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
"Select's pred operand must have PRED element type; got %s.",
ShapeUtil::HumanString(pred).c_str());
}
- if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) {
+ if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
+ ShapeUtil::Rank(pred) == 0) {
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 06735e9442..0dca30a804 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -3315,20 +3315,23 @@ void ComputationLowerer::Visit(
HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs());
HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs());
auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
-
- if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
- if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
+ if (debug_options_.xla_eliminate_hlo_implicit_broadcast() &&
+ !ShapeUtil::IsTuple(request.output_shape())) {
+ if (!ShapeUtil::IsTuple(lhs->shape()) &&
+ !ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
// lhs side is being implicitly broadcast. Change to explicit.
lhs =
ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
}
- if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
+ if (!ShapeUtil::IsTuple(rhs->shape()) &&
+ !ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
rhs =
ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
}
- if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) {
+ if (!ShapeUtil::IsTuple(ehs->shape()) &&
+ !ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) {
ehs =
ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape());
}
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 9810e818f6..4f604e6f7c 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -609,6 +609,8 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
const Shape& rhs) {
+ CHECK(ShapeUtil::IsArray(lhs));
+ CHECK(ShapeUtil::IsArray(rhs));
return ContainersEqual(lhs.dimensions(), rhs.dimensions());
}
@@ -617,7 +619,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
}
- return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
+ if (lhs.element_type() == OPAQUE) {
+ return rhs.element_type() == OPAQUE;
+ }
+ return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
@@ -627,7 +632,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringElementType);
}
- return SameDimensions(lhs, rhs);
+ if (lhs.element_type() == OPAQUE) {
+ return rhs.element_type() == OPAQUE;
+ }
+ return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
@@ -637,6 +645,9 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringFpPrecision);
}
+ if (lhs.element_type() == OPAQUE) {
+ return rhs.element_type() == OPAQUE;
+ }
if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return CompatibleIgnoringElementType(lhs, rhs);
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 92b365e072..3e130a02e2 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -209,6 +209,7 @@ class ShapeUtil {
// Returns whether the LHS and RHS shapes have the same dimensions; note: does
// not check element type.
+ // Precondition: IsArray(lhs) && IsArray(rhs)
static bool SameDimensions(const Shape& lhs, const Shape& rhs);
// Returns whether the lhs and rhs shapes have the same element type.
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index a357415698..424cfe37ea 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -238,6 +238,18 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {
EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
}
+TEST(ShapeUtilTest, IncompatibleScalarVsTuple) {
+ Shape shape1 = ShapeUtil::MakeShape(F32, {});
+ Shape shape2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(U32, {})});
+ EXPECT_FALSE(ShapeUtil::Compatible(shape1, shape2));
+ EXPECT_FALSE(ShapeUtil::Compatible(shape2, shape1));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape1, shape2));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(shape2, shape1));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
+ EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape2, shape1));
+}
+
TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) {
Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30});
shape1.mutable_layout()->add_padded_dimensions(10);