aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Yunxing Dai <yunxing@google.com>2018-06-21 12:32:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 12:35:04 -0700
commitfc4484c359cab66bd5bfdfaab936b1a5128850be (patch)
tree3037681e280ed6729175c2673e4096449d2027e6 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parent1ee5e2ce389a8dbf11db25ff37347715e7dc7efc (diff)
Enable multioutput fusion opearnd buffer reuse.
- Enable multioutput fusion opearnd buffer reuse. - Fix a bug in heap simulator where a buffer can be reused twice. - Add unittest. PiperOrigin-RevId: 201567720
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc80
1 files changed, 76 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index d020005868..08a705b18d 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -34,6 +34,49 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace xla {
+namespace {
+
+// We have this pattern in dynamaic update slice fusion, which should be
+// supported:
+//
+// Parameters: p0, p1
+// Fusion
+// ds = DynamicSlice(p0, p1)
+// ROOT DynamicUpdateslice(p0, ds, p1)
+//
+// In this case, we should be able to reuse p0 and output, although p0 has
+// multiple uses.
+bool MultiDynamicSliceUseShareSameIndices(
+ tensorflow::gtl::ArraySlice<HloUse> uses) {
+ if (uses.empty()) {
+ return false;
+ }
+ const HloInstruction* indices = nullptr;
+ for (HloUse use : uses) {
+ auto user = use.instruction;
+ if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
+ if (indices == nullptr) {
+ indices = user->operand(2);
+ } else if (indices != user->operand(2)) {
+ return false;
+ }
+ if (use.operand_number != 0) {
+ return false;
+ }
+ } else if (user->opcode() == HloOpcode::kDynamicSlice) {
+ if (indices == nullptr) {
+ indices = user->operand(1);
+ } else if (indices != user->operand(1)) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@@ -45,6 +88,31 @@ HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bitcast_defines_value_(bitcast_defines_value),
call_graph_(CallGraph::Build(&module)) {}
+bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
+ const HloInstruction* inst) {
+ tensorflow::gtl::FlatSet<const HloInstruction*> visited;
+ tensorflow::gtl::InlinedVector<const HloInstruction*, 4> stack;
+ stack.push_back(inst);
+ while (!stack.empty()) {
+ const HloInstruction* current = stack.back();
+ stack.pop_back();
+ visited.insert(current);
+ for (const HloInstruction* user : current->users()) {
+ // Found a user that is non-elementwise on current instruction.
+ for (const int64 use_index : user->OperandIndices(current)) {
+ if (!user->IsElementwiseOnOperand(use_index) &&
+ user->opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+ }
+ if (!visited.count(user)) {
+ stack.push_back(user);
+ }
+ }
+ }
+ return true;
+}
+
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
const HloValueSet& value_set = GetValueSet(instruction, index);
@@ -915,6 +983,7 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
ShapeUtil::GetSubshape(operand->shape(), operand_index);
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
+
// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
@@ -927,11 +996,15 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
const HloValue& value = GetValueDefinedAt(fusion_param, operand_index);
if (value.uses().size() != 1) {
+ if (MultiDynamicSliceUseShareSameIndices(value.uses())) {
+ return true;
+ }
return false;
}
const HloUse& use = value.uses()[0];
- if (user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
+ user->fusion_kind() == HloInstruction::FusionKind::kInput) {
if (user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
// Loop fusion with kDynamicUpdateSlice fused root.
@@ -941,6 +1014,8 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// index 0.
return use.instruction == user->fused_expression_root() &&
use.operand_number == 0;
+ } else {
+ return AreTransitiveUsesElementwiseOrTuple(fusion_param);
}
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
@@ -1003,9 +1078,6 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
// Loop fusions that contain transposing copies won't reach here as they have
// different layouts, which fails the check in the beginning of this function.
- //
- // Multi-output fusion will fail the check here as tuples are not considered
- // an elementwise operation.
return user->IsElementwiseOnOperand(user->operand_index(operand));
}