aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index e9a07b14ed..a571bd571b 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1010,6 +1010,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
+ const char* hlo_text = R"(
+ HloModule TensorFlowScatterV1
+
+ update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+ }
+
+ ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
+ computation_ = module_->entry_computation();
+ RunAnalysis();
+
+ HloInstruction* operand_param = computation_->parameter_instruction(0);
+ HloInstruction* indices_param = computation_->parameter_instruction(1);
+ HloInstruction* updates_param = computation_->parameter_instruction(2);
+ HloInstruction* scatter = computation_->root_instruction();
+
+ EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
+ operand_param, {}, scatter, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
+ indices_param, {}, scatter, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
+ updates_param, {}, scatter, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());