aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/scatter_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/scatter_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/scatter_test.cc615
1 files changed, 615 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/tests/scatter_test.cc b/tensorflow/compiler/xla/tests/scatter_test.cc
new file mode 100644
index 0000000000..922d70b752
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/scatter_test.cc
@@ -0,0 +1,615 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::gtl::nullopt;
+
+class ScatterTest : public HloTestBase {
+ protected:
+ void RunTest(const string& hlo_text, Literal* operand,
+ Literal* scatter_indices, Literal* updates) {
+ RunTest(hlo_text, {operand, scatter_indices, updates});
+ }
+
+ void RunTest(const string& hlo_text,
+ tensorflow::gtl::ArraySlice<Literal*> args) {
+ HloModuleConfig config;
+ config.set_debug_options(GetDebugOptionsForTest());
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text, config));
+ EXPECT_TRUE(RunAndCompare(std::move(module), args, nullopt));
+ }
+};
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterV1_Update) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatterV1
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterV2_Update) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterV2
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[3,2] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={0},
+ inserted_window_dims={1},
+ scatter_dims_to_operand_dims={1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 30}, {40, 60}, {70, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_Add) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_Add
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_Mul) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_Mul
+
+mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=mul_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_F32) {
+ const string hlo_text = R"(
+HloModule TensorFlowScatter_F32
+
+add_f32 (lhs: f32[], rhs: f32[]) -> f32[] {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(f32[] lhs, f32[] rhs)
+}
+
+ENTRY main {
+ operand = f32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = f32[2,3] parameter(2)
+ ROOT scatter = f32[3,3] scatter(operand, indices, updates),
+ to_apply=add_f32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<float>(
+ {{1.1, 2.2, 3.3}, {4.4, 5.5, 6.6}, {7.7, 8.8, 9.9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({2, 1});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<float>({{0.4, 1.1, 0.7}, {2.3, 3.1, 1.6}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_RepeatedIndices) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20, 30}, {70, 80, 90}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatter_MultipleBatchDims) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterMultipleBatchDims
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,3,2] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={1},
+ inserted_window_dims={1},
+ scatter_dims_to_operand_dims={1},
+ index_vector_dim=2
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 2}, {2, 1}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10, 30}, {40, 60}, {70, 90}}, {{5, 5}, {5, 5}, {5, 5}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterNd) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNd
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0,1},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{-10, 10}, {-40, 40}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, TensorFlowScatterNd_NonDefaultIndexVectorDim) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatterNdNonDefaultIndexVectorDim
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3,2] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3,3,2] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0,1},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR3<int32>({{{-1, 1}, {-2, 2}, {-3, 3}}, //
+ {{-4, 4}, {-5, 5}, {-6, 6}}, //
+ {{-7, 7}, {-8, 8}, {-9, 9}}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{0, 0}, {1, 0}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{-10, 10}, {-20, 20}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, DynamicUpdateSlice) {
+ const char* hlo_text = R"(
+HloModule DynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[1,1] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={0,1},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({1, 1});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{10}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, BatchDynamicUpdateSlice) {
+ const char* hlo_text = R"(
+HloModule BatchDynamicUpdateSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2,2] parameter(1)
+ updates = s32[2,1,1] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR2<int32>({{2, 1}, {1, 1}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR3<int32>({{{10}}, {{20}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, ZeroDimBounds) {
+ const char* hlo_text = R"(
+HloModule TensorFlowScatter_ZeroDimBounds
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,0] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,0] parameter(2)
+ ROOT scatter = s32[3,0] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR2<int32>({{}, {}, {}});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR1<int32>({0, 2});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR2<int32>({{}, {}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, NoUpdateWindowDims) {
+ const string hlo_text = R"(
+HloModule Scatter_NoUpdateWindowDims
+
+add_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ rhs = s32[] parameter(1)
+ ROOT add = s32[] add(s32[] lhs, s32[] rhs)
+}
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[2,2,1] parameter(1)
+ updates = s32[2,2] parameter(2)
+ ROOT scatter = s32[3] scatter(operand, indices, updates),
+ to_apply=add_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=2
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({0, 1, 2});
+ std::unique_ptr<Literal> scatter_indices =
+ LiteralUtil::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR2<int32>({{10, 20}, {30, 40}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OutOfBoundsIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483647, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OutOfBoundsUnsignedIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = u32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<uint32>(
+ {{2, 7}, {2, 1}, {1, 1}, {5, 1}, {2147483648u, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, NegativeIndex) {
+ const string hlo_text = R"(
+HloModule BatchDynamicSlice
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3,3]{1,0} parameter(0)
+ indices = s32[6,2]{1,0} parameter(1)
+ updates = s32[6,1,1]{2,1,0} parameter(2)
+ ROOT scatter = s32[3,3]{1,0} scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0,1},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand =
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR2<int32>(
+ {{2, 7}, {2, 1}, {1, 1}, {-500, 1}, {-2147483648, 1}, {1, 2}});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR3<int32>(
+ {{{10}}, {{20}}, {{30}}, {{40}}, {{50}}, {{60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, OneScalarIndex) {
+ const char* hlo_text = R"(
+HloModule OneScalarIndex
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[2,3,2]{2,1,0} parameter(0)
+ index = s32[] parameter(1)
+ updates = s32[1,3,2]{2,1,0} parameter(2)
+ ROOT scatter = s32[2,3,2]{2,1,0} scatter(operand, index, updates),
+ to_apply=update_s32,
+ update_window_dims={0,1,2},
+ inserted_window_dims={},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR3<int32>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ std::unique_ptr<Literal> updates =
+ LiteralUtil::CreateR3<int32>({{{10, 20}, {30, 40}, {50, 60}}});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, ScalarUpdate) {
+ const char* hlo_text = R"(
+HloModule ScalarUpdate
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[4]{0} parameter(0)
+ index = s32[] parameter(1)
+ updates = s32[] parameter(2)
+ ROOT scatter = s32[4]{0} scatter(operand, index, updates),
+ to_apply=update_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=0
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3, 4});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR0<int32>(1);
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR0<int32>(25);
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+XLA_TEST_F(ScatterTest, EmptyIndices) {
+ const string hlo_text = R"(
+HloModule EmptyIndices
+
+update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+}
+
+ENTRY main {
+ operand = s32[3] parameter(0)
+ indices = s32[0] parameter(1)
+ updates = s32[0] parameter(2)
+ ROOT scatter = s32[3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+}
+)";
+ std::unique_ptr<Literal> operand = LiteralUtil::CreateR1<int32>({1, 2, 3});
+ std::unique_ptr<Literal> scatter_indices = LiteralUtil::CreateR1<int32>({});
+ std::unique_ptr<Literal> updates = LiteralUtil::CreateR1<int32>({});
+ RunTest(hlo_text, operand.get(), scatter_indices.get(), updates.get());
+}
+
+} // namespace
+} // namespace xla