aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/broadcast_simple_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc176
1 files changed, 130 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 5fdd1018a4..50dd574624 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
Array3D<float>* r3_array, float start, float end, int seed) {
*r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r3_array->FillRandom(start, end, seed);
- auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout(
+ auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r3_global_data =
client_->TransferToServer(*r3_data).ConsumeValueOrDie();
@@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
Array2D<float>* r2_array, float start, float end, int seed) {
*r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
r2_array->FillRandom(start, end, seed);
- auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout(
+ auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout(
LayoutUtil::MakeLayout(minor_to_major));
std::unique_ptr<GlobalData> r2_global_data =
client_->TransferToServer(*r2_data).ConsumeValueOrDie();
@@ -156,6 +157,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {2, 2}), {1});
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {2, 2}), {0});
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 1;
+ expected(1, 0) = 2;
+ expected(1, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}),
+ ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1});
+
+ Array3D<float> expected(2, 2, 2);
+ expected(0, 0, 0) = 1.0;
+ expected(1, 0, 0) = 2.0;
+ expected(0, 0, 1) = 1.0;
+ expected(1, 0, 1) = 2.0;
+ expected(0, 1, 0) = 5.0;
+ expected(1, 1, 0) = 6.0;
+ expected(1, 1, 1) = 6.0;
+ expected(0, 1, 1) = 5.0;
+
+ ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}),
+ ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2});
+
+ Array3D<float> expected(2, 2, 2);
+ expected(0, 0, 0) = 1.0;
+ expected(1, 0, 0) = 2.0;
+ expected(0, 0, 1) = 5.0;
+ expected(1, 0, 1) = 6.0;
+ expected(0, 1, 0) = 1.0;
+ expected(1, 1, 0) = 2.0;
+ expected(1, 1, 1) = 6.0;
+ expected(0, 1, 1) = 5.0;
+
+ ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
+ XlaBuilder b(TestName());
+ BroadcastInDim(ConstantR1<float>(&b, {1, 2}),
+ ShapeUtil::MakeShape(F32, {3, 2}), {1});
+
+ Array2D<float> expected(3, 2);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+ expected(2, 0) = 1;
+ expected(2, 1) = 2;
+
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
// Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
XlaBuilder b(TestName());
@@ -210,13 +291,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
- ConstantLiteral(&b, *Literal::CreateR3<float>(
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});
auto expected =
- Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
- {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
+ LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
+ {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -285,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
}
}
}
- auto expected = Literal::CreateR3FromArray3D(expected_array);
+ auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r3_implicit_global_data.get(), r3_global_data.get()},
@@ -310,7 +391,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
Add(r3h, r1h);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()},
ErrorSpec(0.0001));
@@ -318,39 +399,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}, {2}}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
+ auto r1 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -358,40 +440,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
XlaBuilder b(TestName());
auto r1 =
- ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
XlaBuilder b(TestName());
- auto r1 =
- ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
+ auto r1 = ConstantLiteral(
+ &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}}));
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1);
auto expected =
- Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
+ LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -532,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
*v = ApplyOpToFloats(spec.op2, tmp, v3);
});
- auto expected = Literal::CreateR2FromArray2D(expected_array);
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
ComputeAndCompareLiteral(
&builder, *expected,
{r2_implicit_global_data1.get(), r2_global_data.get(),
@@ -546,22 +628,24 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}}));
- auto r2 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}}));
+ auto r2 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
- auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
XlaBuilder b(TestName());
- auto r1 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1}, {2}}));
- auto r2 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}, {3, 4}}));
+ auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}}));
+ auto r2 =
+ ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
Add(r2, r1);
- auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -570,11 +654,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r3, r1, {0});
- auto expected =
- Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -583,11 +667,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {1});
- auto expected =
- Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -596,11 +680,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
XlaBuilder b(TestName());
auto r1 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
Add(r1, r3, {2});
- auto expected =
- Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
+ auto expected = LiteralUtil::CreateR3<float>(
+ {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
}
@@ -611,7 +695,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
auto r1_1 = ConstantR1<float>(&b, {100, 200});
auto r1_2 = ConstantR1<float>(&b, {10, 20});
auto r3 = ConstantLiteral(
- &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
+ &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
for (int i = 0; i < 3; ++i) {
r3 = Add(r1_0, r3, {0});
r3 = Add(r3, r1_1, {1});
@@ -619,7 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
}
r3 = Mul(r3, ConstantR0<float>(&b, -2));
- auto expected = Literal::CreateR3<float>(
+ auto expected = LiteralUtil::CreateR3<float>(
{{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
{{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
@@ -640,7 +724,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
}
r3 = Mul(r3, ConstantR0<float>(&b, -1));
- auto expected = Literal::CreateR3<float>(
+ auto expected = LiteralUtil::CreateR3<float>(
{{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
{{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
@@ -653,7 +737,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XlaBuilder b(TestName());
Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
- ConstantLiteral(&b, *Literal::CreateR3<float>(
+ ConstantLiteral(&b, *LiteralUtil::CreateR3<float>(
{{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
/*broadcast_dimensions=*/{1, 2});