aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 18:48:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 18:53:46 -0700
commit3cd6bdef5fa44efbf2b16eeb5fe026be839e6898 (patch)
tree6ade03925520c58532541f02abd211772f380d4c
parent46a81b5c3490c8ff21521d7541a860491d27baa8 (diff)
Added test cases on R4 slice.
PiperOrigin-RevId: 168482049
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc186
1 files changed, 168 insertions, 18 deletions
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 5da6104cfa..3bf0f411a8 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -25,12 +25,16 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace {
+using ::tensorflow::str_util::Join;
+using ::tensorflow::strings::StrCat;
+
class SliceTest : public ClientLibraryTestBase {};
TEST_F(SliceTest, Slice3x3x3_To_3x3x1_F32) {
@@ -161,6 +165,20 @@ TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
}
+XLA_TEST_F(SliceTest, StridedSliceR4WithOutputLayout) {
+ Array4D<float> values(2, 4, 6, 8);
+ values.FillRandom(3.14f);
+ auto expected = ReferenceUtil::Slice4D(values, {{0, 0, 0, 0}}, {{2, 4, 6, 8}},
+ /*strides=*/{{1, 1, 2, 1}});
+ auto expected_literal = Literal::CreateR4FromArray4DWithLayout(
+ *expected, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR4FromArray4D(values);
+ builder.Slice(original, {0, 0, 0, 0}, {2, 4, 6, 8}, {1, 1, 2, 1});
+ ComputeAndCompareLiteral(&builder, *expected_literal, {}, ErrorSpec(0.000001),
+ &expected_literal->shape());
+}
+
struct R1Spec {
int64 input_dim0;
int64 slice_start;
@@ -193,29 +211,17 @@ class SliceR1Test : public ClientLibraryTestBase,
}
};
-XLA_TEST_P(SliceR1Test, DoIt_F32) {
- Run<float>(GetParam());
-}
+XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
-XLA_TEST_P(SliceR1Test, DoIt_F64) {
- Run<double>(GetParam());
-}
+XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
-XLA_TEST_P(SliceR1Test, DoIt_U32) {
- Run<uint32>(GetParam());
-}
+XLA_TEST_P(SliceR1Test, DoIt_U32) { Run<uint32>(GetParam()); }
-XLA_TEST_P(SliceR1Test, DoIt_S32) {
- Run<int32>(GetParam());
-}
+XLA_TEST_P(SliceR1Test, DoIt_S32) { Run<int32>(GetParam()); }
-XLA_TEST_P(SliceR1Test, DoIt_U64) {
- Run<uint64>(GetParam());
-}
+XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
-XLA_TEST_P(SliceR1Test, DoIt_S64) {
- Run<int64>(GetParam());
-}
+XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
INSTANTIATE_TEST_CASE_P( //
SliceR1TestInstantiation, //
@@ -306,5 +312,149 @@ INSTANTIATE_TEST_CASE_P(
);
// clang-format on
+struct R4Spec {
+ std::array<int64, 4> input_dims;
+ std::array<int64, 4> input_layout; // minor-to-major
+ std::array<int64, 4> slice_starts;
+ std::array<int64, 4> slice_limits;
+ std::array<int64, 4> slice_strides;
+};
+
+string R4SpecToString(const ::testing::TestParamInfo<R4Spec>& data) {
+ const R4Spec& spec = data.param;
+ return StrCat( //
+ "input_", Join(spec.input_dims, "x"), //
+ "__layout_", Join(spec.input_layout, ""), //
+ "__starts_", Join(spec.slice_starts, "x"), //
+ "__limits_", Join(spec.slice_limits, "x"), //
+ "__strides_", Join(spec.slice_strides, "x") //
+ );
+}
+
+class SliceR4Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<R4Spec> {
+ protected:
+ void Run(const R4Spec& spec) {
+ Array4D<float> values(spec.input_dims[0], spec.input_dims[1],
+ spec.input_dims[2], spec.input_dims[3]);
+ values.FillRandom(3.14f);
+ auto expected = ReferenceUtil::Slice4D(
+ values, spec.slice_starts, spec.slice_limits, spec.slice_strides);
+ ComputationBuilder builder(client_, TestName());
+ auto literal = Literal::CreateR4FromArray4DWithLayout(
+ values, LayoutUtil::MakeLayout(spec.input_layout));
+ auto parameter = builder.Parameter(0, literal->shape(), "p0");
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> arg,
+ client_->TransferToServer(*literal));
+ builder.Slice(parameter, spec.slice_starts, spec.slice_limits,
+ spec.slice_strides);
+ ComputeAndCompareR4(&builder, *expected, {arg.get()}, ErrorSpec(0.000001));
+ }
+};
+
+XLA_TEST_P(SliceR4Test, DoIt) { Run(GetParam()); }
+
+const R4Spec kR4SpecValues[] = {
+ R4Spec{{{2, 2, 2, 2}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 0, 0}},
+ {{0, 0, 0, 0}},
+ {{1, 1, 1, 1}}}, //
+ R4Spec{{{3, 3, 4, 4}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 0, 0}},
+ {{3, 3, 4, 4}},
+ {{1, 1, 2, 1}}}, //
+ R4Spec{{{2, 3, 16, 4}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 0, 0}},
+ {{2, 3, 16, 4}},
+ {{1, 1, 3, 1}}}, //
+ // stride > 1 should be on the second-to-last dimension.
+ R4Spec{{{4, 16, 3, 2}},
+ {{0, 1, 2, 3}},
+ {{1, 4, 1, 1}},
+ {{3, 12, 3, 2}},
+ {{1, 1, 3, 1}}}, //
+ R4Spec{{{2, 2, 257, 129}},
+ {{3, 2, 1, 0}},
+ {{1, 1, 62, 64}},
+ {{2, 2, 195, 129}},
+ {{1, 1, 3, 1}}}, //
+ R4Spec{{{3, 5, 257, 129}},
+ {{3, 2, 1, 0}},
+ {{1, 2, 61, 64}},
+ {{3, 5, 199, 129}},
+ {{1, 1, 3, 1}}}, //
+ R4Spec{{{5, 8, 257, 129}},
+ {{3, 2, 1, 0}},
+ {{2, 3, 60, 64}},
+ {{3, 5, 200, 68}},
+ {{1, 1, 1, 1}}}, //
+ R4Spec{{{2, 2, 256, 130}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 60, 127}},
+ {{2, 2, 166, 129}},
+ {{1, 1, 3, 1}}}, //
+ R4Spec{{{2, 4, 8, 4}},
+ {{3, 2, 1, 0}},
+ {{1, 2, 0, 1}},
+ {{2, 4, 8, 3}},
+ {{1, 1, 7, 1}}}, //
+ R4Spec{{{2, 4, 256, 130}},
+ {{3, 2, 1, 0}},
+ {{1, 2, 9, 127}},
+ {{2, 4, 82, 129}},
+ {{1, 1, 7, 1}}}, //
+ R4Spec{{{2, 4, 256, 130}},
+ {{3, 2, 1, 0}},
+ {{1, 2, 19, 127}},
+ {{2, 4, 89, 129}},
+ {{1, 1, 7, 1}}}, //
+ R4Spec{{{2, 4, 256, 130}},
+ {{3, 2, 1, 0}},
+ {{1, 2, 29, 127}},
+ {{2, 4, 159, 129}},
+ {{1, 1, 7, 1}}}, //
+ R4Spec{{{2, 4, 256, 130}},
+ {{3, 2, 1, 0}},
+ {{1, 2, 39, 127}},
+ {{2, 4, 158, 129}},
+ {{1, 1, 7, 1}}}, //
+ R4Spec{{{1, 1, 5, 512}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 0, 0}},
+ {{1, 1, 5, 512}},
+ {{1, 1, 4, 1}}}, //
+ R4Spec{{{1, 1, 513, 512}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 0, 0}},
+ {{1, 1, 513, 512}},
+ {{1, 1, 512, 1}}}, //
+ R4Spec{{{1, 1, 1024, 4}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 15, 0}},
+ {{1, 1, 1022, 4}},
+ {{1, 1, 23, 1}}}, //
+ R4Spec{{{1, 1, 1024, 4}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 14, 0}},
+ {{1, 1, 1023, 4}},
+ {{1, 1, 101, 1}}}, //
+ R4Spec{{{2, 2, 512, 1024}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 0, 0}},
+ {{2, 2, 512, 1024}},
+ {{1, 1, 2, 1}}}, //
+ R4Spec{{{1, 1, 14, 2048}},
+ {{3, 2, 1, 0}},
+ {{0, 0, 2, 0}},
+ {{1, 1, 14, 2}},
+ {{1, 1, 1, 1}}}, //
+};
+
+INSTANTIATE_TEST_CASE_P(SliceR4TestInstantiation, SliceR4Test,
+ ::testing::ValuesIn(kR4SpecValues), R4SpecToString);
+
} // namespace
} // namespace xla