aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-10-09 15:47:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 15:51:30 -0700
commit69c4a426fc4a3afd83c8190467b07c17b8b2ed60 (patch)
tree1c81b33b71efc63bad8519a77026ac96b805be9e
parent771955e2b8be98a0b38fada41bd67f663397c87d (diff)
[XLA] Allow scatter to share the operand buffer with the output
This avoids a copy. PiperOrigin-RevId: 216437329
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc38
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc1
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc38
4 files changed, 78 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index c22adcdd8d..71122e73b1 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -1048,6 +1048,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
+ user->opcode() == HloOpcode::kScatter ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 510d6360a1..d27786d160 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2283,6 +2283,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
+ const char* hlo_text = R"(
+ HloModule TensorFlowScatterV1
+
+ update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+ }
+
+ ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
+ computation_ = module_->entry_computation();
+ RunAnalysis();
+
+ HloInstruction* operand_param = computation_->parameter_instruction(0);
+ HloInstruction* indices_param = computation_->parameter_instruction(1);
+ HloInstruction* updates_param = computation_->parameter_instruction(2);
+ HloInstruction* scatter = computation_->root_instruction();
+
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(
+ operand_param, {}, scatter, {}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
+ indices_param, {}, scatter, {}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(
+ updates_param, {}, scatter, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index 811ac55e2d..ef4e69180d 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -756,6 +756,7 @@ bool TuplePointsToAnalysis::CanShareOperandBufferWithUser(
}
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
+ user->opcode() == HloOpcode::kScatter ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index e9a07b14ed..a571bd571b 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1010,6 +1010,44 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
points_to_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
+ const char* hlo_text = R"(
+ HloModule TensorFlowScatterV1
+
+ update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
+ lhs = s32[] parameter(0)
+ ROOT rhs = s32[] parameter(1)
+ }
+
+ ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[2] parameter(1)
+ updates = s32[2,3] parameter(2)
+ ROOT scatter = s32[3,3] scatter(operand, indices, updates),
+ to_apply=update_s32,
+ update_window_dims={1},
+ inserted_window_dims={0},
+ scatter_dims_to_operand_dims={0},
+ index_vector_dim=1
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
+ computation_ = module_->entry_computation();
+ RunAnalysis();
+
+ HloInstruction* operand_param = computation_->parameter_instruction(0);
+ HloInstruction* indices_param = computation_->parameter_instruction(1);
+ HloInstruction* updates_param = computation_->parameter_instruction(2);
+ HloInstruction* scatter = computation_->root_instruction();
+
+ EXPECT_TRUE(points_to_analysis_->CanShareOperandBufferWithUser(
+ operand_param, {}, scatter, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
+ indices_param, {}, scatter, {}));
+ EXPECT_FALSE(points_to_analysis_->CanShareOperandBufferWithUser(
+ updates_param, {}, scatter, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());