aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/broadcast_simple_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc89
1 files changed, 43 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index fe4267c73b..dde19fb65d 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -60,10 +60,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
- client_->TransferToServer(*r3_data).ConsumeValueOrDie();
+ client_->TransferToServer(r3_data).ConsumeValueOrDie();
return r3_global_data;
}
@@ -74,10 +74,10 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
- client_->TransferToServer(*r2_data).ConsumeValueOrDie();
+ client_->TransferToServer(r2_data).ConsumeValueOrDie();
return r2_global_data;
}
@@ -293,7 +293,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
@@ -301,7 +301,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
{{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R3ImplicitBroadcastSpec {
@@ -370,8 +370,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
- {r3_implicit_global_data.get(), r3_global_data.get()},
+ &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
ErrorSpec(1e-7, 1e-7));
}
@@ -395,89 +394,89 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
+ ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
auto r1 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
struct R2ImplicitBroadcastSpec {
@@ -618,7 +617,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
- &builder, *expected,
+ &builder, expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
r2_implicit_global_data2.get()},
ErrorSpec(1e-6, 1e-6));
@@ -630,65 +629,63 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
- auto r2 =
- ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
auto expected = LiteralUtil::CreateR3<float>(
{{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
@@ -697,7 +694,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -709,7 +706,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
@@ -730,7 +727,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
- ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+ ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
@@ -739,7 +736,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
+ ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});