diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/slice_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/slice_test.cc | 51 |
1 files changed, 23 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 5e7d475662..97120df0c5 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -44,7 +44,7 @@ class SliceTest : public ClientLibraryTestBase { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1<NativeT>(constant); - builder.Slice(original, {2}, {4}, {1}); + builder.Slice(original, {2}, {4}); const std::vector<NativeT> expected = {static_cast<NativeT>(2), static_cast<NativeT>(3)}; @@ -55,7 +55,7 @@ class SliceTest : public ClientLibraryTestBase { XLA_TEST_F(SliceTest, SliceZeroToZeroF32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1<float>({}); - builder.Slice(original, {0}, {0}, {1}); + builder.Slice(original, {0}, {0}); ComputeAndCompareR1<float>(&builder, {}, {}); } @@ -64,7 +64,7 @@ XLA_TEST_F(SliceTest, SliceTenToZeroF32) { ComputationBuilder builder(client_, TestName()); std::vector<float> constant(10, 0.3); auto original = builder.ConstantR1<float>(constant); - builder.Slice(original, {7}, {7}, {1}); + builder.Slice(original, {7}, {7}); ComputeAndCompareR1<float>(&builder, {}, {}); } @@ -87,7 +87,7 @@ TEST_F(SliceTest, SliceTenToTen) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1<float>(values); - builder.Slice(original, {0}, {10}, {1}); + builder.Slice(original, {0}, {10}); ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001)); } @@ -98,7 +98,7 @@ TEST_F(SliceTest, SliceLastFourOf1024) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1<float>(values); - builder.Slice(original, {1024 - 4}, {1024}, {1}); + builder.Slice(original, {1024 - 4}, {1024}); const std::vector<float> expected = {1020, 1021, 1022, 1023}; ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001)); @@ -112,7 +112,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR1<float>(values); - builder.Slice(original, {7}, {7 + 1024}, {1}); + builder.Slice(original, {7}, {7 + 1024}); std::vector<float> expected(1024); std::iota(values.begin(), values.end(), 7.0); @@ -122,7 +122,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) { XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0)); - builder.Slice(original, {0, 0}, {0, 0}, {1, 1}); + builder.Slice(original, {0, 0}, {0, 0}); ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}); } @@ -130,7 +130,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) { XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20)); - builder.Slice(original, {0, 15}, {0, 20}, {1, 1}); + builder.Slice(original, {0, 15}, {0, 20}); ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {}); } @@ -138,7 +138,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) { XLA_TEST_F(SliceTest, Slice3x0to2x0F32) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0)); - builder.Slice(original, {1, 0}, {3, 0}, {1, 1}); + builder.Slice(original, {1, 0}, {3, 0}); ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {}); } @@ -153,7 +153,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D<float>(values); - builder.Slice(original, {128, 128}, {256, 256}, {1, 1}); + builder.Slice(original, {128, 128}, {256, 256}); Array2D<float> expected(128, 128); for (int row = 0; row < 128; ++row) { @@ -171,7 +171,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) { ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D<float>(values); - builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1}); + builder.Slice(original, {0, 3072}, {1, 4096}); Array2D<float> expected(1, 1024); std::iota(expected.data(), expected.data() + 1024, 3072.0); @@ -192,7 +192,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { } ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR2FromArray2D<float>(values); - builder.Slice(original, {0, 0}, {16, 2}, {1, 1}); + builder.Slice(original, {0, 0}, {16, 2}); ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001)); } @@ -204,7 +204,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}); ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR4FromArray4D(values); - builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); + builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}); ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001)); } @@ -213,7 +213,6 @@ struct R2Spec { int64 input_dim1; std::array<int64, 2> slice_starts; std::array<int64, 2> slice_limits; - std::array<int64, 2> slice_strides; Layout layout; }; @@ -229,7 +228,7 @@ TEST_P(SliceR2Test, DoIt) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2FromArray2D<int32>(input); - builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); + builder.Slice(a, spec.slice_starts, spec.slice_limits); std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits); @@ -240,23 +239,19 @@ TEST_P(SliceR2Test, DoIt) { INSTANTIATE_TEST_CASE_P( SliceR2TestInstantiation, SliceR2Test, ::testing::Values( - R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, - LayoutUtil::MakeLayout({0, 1})}, - R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}}, + R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})}, + R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})}, + R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})}, + R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})}, + R2Spec {256, 400, {{0, 300}}, {{256, 400}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, - LayoutUtil::MakeLayout({0, 1})}, - R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}}, + R2Spec {500, 400, {{111, 123}}, {{300, 257}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}}, + R2Spec {500, 400, {{111, 123}}, {{300, 400}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}}, + R2Spec {384, 512, {{128, 256}}, {{256, 384}}, LayoutUtil::MakeLayout({1, 0})}, - R2Spec {500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}}, - LayoutUtil::MakeLayout({1, 0})}, - R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}}, - LayoutUtil::MakeLayout({1, 0})}, - R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}}, + R2Spec {357, 512, {{111, 256}}, {{301, 384}}, LayoutUtil::MakeLayout({1, 0})} ) ); |