From 4146ff1259c0b4ada8afbbad11a7b37d8373d1b9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Nov 2017 11:18:54 -0800 Subject: [XLA] Adds Dot with DotDimensionNumbers proto for specifying arbitrary contracting and batch dimensions. PiperOrigin-RevId: 177481231 --- .../compiler/xla/service/shape_inference_test.cc | 177 ++++++++++++++++++--- 1 file changed, 159 insertions(+), 18 deletions(-) (limited to 'tensorflow/compiler/xla/service/shape_inference_test.cc') diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index be93c879c0..6e53d2d609 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -898,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -907,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. -- cgit v1.2.3