aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/dot_operation_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/dot_operation_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc121
1 files changed, 66 insertions, 55 deletions
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 33d79aebb1..d86fd7cc2d 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -67,15 +67,16 @@ XLA_TEST_F(DotOperationTest, DotOfInputTupleElem) {
XlaOp param;
auto param_data = CreateParameterAndTransferLiteral(
0,
- *Literal::MakeTuple({Literal::CreateR2<float>({{1, 2}, {3, 4}}).get(),
- Literal::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
+ *LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}).get(),
+ LiteralUtil::CreateR2<float>({{5, 6}, {7, 8}}).get()}),
"arg0", &builder, &param);
auto lhs = GetTupleElement(param, 0);
auto rhs = GetTupleElement(param, 1);
Dot(lhs, rhs);
ComputeAndCompareLiteral(&builder,
- *Literal::CreateR2<float>({{19, 22}, {43, 50}}),
+ *LiteralUtil::CreateR2<float>({{19, 22}, {43, 50}}),
{param_data.get()});
}
@@ -194,11 +195,11 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, FusedDot) {
auto lhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f, 2.0f, 3.0f, 4.0f}, {-1.0f, -2.0f, -3.0f, -4.0f}}))
.ConsumeValueOrDie();
auto rhs_handle = this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR2FromArray2D<T>(
{{1.0f}, {2.0f}, {3.0f}, {4.0f}}))
.ConsumeValueOrDie();
@@ -217,14 +218,14 @@ class SquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f}, {3.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -286,9 +287,10 @@ void ParametricDotTest::TestImpl() {
std::unique_ptr<Array2D<NativeT>> dot_lhs_data =
MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.k);
- std::unique_ptr<Literal> dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout(
- *dot_lhs_data, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(param.dot_lhs_row_major)));
+ std::unique_ptr<Literal> dot_lhs_lit =
+ LiteralUtil::CreateR2FromArray2DWithLayout(
+ *dot_lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(
+ param.dot_lhs_row_major)));
std::unique_ptr<GlobalData> dot_lhs_handle =
client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie();
@@ -297,7 +299,7 @@ void ParametricDotTest::TestImpl() {
Layout rhs_layout = LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.dot_rhs_row_major));
std::unique_ptr<Literal> dot_rhs_lit =
- Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
+ LiteralUtil::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout);
std::unique_ptr<GlobalData> dot_rhs_handle =
client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie();
@@ -307,7 +309,7 @@ void ParametricDotTest::TestImpl() {
if (param.has_addend) {
addend_data = MakeLinspaceArray2D<NativeT>(0.0, 1.0, param.m, param.n);
- addend_lit = Literal::CreateR2FromArray2DWithLayout(
+ addend_lit = LiteralUtil::CreateR2FromArray2DWithLayout(
*addend_data, LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(param.addend_row_major)));
addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie();
@@ -476,14 +478,14 @@ class NonsquareMatrixDot : public DotOperationTest {
void TestImpl(bool lhs_row_major, bool rhs_row_major) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 2.0f, 3.0f}, {3.0f, -4.0f, -1.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(lhs_row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateFromArrayWithLayout<T>(
+ ->TransferToServer(*LiteralUtil::CreateFromArrayWithLayout<T>(
{{1.0f, 6.0f}, {2.0f, 3.0f}, {7.0f, -4.0f}},
LayoutUtil::MakeLayout(
MinorToMajorForIsRowMajor(rhs_row_major))))
@@ -510,12 +512,12 @@ XLA_TYPED_TEST(NonsquareMatrixDot, TestTT) { this->TestImpl(true, true); }
XLA_TEST_F(DotOperationTest, MatrixVectorC64) {
auto lhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
auto rhs_handle =
client_
- ->TransferToServer(*Literal::CreateR2WithLayout<complex64>(
+ ->TransferToServer(*LiteralUtil::CreateR2WithLayout<complex64>(
{{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}},
LayoutUtil::MakeLayout({1, 0})))
.ConsumeValueOrDie();
@@ -583,7 +585,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
auto x_data = this->client_
- ->TransferToServer(*Literal::CreateR4FromArray4D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
{{{{1000.0f, 100.0f}, {10.0f, 1.0f}},
{{2000.0f, 200.0f}, {20.0f, 2.0f}}},
{{{3000.0f, 300.0f}, {30.0f, 3.0f}},
@@ -591,7 +593,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*Literal::CreateR4FromArray4D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
{{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
{{{11.0f, 22.0f}, {33.0f, 44.0f}},
{{55.0f, 66.0f}, {77.0f, 88.0f}}}}))
@@ -629,13 +631,13 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
auto x_data =
this->client_
- ->TransferToServer(*Literal::CreateR3FromArray3D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}))
.ConsumeValueOrDie();
auto y_data =
this->client_
- ->TransferToServer(*Literal::CreateR3FromArray3D<T>(
+ ->TransferToServer(*LiteralUtil::CreateR3FromArray3D<T>(
{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}))
.ConsumeValueOrDie();
@@ -664,15 +666,17 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
}
auto lhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- *lhs, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(row_major))))
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ *lhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
auto rhs_handle =
this->client_
- ->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
- *rhs, LayoutUtil::MakeLayout(
- MinorToMajorForIsRowMajor(row_major))))
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<T>(
+ *rhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
.ConsumeValueOrDie();
XlaBuilder builder(this->TestName());
@@ -733,15 +737,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{53.0f, 74.0f}, {45.0f, 66.0f}});
this->template ComputeAndCompareR2<T>(
@@ -782,15 +786,15 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64,
TF_ASSERT_OK_AND_ASSIGN(
auto arg_0_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_0_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_0_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_1_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_1_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_1_value_array)));
TF_ASSERT_OK_AND_ASSIGN(
auto arg_2_value,
this->client_->TransferToServer(
- *Literal::CreateR2FromArray2D<T>(*arg_2_value_array)));
+ *LiteralUtil::CreateR2FromArray2D<T>(*arg_2_value_array)));
Array2D<T> expected({{38.0f, 36.0f}, {93.0f, 91.0f}});
this->template ComputeAndCompareR2<T>(
@@ -853,10 +857,9 @@ XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSClassicMM) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSReverseMM)))) {
+
+ DotOfGatherOptimizationWithConstRHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -883,10 +886,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSReverseMM)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSReverseMM) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
@@ -913,10 +913,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSRows)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -948,10 +945,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSRows)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSRows) {
std::unique_ptr<Array2D<float>> constant_lhs_array(
new Array2D<float>({{1.0, 2.0},
{3.0, 4.0},
@@ -983,10 +977,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstRHSCols)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstRHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1010,10 +1001,7 @@ XLA_TEST_F(DotOperationTest,
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-// TODO (b/69062148) Enable when Dot implements general contracting dimensions.
-XLA_TEST_F(DotOperationTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(DISABLED_ON_INTERPRETER(
- DotOfGatherOptimizationWithConstLHSCols)))) {
+XLA_TEST_F(DotOperationTest, DotOfGatherOptimizationWithConstLHSCols) {
std::unique_ptr<Array2D<float>> constant_lhs_array(new Array2D<float>(
{{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}}));
std::unique_ptr<Array2D<float>> constant_rhs_array(
@@ -1036,5 +1024,28 @@ XLA_TEST_F(DotOperationTest,
Array2D<float> expected({{168.0}, {168.0}});
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
+
+XLA_TEST_F(DotOperationTest, DotRank2AndRank2NonDefaultContractionDims) {
+ XlaBuilder builder(TestName());
+
+ Array2D<float> lhs_array({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ auto lhs_constant = ConstantR2FromArray2D(&builder, lhs_array);
+
+ Array2D<float> rhs_array({{5.0f, 6.0f}, {7.0f, 8.0f}});
+ auto rhs_constant = ConstantR2FromArray2D(&builder, rhs_array);
+
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ DotGeneral(lhs_constant, rhs_constant, dot_dnums);
+
+ Array2D<float> expected({
+ {26.f, 30.f},
+ {38.f, 44.f},
+ });
+
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
} // namespace
} // namespace xla