aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc46
1 files changed, 44 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 37bc2d2c9d..4755c4a0cf 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2232,6 +2232,48 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
dataflow_analysis_->CanShareOperandBufferWithUser(starts, {}, dus, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto sort =
+ builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape values_shape = ShapeUtil::MakeShape(F32, {8});
+ auto keys = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, keys_shape, "keys"));
+ auto values = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values"));
+ auto sort = builder.AddInstruction(HloInstruction::CreateSort(
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // The buffer for the keys can be shared with the first tuple entry.
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
+ // The buffer for the values can be shared with the second tuple entry.
+ EXPECT_TRUE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {1}));
+ // Verify that the buffers are not shared with the "wrong" tuple entry.
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(values, {}, sort, {0}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
auto builder = HloComputation::Builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
@@ -2323,7 +2365,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
- auto make_cond = [this, &data_shape]() {
+ auto make_cond = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Cond");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));
@@ -2332,7 +2374,7 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
return builder.Build();
};
- auto make_body = [this, &data_shape]() {
+ auto make_body = [&data_shape]() {
auto builder = HloComputation::Builder(TestName() + ".Body");
auto data = builder.AddInstruction(
HloInstruction::CreateParameter(0, data_shape, "data"));