aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-13 15:45:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-13 15:53:35 -0700
commit0355763a32660b2112f155c752f6bf244fcf7007 (patch)
tree94f0f80166ffe28b3020d3e6fef0324fa38dc37f
parenta24f96a3f70e0cc0792410673f55c4a898dea604 (diff)
Change ReferenceUtil::Slice{2,3,4}D to accept a strides parameter.
This also fixes their ordering in xla/reference_util.h This also adds a few stride tests to reference_util_test and adds LiteralTestUtil::ExpectR4Near(). PiperOrigin-RevId: 161876759
-rw-r--r--tensorflow/compiler/xla/reference_util.h68
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc69
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h15
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc8
4 files changed, 133 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 1d326aff5f..41ef1f96ab 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/macros.h"
@@ -301,48 +302,56 @@ class ReferenceUtil {
return result;
}
- // Slices the input array given starting indices in each dimension and limit
- // indices in each dimension.
+ // Slices the input array given starting indices, limit indices, and strides
+ // in each dimension.
template <typename T>
static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
std::array<int64, 2> starts,
- std::array<int64, 2> limits) {
+ std::array<int64, 2> limits,
+ std::array<int64, 2> strides) {
CHECK_LE(starts[0], input.n1());
CHECK_LE(starts[1], input.n2());
CHECK_LE(limits[0], input.n1());
CHECK_LE(limits[1], input.n2());
+ CHECK_GE(strides[0], 1);
+ CHECK_GE(strides[1], 1);
auto result =
- MakeUnique<Array2D<T>>(limits[0] - starts[0], limits[1] - starts[1]);
+ MakeUnique<Array2D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
- (*result)(i0, i1) = input(starts[0] + i0, starts[1] + i1);
+ (*result)(i0, i1) =
+ input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]);
}
}
return result;
}
template <typename T>
- static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
- std::array<int64, 4> starts,
- std::array<int64, 4> limits) {
+ static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
+ std::array<int64, 3> starts,
+ std::array<int64, 3> limits,
+ std::array<int64, 3> strides) {
CHECK_LE(starts[0], input.n1());
CHECK_LE(starts[1], input.n2());
CHECK_LE(starts[2], input.n3());
- CHECK_LE(starts[3], input.n4());
CHECK_LE(limits[0], input.n1());
CHECK_LE(limits[1], input.n2());
CHECK_LE(limits[2], input.n3());
- CHECK_LE(limits[3], input.n4());
+ CHECK_GE(strides[0], 1);
+ CHECK_GE(strides[1], 1);
+ CHECK_GE(strides[2], 1);
auto result =
- MakeUnique<Array4D<T>>(limits[0] - starts[0], limits[1] - starts[1],
- limits[2] - starts[2], limits[3] - starts[3]);
+ MakeUnique<Array3D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]),
+ CeilOfRatio(limits[2] - starts[2], strides[2]));
+
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
- for (int64 i3 = 0; i3 < result->n4(); ++i3) {
- (*result)(i0, i1, i2, i3) = input(starts[0] + i0, starts[1] + i1,
- starts[2] + i2, starts[3] + i3);
- }
+ (*result)(i0, i1, i2) =
+ input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
+ starts[2] + i2 * strides[2]);
}
}
}
@@ -350,22 +359,35 @@ class ReferenceUtil {
}
template <typename T>
- static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
- std::array<int64, 3> starts,
- std::array<int64, 3> limits) {
+ static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
+ std::array<int64, 4> starts,
+ std::array<int64, 4> limits,
+ std::array<int64, 4> strides) {
CHECK_LE(starts[0], input.n1());
CHECK_LE(starts[1], input.n2());
CHECK_LE(starts[2], input.n3());
+ CHECK_LE(starts[3], input.n4());
CHECK_LE(limits[0], input.n1());
CHECK_LE(limits[1], input.n2());
CHECK_LE(limits[2], input.n3());
- auto result = MakeUnique<Array3D<T>>(
- limits[0] - starts[0], limits[1] - starts[1], limits[2] - starts[2]);
+ CHECK_LE(limits[3], input.n4());
+ CHECK_GE(strides[0], 1);
+ CHECK_GE(strides[1], 1);
+ CHECK_GE(strides[2], 1);
+ CHECK_GE(strides[3], 1);
+ auto result =
+ MakeUnique<Array4D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
+ CeilOfRatio(limits[1] - starts[1], strides[1]),
+ CeilOfRatio(limits[2] - starts[2], strides[2]),
+ CeilOfRatio(limits[3] - starts[3], strides[3]));
for (int64 i0 = 0; i0 < result->n1(); ++i0) {
for (int64 i1 = 0; i1 < result->n2(); ++i1) {
for (int64 i2 = 0; i2 < result->n3(); ++i2) {
- (*result)(i0, i1, i2) =
- input(starts[0] + i0, starts[1] + i1, starts[2] + i2);
+ for (int64 i3 = 0; i3 < result->n4(); ++i3) {
+ (*result)(i0, i1, i2, i3) =
+ input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
+ starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]);
+ }
}
}
}
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
index 215f220258..680e908d46 100644
--- a/tensorflow/compiler/xla/reference_util_test.cc
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -132,6 +132,75 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
ErrorSpec(0.0001));
}
+TEST_F(ReferenceUtilTest, SliceArray2D) {
+ auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 2}}, {{1, 1}});
+ auto actual_literal = Literal::CreateR2FromArray2D(*result);
+
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 2.f}, {4.f, 5.f}},
+ *actual_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, SliceStridedArray2D) {
+ auto result = ReferenceUtil::Slice2D(*matrix_, {{0, 0}}, {{2, 3}}, {{1, 2}});
+ auto actual_literal = Literal::CreateR2FromArray2D(*result);
+
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f}, {4.f, 6.f}},
+ *actual_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, SliceArray3D) {
+ Array3D<float> input(2, 3, 4);
+ input.FillIota(0);
+
+ auto result =
+ ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 2, 2}}, {{1, 1, 1}});
+ auto actual_literal = Literal::CreateR3FromArray3D(*result);
+
+ LiteralTestUtil::ExpectR3Near<float>(
+ {{{0.f, 1.f}, {4.f, 5.f}}, {{12.f, 13.f}, {16.f, 17.f}}}, *actual_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, SliceStridedArray3D) {
+ Array3D<float> input(2, 3, 4);
+ input.FillIota(0);
+
+ auto result =
+ ReferenceUtil::Slice3D(input, {{0, 0, 0}}, {{2, 3, 4}}, {{1, 2, 2}});
+ auto actual_literal = Literal::CreateR3FromArray3D(*result);
+
+ LiteralTestUtil::ExpectR3Near<float>(
+ {{{0.f, 2.f}, {8.f, 10.f}}, {{12.f, 14.f}, {20.f, 22.f}}},
+ *actual_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, SliceArray4D) {
+ Array4D<float> input(2, 3, 4, 5);
+ input.FillIota(0);
+
+ auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 2, 2, 2}},
+ {{1, 1, 1, 1}});
+ auto actual_literal = Literal::CreateR4FromArray4D(*result);
+
+ LiteralTestUtil::ExpectR4Near<float>(
+ {{{{60.f, 61.f}, {65.f, 66.f}}, {{80.f, 81.f}, {85.f, 86.f}}}},
+ *actual_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, SliceStridedArray4D) {
+ Array4D<float> input(2, 3, 4, 5);
+ input.FillIota(0);
+
+ auto result = ReferenceUtil::Slice4D(input, {{1, 0, 0, 0}}, {{2, 3, 4, 5}},
+ {{1, 2, 2, 2}});
+ auto actual_literal = Literal::CreateR4FromArray4D(*result);
+
+ LiteralTestUtil::ExpectR4Near<float>(
+ {{{{60.f, 62.f, 64.f}, {70.f, 72.f, 74.f}},
+ {{100.f, 102.f, 104.f}, {110.f, 112.f, 114.f}}}},
+ *actual_literal, ErrorSpec(0.0001));
+}
+
TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
Array4D<float> input(1, 1, 4, 4);
// clang-format off
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
index b7f1e8603f..f645c4e8dc 100644
--- a/tensorflow/compiler/xla/tests/literal_test_util.h
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -131,6 +131,12 @@ class LiteralTestUtil {
std::initializer_list<std::initializer_list<NativeT>>>
expected,
const Literal& actual, const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR4Near(
+ std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>>
+ expected,
+ const Literal& actual, const ErrorSpec& error);
// Asserts the given literal are within the given error bound to the given
// array. Only supported for floating point values.
@@ -283,6 +289,15 @@ template <typename NativeT>
}
template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR4Near(
+ std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>>
+ expected,
+ const Literal& actual, const ErrorSpec& error) {
+ ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
const Array2D<NativeT>& expected, const Literal& actual,
const ErrorSpec& error) {
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index 5e7d475662..35e2f216d7 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -200,8 +200,8 @@ TEST_F(SliceTest, Slice_16x4_To_16x2) {
TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
Array4D<float> values(2, 2, 24, 256);
values.FillRandom(3.14f);
- auto expected =
- ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
+ auto expected = ReferenceUtil::Slice4D(
+ values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}}, /*strides=*/{{1, 1, 1, 1}});
ComputationBuilder builder(client_, TestName());
auto original = builder.ConstantR4FromArray4D(values);
builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128}, {1, 1, 1, 1});
@@ -231,8 +231,8 @@ TEST_P(SliceR2Test, DoIt) {
auto a = builder.ConstantR2FromArray2D<int32>(input);
builder.Slice(a, spec.slice_starts, spec.slice_limits, spec.slice_strides);
- std::unique_ptr<Array2D<int32>> expected =
- ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
+ std::unique_ptr<Array2D<int32>> expected = ReferenceUtil::Slice2D(
+ input, spec.slice_starts, spec.slice_limits, spec.slice_strides);
ComputeAndCompareR2<int32>(&builder, *expected, {});
}