diff options
author | 2017-07-13 15:45:44 -0700 | |
---|---|---|
committer | 2017-07-13 15:53:35 -0700 | |
commit | 0355763a32660b2112f155c752f6bf244fcf7007 (patch) | |
tree | 94f0f80166ffe28b3020d3e6fef0324fa38dc37f | |
parent | a24f96a3f70e0cc0792410673f55c4a898dea604 (diff) |
Change ReferenceUtil::Slice{2,3,4}D to accept a strides parameter.
This also fixes their ordering in xla/reference_util.h
This also adds a few stride tests to reference_util_test and
adds LiteralTestUtil::ExpectR4Near().
PiperOrigin-RevId: 161876759
-rw-r--r-- | tensorflow/compiler/xla/reference_util.h | 68 | ||||
-rw-r--r-- | tensorflow/compiler/xla/reference_util_test.cc | 69 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/literal_test_util.h | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/slice_test.cc | 8 |
4 files changed, 133 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 1d326aff5f..41ef1f96ab 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" @@ -301,48 +302,56 @@ class ReferenceUtil { return result; } - // Slices the input array given starting indices in each dimension and limit - // indices in each dimension. + // Slices the input array given starting indices, limit indices, and strides + // in each dimension. template <typename T> static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input, std::array<int64, 2> starts, - std::array<int64, 2> limits) { + std::array<int64, 2> limits, + std::array<int64, 2> strides) { CHECK_LE(starts[0], input.n1()); CHECK_LE(starts[1], input.n2()); CHECK_LE(limits[0], input.n1()); CHECK_LE(limits[1], input.n2()); + CHECK_GE(strides[0], 1); + CHECK_GE(strides[1], 1); auto result = - MakeUnique<Array2D<T>>(limits[0] - starts[0], limits[1] - starts[1]); + MakeUnique<Array2D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { - (*result)(i0, i1) = input(starts[0] + i0, starts[1] + i1); + (*result)(i0, i1) = + input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]); } } return result; } template <typename T> - static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input, - std::array<int64, 4> starts, - std::array<int64, 4> limits) { + static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input, + std::array<int64, 3> starts, + std::array<int64, 3> limits, + std::array<int64, 3> strides) { CHECK_LE(starts[0], input.n1()); CHECK_LE(starts[1], input.n2()); CHECK_LE(starts[2], input.n3()); - CHECK_LE(starts[3], input.n4()); CHECK_LE(limits[0], input.n1()); CHECK_LE(limits[1], input.n2()); CHECK_LE(limits[2], input.n3()); - CHECK_LE(limits[3], input.n4()); + CHECK_GE(strides[0], 1); + CHECK_GE(strides[1], 1); + CHECK_GE(strides[2], 1); auto result = - MakeUnique<Array4D<T>>(limits[0] - starts[0], limits[1] - starts[1], - limits[2] - starts[2], limits[3] - starts[3]); + MakeUnique<Array3D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2])); + for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { - for (int64 i3 = 0; i3 < result->n4(); ++i3) { - (*result)(i0, i1, i2, i3) = input(starts[0] + i0, starts[1] + i1, - starts[2] + i2, starts[3] + i3); - } + (*result)(i0, i1, i2) = + input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1], + starts[2] + i2 * strides[2]); } } } @@ -350,22 +359,35 @@ class ReferenceUtil { } template <typename T> - static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input, - std::array<int64, 3> starts, - std::array<int64, 3> limits) { + static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input, + std::array<int64, 4> starts, + std::array<int64, 4> limits, + std::array<int64, 4> strides) { CHECK_LE(starts[0], input.n1()); CHECK_LE(starts[1], input.n2()); CHECK_LE(starts[2], input.n3()); + CHECK_LE(starts[3], input.n4()); CHECK_LE(limits[0], input.n1()); CHECK_LE(limits[1], input.n2()); CHECK_LE(limits[2], input.n3()); - auto result = MakeUnique<Array3D<T>>( - limits[0] - starts[0], limits[1] - starts[1], limits[2] - starts[2]); + CHECK_LE(limits[3], input.n4()); + CHECK_GE(strides[0], 1); + CHECK_GE(strides[1], 1); + CHECK_GE(strides[2], 1); + CHECK_GE(strides[3], 1); + auto result = + MakeUnique<Array4D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]), + CeilOfRatio(limits[1] - starts[1], strides[1]), + CeilOfRatio(limits[2] - starts[2], strides[2]), + CeilOfRatio(limits[3] - starts[3], strides[3])); for (int64 i0 = 0; i0 < result->n1(); ++i0) { for (int64 i1 = 0; i1 < result->n2(); ++i1) { for (int64 i2 = 0; i2 < result->n3(); ++i2) { - (*result)(i0, i1, i2) = - input(starts[0] + i0, starts[1] + i1, starts[2] + i2); + for (int64 i3 = 0; i3 < result->n4(); ++i3) { + (*result)(i0, i1, i2, i3) = + input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1], + starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]); + } } } } diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 215f220258..680e908d46 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -132,6 +132,75 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, SliceArray2D) { + auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}}); + auto actual_literal = Literal::CreateR2FromArray2D(*result); + + LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceStridedArray2D) { + auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}}); + auto actual_literal = Literal::CreateR2FromArray2D(*result); + + LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceArray3D) { + Array3D<float> input(2, 3, 4); + input.FillIota(0); + + auto result = + ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}}); + auto actual_literal = Literal::CreateR3FromArray3D(*result); + + LiteralTestUtil::ExpectR3Near<float>( + {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal, + ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceStridedArray3D) { + Array3D<float> input(2, 3, 4); + input.FillIota(0); + + auto result = + ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}}); + auto actual_literal = Literal::CreateR3FromArray3D(*result); + + LiteralTestUtil::ExpectR3Near<float>( + {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceArray4D) { + Array4D<float> input(2, 3, 4, 5); + input.FillIota(0); + + auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}}, + {{1, 1, 1, 1}}); + auto actual_literal = Literal::CreateR4FromArray4D(*result); + + LiteralTestUtil::ExpectR4Near<float>( + {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}}, + *actual_literal, ErrorSpec(0.0001)); +} + +TEST_F(ReferenceUtilTest, SliceStridedArray4D) { + Array4D<float> input(2, 3, 4, 5); + input.FillIota(0); + + auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}}, + {{1, 2, 2, 2}}); + auto actual_literal = Literal::CreateR4FromArray4D(*result); + + LiteralTestUtil::ExpectR4Near<float>( + {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}}, + {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}}, + *actual_literal, ErrorSpec(0.0001)); +} + TEST_F(ReferenceUtilTest, ConvWithSamePadding) { Array4D<float> input(1, 1, 4, 4); // clang-format off diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index b7f1e8603f..f645c4e8dc 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -131,6 +131,12 @@ class LiteralTestUtil { std::initializer_list<std::initializer_list<NativeT>>> expected, const Literal& actual, const ErrorSpec& error); + template <typename NativeT> + static void ExpectR4Near( + std::initializer_list<std::initializer_list< + std::initializer_list<std::initializer_list<NativeT>>>> + expected, + const Literal& actual, const ErrorSpec& error); // Asserts the given literal are within the given error bound to the given // array. Only supported for floating point values. @@ -283,6 +289,15 @@ template <typename NativeT> } template <typename NativeT> +/* static */ void LiteralTestUtil::ExpectR4Near( + std::initializer_list<std::initializer_list< + std::initializer_list<std::initializer_list<NativeT>>>> + expected, + const Literal& actual, const ErrorSpec& error) { + ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error); +} + +template <typename NativeT> /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D<NativeT>& expected, const Literal& actual, const ErrorSpec& error) { diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 5e7d475662..35e2f216d7 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -200,8 +200,8 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) { TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) { Array4D<float> values(2, 2, 24, 256); values.FillRandom(3.14f); - auto expected = - ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}); + auto expected = ReferenceUtil::Slice4D( + values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}}); ComputationBuilder builder(client_, TestName()); auto original = builder.ConstantR4FromArray4D(values); builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1}); @@ -231,8 +231,8 @@ TEST_P(SliceR2Test, DoIt) { auto a = builder.ConstantR2FromArray2D<int32>(input); builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides); - std::unique_ptr<Array2D<int32>> expected = - ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits); + std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D( + input, spec.slice_starts, spec.slice_limits, spec.slice_strides); ComputeAndCompareR2<int32>(&builder, *expected, {}); } |