aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-30 11:18:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-30 11:22:15 -0800
commit4146ff1259c0b4ada8afbbad11a7b37d8373d1b9 (patch)
tree1a4cb649245215420c7a34ce97506327caa0d1c4 /tensorflow/compiler/xla/service/shape_inference_test.cc
parentea1c29552b01f3404e27999a27a1919b3accc594 (diff)
[XLA] Adds Dot with DotDimensionNumbers proto for specifying arbitrary contracting and batch dimensions.
PiperOrigin-RevId: 177481231
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc177
1 files changed, 159 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index be93c879c0..6e53d2d609 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -898,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) {
// scalar <dot> vector: error
TEST_F(ShapeInferenceTest, ScalarDotVector) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {});
+ ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("dot only supports rank"));
@@ -907,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) {
// 3D <dot> 2D: error
TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
- auto inferred_status = ShapeInference::InferBinaryOpShape(
- BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status = ShapeInference::InferDotOpShape(
+ ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
- HasSubstr("dot only supports rank"));
+ HasSubstr("batch and contracting dimension number mismatch"));
}
// vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest, VectorDotVector) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {});
+ ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
auto inferred_status_mismatch =
- ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {});
+ ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
// matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest, MatrixDotVector) {
- auto inferred_status = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {});
+ auto inferred_status_mismatch =
+ ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
// vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest, VectorDotMatrix) {
- auto inferred_status = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {});
+ auto inferred_status_mismatch =
+ ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
// matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
- auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status_match =
+ ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
<< "inferred: "
<< ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
<< " expected: " << ShapeUtil::HumanString(matrix_64_48_);
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {});
+ auto inferred_status_mismatch =
+ ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
+// BatchMatMul with two batch dimensions and one contracting dimension.
+TEST_F(ShapeInferenceTest, DotGeneral) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
+ Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(3);
+ dot_dnums.add_lhs_batch_dimensions(0);
+ dot_dnums.add_lhs_batch_dimensions(1);
+
+ dot_dnums.add_rhs_contracting_dimensions(2);
+ dot_dnums.add_rhs_batch_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status_match =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(
+ ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
+ << "inferred: "
+ << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
+ << " expected: " << ShapeUtil::HumanString(output_shape);
+}
+
+// BatchMatMul with two contracting dimensions fails.
+TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
+ Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_lhs_contracting_dimensions(3);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_batch_dimensions(0);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must specify one contracting dimension for both "
+ "lhs and rhs"));
+}
+
+// BatchMatMul with different batch dimension sizes fails.
+TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_batch_dimensions(0);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("batch dimension numbers and sizes must match"));
+}
+
+// BatchMatMul with different batch dimension numbers fails.
+TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("batch dimension numbers and sizes must match"));
+}
+
+// BatchMatMul with out-of-range dimension numbers fails.
+TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(3);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("A dimension number is out of range"));
+}
+
+// BatchMatMul with non-unique dimension numbers fails.
+TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("A dimension number is not unique"));
+}
+
TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
// Test variations of broadcasting a vector for a binary add with a
// matrix.