diff options
author | Yunxing Dai <yunxing@google.com> | 2018-06-21 12:32:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-21 12:35:04 -0700 |
commit | fc4484c359cab66bd5bfdfaab936b1a5128850be (patch) | |
tree | 3037681e280ed6729175c2673e4096449d2027e6 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | |
parent | 1ee5e2ce389a8dbf11db25ff37347715e7dc7efc (diff) |
Enable multioutput fusion opearnd buffer reuse.
- Enable multioutput fusion opearnd buffer reuse.
- Fix a bug in heap simulator where a buffer can be reused twice.
- Add unittest.
PiperOrigin-RevId: 201567720
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc | 80 |
1 files changed, 76 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d020005868..08a705b18d 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -34,6 +34,49 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" namespace xla { +namespace { + +// We have this pattern in dynamaic update slice fusion, which should be +// supported: +// +// Parameters: p0, p1 +// Fusion +// ds = DynamicSlice(p0, p1) +// ROOT DynamicUpdateslice(p0, ds, p1) +// +// In this case, we should be able to reuse p0 and output, although p0 has +// multiple uses. +bool MultiDynamicSliceUseShareSameIndices( + tensorflow::gtl::ArraySlice<HloUse> uses) { + if (uses.empty()) { + return false; + } + const HloInstruction* indices = nullptr; + for (HloUse use : uses) { + auto user = use.instruction; + if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + if (indices == nullptr) { + indices = user->operand(2); + } else if (indices != user->operand(2)) { + return false; + } + if (use.operand_number != 0) { + return false; + } + } else if (user->opcode() == HloOpcode::kDynamicSlice) { + if (indices == nullptr) { + indices = user->operand(1); + } else if (indices != user->operand(1)) { + return false; + } + } else { + return false; + } + } + return true; +} + +} // namespace using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -45,6 +88,31 @@ HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, bitcast_defines_value_(bitcast_defines_value), call_graph_(CallGraph::Build(&module)) {} +bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple( + const HloInstruction* inst) { + tensorflow::gtl::FlatSet<const HloInstruction*> visited; + tensorflow::gtl::InlinedVector<const HloInstruction*, 4> stack; + stack.push_back(inst); + while (!stack.empty()) { + const HloInstruction* current = stack.back(); + stack.pop_back(); + visited.insert(current); + for (const HloInstruction* user : current->users()) { + // Found a user that is non-elementwise on current instruction. + for (const int64 use_index : user->OperandIndices(current)) { + if (!user->IsElementwiseOnOperand(use_index) && + user->opcode() != HloOpcode::kTuple) { + return false; + } + } + if (!visited.count(user)) { + stack.push_back(user); + } + } + } + return true; +} + bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { const HloValueSet& value_set = GetValueSet(instruction, index); @@ -915,6 +983,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( ShapeUtil::GetSubshape(operand->shape(), operand_index); const Shape& user_subshape = ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; @@ -927,11 +996,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( const HloValue& value = GetValueDefinedAt(fusion_param, operand_index); if (value.uses().size() != 1) { + if (MultiDynamicSliceUseShareSameIndices(value.uses())) { + return true; + } return false; } const HloUse& use = value.uses()[0]; - if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop || + user->fusion_kind() == HloInstruction::FusionKind::kInput) { if (user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { // Loop fusion with kDynamicUpdateSlice fused root. @@ -941,6 +1014,8 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // index 0. return use.instruction == user->fused_expression_root() && use.operand_number == 0; + } else { + return AreTransitiveUsesElementwiseOrTuple(fusion_param); } } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { @@ -1003,9 +1078,6 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( // Loop fusions that contain transposing copies won't reach here as they have // different layouts, which fails the check in the beginning of this function. - // - // Multi-output fusion will fail the check here as tuples are not considered - // an elementwise operation. return user->IsElementwiseOnOperand(user->operand_index(operand)); } |