aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-06-07 18:42:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-07 18:45:05 -0700
commita9ddfe50eee83b2f18293241ab96f0a1e2b4b05b (patch)
tree849afa106cb882d918cfa0a633134bb89f7f014f /tensorflow/compiler/xla
parent2f41346cbc0c8ecb915983a1f8711fd0d0ccc50e (diff)
[DataFlowAnalysis] Be less conservative on loop fusion nodes when reusing buffer.
- Previously, we say we cannot reuse operand buffer for a loop fusion node if any of the fusion's inputs is a broadcast or reshape. That's too conservative since in theory we can still reuse the operand's buffer if all the users of that particular operand are elementwise. This CL implements that. - Also fixed a bug in previous code where a dynamic update fusion node that ends with convert (added for bf16) is not caught by the if condition currectly. PiperOrigin-RevId: 199731488
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc31
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc123
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc19
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
6 files changed, 181 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index cc130a4900..d020005868 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -931,16 +931,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}
const HloUse& use = value.uses()[0];
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
- user->fused_expression_root()->opcode() ==
- HloOpcode::kDynamicUpdateSlice) {
- // Loop fusion with kDynamicUpdateSlice fused root.
- //
- // Returns true iff there is exactly one use of 'operand' at shape index
- // 'operand_index', and this singleton use is the fused root at operand
- // index 0.
- return use.instruction == user->fused_expression_root() &&
- use.operand_number == 0;
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ if (user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == 0;
+ }
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -967,6 +968,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
use.operand_number == other_add_operand_index;
}
}
+
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
user->opcode() == HloOpcode::kWhile) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
@@ -998,8 +1000,13 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
}) != uses.end();
return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
}
- // Check if 'user' is element-wise.
- return user->IsElementwise();
+
+ // Loop fusions that contain transposing copies won't reach here as they have
+ // different layouts, which fails the check in the beginning of this function.
+ //
+ // Multi-output fusion will fail the check here as tuples are not considered
+ // an elementwise operation.
+ return user->IsElementwiseOnOperand(user->operand_index(operand));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 5798326dcb..db1822ec47 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -1974,6 +1974,89 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
dataflow_analysis_->CanShareOperandBufferWithUser(exp, {}, log, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ NonElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape, "param0"));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
+
+ auto reverse = builder.AddInstruction(
+ HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {reverse, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ MultiOutputFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ Shape in_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, in_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, in_shape, "param1"));
+
+ auto copy0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
+ auto copy1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {0}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
+ fusion, {1}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {0}));
+ EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
+ fusion, {1}));
+}
+
+TEST_F(CanShareOperandBufferWithUserTest,
+ ElementwiseLoopFusionCantAliasOperandBuffer) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto operand = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape, one, {1}));
+
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
+
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {exp, neg}, HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(operand, {},
+ fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
auto builder = HloComputation::Builder(TestName());
@@ -2048,6 +2131,46 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
fusion, {}));
}
+TEST_F(CanShareOperandBufferWithUserTest,
+ FusedDynamicUpdateSliceWithConvertCantShare) {
+ auto builder = HloComputation::Builder(TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
+ auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
+
+ auto convert1 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape_bf16, gte1));
+
+ // Create a DynamicUpdateSlice instruction of tuple element 1.
+ auto starts = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR1<int32>({2})));
+ auto update = builder.AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR1<float>({2.f, 2.f, 2.f})));
+ auto dynamic_update_slice =
+ builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ data_shape_bf16, convert1, update, starts));
+
+ auto convert2 = builder.AddInstruction(
+ HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
+ builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
+
+ BuildModule(builder.Build());
+ auto fusion = computation_->CreateFusionInstruction(
+ {convert2, dynamic_update_slice, starts, update, convert1},
+ HloInstruction::FusionKind::kLoop);
+ RunAnalysis();
+
+ // The fusion instruction can't share with tuple element 1.
+ EXPECT_FALSE(
+ dataflow_analysis_->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
+}
+
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index cf1530abe1..570ad5459a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -398,6 +398,11 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
instruction->AppendOperand(operand);
}
instruction->called_computations_.push_back(map_computation);
+ // TODO(b/65689298) Remove code below once Map is generalized to accept
+ // arbitrary map dimensions.
+ instruction->dimensions_.resize(ShapeUtil::Rank(shape));
+ std::iota(instruction->dimensions_.begin(), instruction->dimensions_.end(),
+ 0);
return instruction;
}
@@ -1603,7 +1608,7 @@ bool HloInstruction::HasLiteral() const { return literal_ != nullptr; }
bool HloInstruction::CanHaveDimensionsField() const {
return (opcode() == HloOpcode::kReverse ||
- opcode() == HloOpcode::kConcatenate ||
+ opcode() == HloOpcode::kConcatenate || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
opcode() == HloOpcode::kTranspose);
}
@@ -3151,7 +3156,19 @@ bool HloInstruction::IsElementwise() const {
// Other operations.
case HloOpcode::kRng:
+ return true;
case HloOpcode::kMap:
+ if (!dimensions().empty()) {
+ // Check that the map is executed in elementwise compatible dimensions.
+ if (dimensions().size() != operand(0)->shape().dimensions_size()) {
+ return false;
+ }
+ for (int i = 0; i < dimensions().size(); ++i) {
+ if (dimensions()[i] != i) {
+ return false;
+ }
+ }
+ }
return true;
case HloOpcode::kFusion:
if (fusion_kind() != FusionKind::kLoop) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 313033ddad..76349c4099 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -980,6 +980,23 @@ TEST_F(HloInstructionTest, FullyElementwise) {
}
}
+TEST_F(HloInstructionTest, MapIsElementwise) {
+ auto module = CreateNewModule();
+ const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
+ HloComputation::Builder builder(TestName());
+ HloComputation::Builder map_builder("id");
+ map_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
+ auto map = builder.AddInstruction(
+ HloInstruction::CreateMap(r2f32, {x}, map_computation));
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(map->IsElementwise());
+}
+
TEST_F(HloInstructionTest, PartiallyElementwise) {
const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 3eadedfe1f..a1bc269400 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -777,6 +777,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
optional<HloComputation*> to_apply;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
+ optional<std::vector<tensorflow::int64>> dimensions;
+ attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List,
+ &dimensions};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 08068dc504..1c5a47c875 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -765,7 +765,7 @@ add_F32.v3 {
ENTRY MapBinaryAdder.v3 {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
- ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3
+ ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
}
)"