diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/broadcast_simple_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/broadcast_simple_test.cc | 176 |
1 files changed, 130 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 5fdd1018a4..50dd574624 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -58,7 +59,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D<float>* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr<GlobalData> r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); @@ -71,7 +72,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D<float>* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array)->Relayout( LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr<GlobalData> r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); @@ -156,6 +157,86 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); } +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1<float>(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {1}); + + Array2D<float> expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + + ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1<float>(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {2, 2}), {0}); + + Array2D<float> expected(2, 2); + expected(0, 0) = 1; + expected(0, 1) = 1; + expected(1, 0) = 2; + expected(1, 1) = 2; + + ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 1}); + + Array3D<float> expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 1.0; + expected(1, 0, 1) = 2.0; + expected(0, 1, 0) = 5.0; + expected(1, 1, 0) = 6.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), + ShapeUtil::MakeShape(F32, {2, 2, 2}), {0, 2}); + + Array3D<float> expected(2, 2, 2); + expected(0, 0, 0) = 1.0; + expected(1, 0, 0) = 2.0; + expected(0, 0, 1) = 5.0; + expected(1, 0, 1) = 6.0; + expected(0, 1, 0) = 1.0; + expected(1, 1, 0) = 2.0; + expected(1, 1, 1) = 6.0; + expected(0, 1, 1) = 5.0; + + ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001)); +} + +XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) { + XlaBuilder b(TestName()); + BroadcastInDim(ConstantR1<float>(&b, {1, 2}), + ShapeUtil::MakeShape(F32, {3, 2}), {1}); + + Array2D<float> expected(3, 2); + expected(0, 0) = 1; + expected(0, 1) = 2; + expected(1, 0) = 1; + expected(1, 1) = 2; + expected(2, 0) = 1; + expected(2, 1) = 2; + + ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001)); +} + // Tests implicit broadcasting of PREDs. XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { XlaBuilder b(TestName()); @@ -210,13 +291,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2<float>(&b, {{1.0, 5.0}}), - ConstantLiteral(&b, *Literal::CreateR3<float>( + ConstantLiteral(&b, *LiteralUtil::CreateR3<float>( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); auto expected = - Literal::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -285,7 +366,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = Literal::CreateR3FromArray3D(expected_array); + auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -310,7 +391,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { Add(r3h, r1h); auto expected = - Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -318,39 +399,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}, {2}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}, {2}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}})); + auto r1 = + ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -358,40 +440,40 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { XlaBuilder b(TestName()); auto r1 = - ConstantLiteral(&b, *Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}})); + ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { XlaBuilder b(TestName()); - auto r1 = - ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}})); + auto r1 = ConstantLiteral( + &b, *LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR3<float>({{{1}}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR3<float>({{{1}}})); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1); auto expected = - Literal::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -532,7 +614,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = Literal::CreateR2FromArray2D(expected_array); + auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -546,22 +628,24 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}})); - auto r2 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}})); Add(r2, r1); - auto expected = Literal::CreateR2<float>({{2, 4}, {4, 6}}); + auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { XlaBuilder b(TestName()); - auto r1 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1}, {2}})); - auto r2 = ConstantLiteral(&b, *Literal::CreateR2<float>({{1, 2}, {3, 4}})); + auto r1 = ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1}, {2}})); + auto r2 = + ConstantLiteral(&b, *LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}})); Add(r2, r1); - auto expected = Literal::CreateR2<float>({{2, 3}, {5, 6}}); + auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -570,11 +654,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { XlaBuilder b(TestName()); auto r1 = ConstantR1<float>(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r3, r1, {0}); - auto expected = - Literal::CreateR3<float>({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3<float>( + {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -583,11 +667,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { XlaBuilder b(TestName()); auto r1 = ConstantR1<float>(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {1}); - auto expected = - Literal::CreateR3<float>({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = LiteralUtil::CreateR3<float>( + {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -596,11 +680,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { XlaBuilder b(TestName()); auto r1 = ConstantR1<float>(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); Add(r1, r3, {2}); - auto expected = - Literal::CreateR3<float>({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = LiteralUtil::CreateR3<float>( + {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -611,7 +695,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = ConstantR1<float>(&b, {100, 200}); auto r1_2 = ConstantR1<float>(&b, {10, 20}); auto r3 = ConstantLiteral( - &b, *Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + &b, *LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = Add(r1_0, r3, {0}); r3 = Add(r3, r1_1, {1}); @@ -619,7 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } r3 = Mul(r3, ConstantR0<float>(&b, -2)); - auto expected = Literal::CreateR3<float>( + auto expected = LiteralUtil::CreateR3<float>( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -640,7 +724,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { } r3 = Mul(r3, ConstantR0<float>(&b, -1)); - auto expected = Literal::CreateR3<float>( + auto expected = LiteralUtil::CreateR3<float>( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -653,7 +737,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { XlaBuilder b(TestName()); Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}), - ConstantLiteral(&b, *Literal::CreateR3<float>( + ConstantLiteral(&b, *LiteralUtil::CreateR3<float>( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); |