aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-01-24 17:24:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-24 17:28:18 -0800
commit0412e0946bdd2765d5c3dba0cc9b12b8650f564a (patch)
tree7a1d932129332f1dff3a3e3fa8eab2cbf2907af1
parent7c4f482a851d12a3b0187bbf31db65ff6b7a7ad3 (diff)
Add R3 and R5 tests to select and scatter.
- Added evaluator support so that we can test S&S in arbitrary dimensions. RELNOTES: n/a PiperOrigin-RevId: 183168473
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc189
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc189
3 files changed, 276 insertions, 103 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 2112cf57c7..e3f5c17e35 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -43,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -166,6 +167,34 @@ StatusOr<std::unique_ptr<Literal>> ElementWiseUnaryOpImpl(
return std::move(result);
}
+// For one particular placement of a window in a base shape (the placement is
+// represented as `window_count_index`), iterates inside the window. Translates
+// the window index into base index. If the base index is within bound, call `f`
+// with the base index.
+void IterateThroughWindow(
+ const Shape& window_shape, const Window& window, const Shape& base_shape,
+ const tensorflow::gtl::ArraySlice<int64>& window_count_index,
+ const std::function<void(const std::vector<int64>&)>& f) {
+ const int64 rank = ShapeUtil::Rank(base_shape);
+ DimensionVector window_index(rank);
+ std::fill(window_index.begin(), window_index.end(), 0);
+ do {
+ std::vector<int64> base_index(rank);
+ bool out_of_bound = false;
+ for (int64 i = 0; i < rank; ++i) {
+ base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] - window.dimensions(i).padding_low();
+ if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
+ out_of_bound = true;
+ break;
+ }
+ }
+ if (!out_of_bound) {
+ f(base_index);
+ }
+ } while (IndexUtil::BumpIndices(window_shape, &window_index));
+}
+
} // namespace
template <typename ReturnT, typename ElementwiseT>
@@ -1420,6 +1449,111 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
return Status::OK();
}
+ Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
+ auto operand = select_and_scatter->operand(0);
+ auto source = select_and_scatter->operand(1);
+ const Window& window = select_and_scatter->window();
+
+ const Literal& init_literal =
+ parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2));
+ TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape()));
+ auto init_scalar = init_literal.Get<ReturnT>({});
+
+ auto result = Literal::CreateFromShape(select_and_scatter->shape());
+
+ // Initialize result array with the init value.
+ TF_RETURN_IF_ERROR(result->Populate<ReturnT>(
+ [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ return init_scalar;
+ }));
+
+ std::vector<int64> window_dimension_sizes;
+ for (const auto& window_dimension : window.dimensions()) {
+ window_dimension_sizes.push_back(window_dimension.size());
+ }
+ const Shape window_shape = ShapeUtil::MakeShape(
+ operand->shape().element_type(), window_dimension_sizes);
+
+ HloComputation* select = select_and_scatter->select();
+ HloComputation* scatter = select_and_scatter->scatter();
+
+ const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand);
+ const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source);
+
+ int64 rank = ShapeUtil::Rank(operand_literal.shape());
+
+ HloEvaluator embedded_evaluator;
+ DimensionVector source_index(rank);
+
+ std::fill(source_index.begin(), source_index.end(), 0);
+ do {
+ // For each element in `source`, we place a window in `operand`. For each
+ // window placement, we iterate inside the window twice:
+ //
+ // 1. Find the selected index by applying `select` function to all
+ // elements. E.g., If the `select` function is GreaterEqual, the first
+ // iteration through the window finds the biggest value and returns its
+ // index.
+ //
+ // 2. Using the selected index, scatter value from `source` to result. We
+ // do this by iterating through the window, and compare each index with
+ // the selected index.
+ tensorflow::gtl::optional<ReturnT> selected_val;
+ tensorflow::gtl::optional<std::vector<int64>> selected_index;
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), source_index,
+ [&](const std::vector<int64>& operand_index) {
+ auto curr_val = operand_literal.Get<ReturnT>(operand_index);
+ if (!selected_val) {
+ selected_val = curr_val;
+ selected_index = operand_index;
+ }
+ const auto curr_val_literal = Literal::CreateR0<ReturnT>(curr_val);
+ const auto selected_val_literal =
+ Literal::CreateR0<ReturnT>(*selected_val);
+
+ const std::vector<const Literal*> args = {
+ curr_val_literal.get(), selected_val_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*select, args)
+ .ConsumeValueOrDie();
+ bool selected = computed_result->Get<bool>({});
+ if (selected) {
+ selected_val = curr_val;
+ selected_index = operand_index;
+ }
+ embedded_evaluator.ResetVisitStates();
+ });
+
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), source_index,
+ [&](const std::vector<int64>& operand_index) {
+ if (std::equal(operand_index.begin(), operand_index.end(),
+ selected_index->begin())) {
+ auto source = source_literal.Get<ReturnT>(source_index);
+ auto scattered = result->Get<ReturnT>(operand_index);
+ const auto source_literal = Literal::CreateR0<ReturnT>(source);
+ const auto scattered_literal =
+ Literal::CreateR0<ReturnT>(scattered);
+
+ const std::vector<const Literal*> args = {
+ source_literal.get(), scattered_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*scatter, args)
+ .ConsumeValueOrDie();
+ result->Set(operand_index, computed_result->Get<ReturnT>({}));
+ // Clear visit states so that the we can use the evaluator again
+ // on the same computation.
+ embedded_evaluator.ResetVisitStates();
+ }
+ });
+ } while (IndexUtil::BumpIndices(source->shape(), &source_index));
+
+ parent_->evaluated_[select_and_scatter] = std::move(result);
+ return Status::OK();
+ }
+
Status HandleReduceWindow(HloInstruction* reduce_window) override {
auto operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
@@ -1468,39 +1602,28 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
std::fill(window_index.begin(), window_index.end(), 0);
std::fill(operand_index.begin(), operand_index.end(), 0);
- do {
- bool out_of_bound = false;
- for (int i = 0; i < operand_index.size(); ++i) {
- operand_index[i] =
- output_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
- if (operand_index[i] < 0 ||
- operand_index[i] >= operand_literal.shape().dimensions(i)) {
- out_of_bound = true;
- break;
- }
- }
- if (!out_of_bound) {
- auto curr_val = operand_literal.Get<ReturnT>(operand_index);
-
- // Evaluate computation with specified literal operands.
- const auto curr_val_literal =
- Literal::CreateR0<ReturnT>(curr_val);
- const auto result_val_literal =
- Literal::CreateR0<ReturnT>(result_val);
- const std::vector<const Literal*> args = {
- curr_val_literal.get(), result_val_literal.get()};
- std::unique_ptr<Literal> computed_result =
- embedded_evaluator.Evaluate<const Literal*>(*function, args)
- .ConsumeValueOrDie();
-
- // Clear visit states so that the we can use the evaluate again on
- // the same computation.
- embedded_evaluator.ResetVisitStates();
-
- result_val = computed_result->Get<ReturnT>({});
- }
- } while (IndexUtil::BumpIndices(window_shape, &window_index));
+ IterateThroughWindow(
+ window_shape, window, operand_literal.shape(), output_index,
+ [&](const std::vector<int64>& operand_index) {
+ auto curr_val = operand_literal.Get<ReturnT>(operand_index);
+
+ // Evaluate computation with specified literal operands.
+ const auto curr_val_literal =
+ Literal::CreateR0<ReturnT>(curr_val);
+ const auto result_val_literal =
+ Literal::CreateR0<ReturnT>(result_val);
+ const std::vector<const Literal*> args = {
+ curr_val_literal.get(), result_val_literal.get()};
+ std::unique_ptr<Literal> computed_result =
+ embedded_evaluator.Evaluate<const Literal*>(*function, args)
+ .ConsumeValueOrDie();
+
+ // Clear visit states so that the we can use the evaluate again
+ // on the same computation.
+ embedded_evaluator.ResetVisitStates();
+
+ result_val = computed_result->Get<ReturnT>({});
+ });
return result_val;
}));
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index bc15bd9593..3afd52b6b2 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1034,6 +1034,7 @@ xla_test(
name = "select_and_scatter_test",
timeout = "long",
srcs = ["select_and_scatter_test.cc"],
+ tags = ["enable_for_xla_interpreter"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal_util",
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
index 62ff349e9c..9ee94b8571 100644
--- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -39,8 +39,8 @@ namespace xla {
namespace {
struct SelectAndScatterTestParam {
- Array4D<float> operand_shape;
- Array4D<float> source_shape;
+ std::vector<int64> operand_shape;
+ std::vector<int64> source_shape;
Padding padding_type;
tensorflow::gtl::ArraySlice<int64> window_dimensions;
tensorflow::gtl::ArraySlice<int64> window_strides;
@@ -69,83 +69,132 @@ class SelectAndScatterTest
Computation min_f32_;
};
-XLA_TEST_P(SelectAndScatterTest, R4Randomized) {
- Array4D<float> o(GetParam().operand_shape);
+XLA_TEST_P(SelectAndScatterTest, ParamTest) {
+ auto operand_shape = GetParam().operand_shape;
+ Array<float> o(operand_shape);
o.FillRandom(1.5f);
- auto operand = builder_.ConstantR4FromArray4D(o);
+ auto operand = builder_.ConstantFromArray(o);
- Array4D<float> s(GetParam().source_shape);
+ auto source_shape = GetParam().source_shape;
+ Array<float> s(source_shape);
s.FillRandom(12.0f);
- auto source = builder_.ConstantR4FromArray4D(s);
-
- builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions,
- GetParam().window_strides, GetParam().padding_type,
- source, builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto source = builder_.ConstantFromArray(s);
- auto e = ReferenceUtil::SelectAndScatter4DGePlus(
- o, s, 0.0f, GetParam().window_dimensions, GetParam().window_strides,
- GetParam().padding_type == Padding::kSame);
+ auto select_and_scatter = builder_.SelectAndScatter(
+ operand, ge_f32_, GetParam().window_dimensions, GetParam().window_strides,
+ GetParam().padding_type, source, builder_.ConstantR0<float>(0.0f),
+ add_f32_);
- ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-5));
+ ComputeAndCompare(&builder_, select_and_scatter, {}, ErrorSpec(1e-5));
}
INSTANTIATE_TEST_CASE_P(
SelectAndScatterTest_Instantiation, SelectAndScatterTest,
- ::testing::Values(SelectAndScatterTestParam{{6, 6, 256, 128},
- {3, 3, 256, 128},
- Padding::kSame,
- {3, 3, 1, 1},
- {2, 2, 1, 1}},
- SelectAndScatterTestParam{{7, 7, 256, 128},
- {3, 3, 256, 128},
- Padding::kValid,
- {3, 3, 1, 1},
- {2, 2, 1, 1}},
- SelectAndScatterTestParam{{6, 7, 256, 128},
- {3, 3, 256, 128},
- Padding::kValid,
- {2, 3, 1, 1},
- {2, 2, 1, 1}},
- SelectAndScatterTestParam{{6, 7, 256, 128},
- {2, 3, 256, 128},
- Padding::kValid,
- {2, 3, 1, 1},
- {3, 2, 1, 1}},
- SelectAndScatterTestParam{{9, 9, 16, 128},
- {3, 3, 16, 128},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{3, 3, 4, 4},
- {1, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{3, 3, 4, 4},
- {1, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{9, 3, 4, 4},
- {3, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{7, 3, 4, 4},
- {3, 1, 4, 4},
- Padding::kValid,
- {3, 3, 1, 1},
- {2, 3, 1, 1}},
- SelectAndScatterTestParam{{1, 1, 5, 5},
- {1, 1, 5, 5},
- Padding::kSame,
- {3, 3, 1, 1},
- {3, 3, 1, 1}},
- SelectAndScatterTestParam{{7, 7, 8, 256},
- {4, 4, 8, 256},
- Padding::kSame,
- {2, 2, 1, 1},
- {2, 2, 1, 1}}));
+ ::testing::Values(
+ SelectAndScatterTestParam{{6, 6, 6, 4, 4},
+ {3, 3, 3, 4, 4},
+ Padding::kSame,
+ {3, 3, 3, 1, 1},
+ {2, 2, 2, 1, 1}},
+ SelectAndScatterTestParam{{7, 7, 7, 4, 4},
+ {3, 3, 3, 4, 4},
+ Padding::kValid,
+ {3, 3, 3, 1, 1},
+ {2, 2, 2, 1, 1}},
+
+ SelectAndScatterTestParam{{8, 8, 8, 4, 4},
+ {1, 3, 3, 4, 4},
+ Padding::kValid,
+ {8, 4, 4, 1, 1},
+ {1, 2, 2, 1, 1}},
+ SelectAndScatterTestParam{{6, 6, 256, 128},
+ {3, 3, 256, 128},
+ Padding::kSame,
+ {3, 3, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{{7, 7, 256, 128},
+ {3, 3, 256, 128},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{{6, 7, 256, 128},
+ {3, 3, 256, 128},
+ Padding::kValid,
+ {2, 3, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{{6, 7, 256, 128},
+ {2, 3, 256, 128},
+ Padding::kValid,
+ {2, 3, 1, 1},
+ {3, 2, 1, 1}},
+ SelectAndScatterTestParam{{9, 9, 16, 128},
+ {3, 3, 16, 128},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{3, 3, 4, 4},
+ {1, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{3, 3, 4, 4},
+ {1, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{9, 3, 4, 4},
+ {3, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{7, 3, 4, 4},
+ {3, 1, 4, 4},
+ Padding::kValid,
+ {3, 3, 1, 1},
+ {2, 3, 1, 1}},
+ SelectAndScatterTestParam{{1, 1, 5, 5},
+ {1, 1, 5, 5},
+ Padding::kSame,
+ {3, 3, 1, 1},
+ {3, 3, 1, 1}},
+ SelectAndScatterTestParam{{7, 7, 8, 256},
+ {4, 4, 8, 256},
+ Padding::kSame,
+ {2, 2, 1, 1},
+ {2, 2, 1, 1}},
+ SelectAndScatterTestParam{
+ {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
+ SelectAndScatterTestParam{
+ {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}},
+ SelectAndScatterTestParam{{7, 256, 128},
+ {3, 256, 128},
+ Padding::kValid,
+ {3, 1, 1},
+ {2, 1, 1}},
+ SelectAndScatterTestParam{{6, 256, 128},
+ {3, 256, 128},
+ Padding::kValid,
+ {2, 1, 1},
+ {2, 1, 1}},
+ SelectAndScatterTestParam{{6, 256, 128},
+ {2, 256, 128},
+ Padding::kValid,
+ {2, 1, 1},
+ {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}},
+ SelectAndScatterTestParam{
+ {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}},
+ SelectAndScatterTestParam{
+ {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}}));
// Test for F32 1D array, with a zero-element input.
XLA_TEST_F(SelectAndScatterTest, R1S0F32) {