aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/slice_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/slice_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc51
1 files changed, 23 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 5e7d475662..97120df0c5 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -44,7 +44,7 @@ class SliceTest : public ClientLibraryTestBase {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<NativeT>(constant);
- builder.Slice(original, {2}, {4}, {1});
+ builder.Slice(original, {2}, {4});
const std::vector<NativeT> expected = {static_cast<NativeT>(2),
static_cast<NativeT>(3)};
@@ -55,7 +55,7 @@ class SliceTest : public ClientLibraryTestBase {
XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>({});
- builder.Slice(original, {0}, {0}, {1});
+ builder.Slice(original, {0}, {0});
ComputeAndCompareR1<float>(&builder, {}, {});
}
@@ -64,7 +64,7 @@ XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
ComputationBuilder builder(client_, TestName());
std::vector<float> constant(10, 0.3);
auto original = builder.ConstantR1<float>(constant);
- builder.Slice(original, {7}, {7}, {1});
+ builder.Slice(original, {7}, {7});
ComputeAndCompareR1<float>(&builder, {}, {});
}
@@ -87,7 +87,7 @@ TEST_F(SliceTest, SliceTenToTen) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {0}, {10}, {1});
+ builder.Slice(original, {0}, {10});
ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
}
@@ -98,7 +98,7 @@ TEST_F(SliceTest, SliceLastFourOf1024) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {1024 - 4}, {1024}, {1});
+ builder.Slice(original, {1024 - 4}, {1024});
const std::vector<float> expected = {1020, 1021, 1022, 1023};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
@@ -112,7 +112,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR1<float>(values);
- builder.Slice(original, {7}, {7 + 1024}, {1});
+ builder.Slice(original, {7}, {7 + 1024});
std::vector<float> expected(1024);
std::iota(values.begin(), values.end(), 7.0);
@@ -122,7 +122,7 @@ TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
- builder.Slice(original, {0, 0}, {0, 0}, {1, 1});
+ builder.Slice(original, {0, 0}, {0, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
}
@@ -130,7 +130,7 @@ XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
- builder.Slice(original, {0, 15}, {0, 20}, {1, 1});
+ builder.Slice(original, {0, 15}, {0, 20});
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
}
@@ -138,7 +138,7 @@ XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
- builder.Slice(original, {1, 0}, {3, 0}, {1, 1});
+ builder.Slice(original, {1, 0}, {3, 0});
ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
}
@@ -153,7 +153,7 @@ XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {128, 128}, {256, 256}, {1, 1});
+ builder.Slice(original, {128, 128}, {256, 256});
Array2D<float> expected(128, 128);
for (int row = 0; row < 128; ++row) {
@@ -171,7 +171,7 @@ TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 3072}, {1, 4096}, {1, 1});
+ builder.Slice(original, {0, 3072}, {1, 4096});
Array2D<float> expected(1, 1024);
std::iota(expected.data(), expected.data() + 1024, 3072.0);
@@ -192,7 +192,7 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
}
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR2FromArray2D<float>(values);
- builder.Slice(original, {0, 0}, {16, 2}, {1, 1});
+ builder.Slice(original, {0, 0}, {16, 2});
ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
}
@@ -204,7 +204,7 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR4FromArray4D(values);
- builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
+ builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128});
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
}
@@ -213,7 +213,6 @@ struct R2Spec {
int64 input_dim1;
std::array<int64, 2> slice_starts;
std::array<int64, 2> slice_limits;
- std::array<int64, 2> slice_strides;
Layout layout;
};
@@ -229,7 +228,7 @@ TEST_P(SliceR2Test, DoIt) {
ComputationBuilder builder(client_, TestName());
auto a = builder.ConstantR2FromArray2D<int32>(input);
- builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
+ builder.Slice(a, spec.slice_starts, spec.slice_limits);
std::unique_ptr<Array2D<int32>> expected =
ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
@@ -240,23 +239,19 @@ TEST_P(SliceR2Test, DoIt) {
INSTANTIATE_TEST_CASE_P(
SliceR2TestInstantiation, SliceR2Test,
::testing::Values(
- R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
- LayoutUtil::MakeLayout({0, 1})},
- R2Spec {4, 12, {{0, 3}}, {{4, 6}}, {{1, 1}},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {256, 400, {{0, 300}}, {{256, 400}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
- LayoutUtil::MakeLayout({0, 1})},
- R2Spec {16, 4, {{0, 2}}, {{16, 4}}, {{1, 1}},
+ R2Spec {500, 400, {{111, 123}}, {{300, 257}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {256, 400, {{0, 300}}, {{256, 400}}, {{1, 1}},
+ R2Spec {500, 400, {{111, 123}}, {{300, 400}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {500, 400, {{111, 123}}, {{300, 257}}, {{1, 1}},
+ R2Spec {384, 512, {{128, 256}}, {{256, 384}},
LayoutUtil::MakeLayout({1, 0})},
- R2Spec {500, 400, {{111, 123}}, {{300, 400}}, {{1, 1}},
- LayoutUtil::MakeLayout({1, 0})},
- R2Spec {384, 512, {{128, 256}}, {{256, 384}}, {{1, 1}},
- LayoutUtil::MakeLayout({1, 0})},
- R2Spec {357, 512, {{111, 256}}, {{301, 384}}, {{1, 1}},
+ R2Spec {357, 512, {{111, 256}}, {{301, 384}},
LayoutUtil::MakeLayout({1, 0})}
)
);