aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-09-18 23:34:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-18 23:37:38 -0700
commit23da21150d988f7cf5780488f24adbb116675586 (patch)
tree916db234da382bf5cf2815f913ee5c358c720be5 /tensorflow/compiler/xla
parentf08ec5722b68da94cf7bea1186337f82620fe60e (diff)
Add liveness_util functions which use dataflow analysis. Also make the analysis argument (TuplePointsToAnalysis or HloDataflowAnalysis) non-optional as all callers were passing in the analysis.
PiperOrigin-RevId: 169200824
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc2
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc22
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.cc116
-rw-r--r--tensorflow/compiler/xla/service/liveness_util.h22
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc97
13 files changed, 254 insertions, 68 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 8361212337..f23fa22107 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -720,6 +720,7 @@ cc_library(
hdrs = ["liveness_util.h"],
deps = [
":hlo",
+ ":hlo_dataflow_analysis",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -838,6 +839,7 @@ cc_library(
deps = [
":call_graph",
":hlo",
+ ":hlo_dataflow_analysis",
":hlo_proto",
":hlo_value",
":liveness_util",
@@ -1391,9 +1393,7 @@ cc_library(
deps = [
":call_graph",
":hlo",
- ":hlo_ordering",
":hlo_value",
- ":liveness_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",
@@ -1412,6 +1412,7 @@ cc_test(
":hlo_dataflow_analysis",
":hlo_graph_dumper",
":hlo_matchers",
+ ":hlo_ordering",
":instruction_fusion",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@@ -1470,6 +1471,7 @@ cc_test(
":hlo_alias_analysis",
":hlo_graph_dumper",
":hlo_matchers",
+ ":hlo_ordering",
":instruction_fusion",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
index f085ffa6bc..8610080203 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -123,7 +123,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
if (b.instruction()->IsUserOf(alias.instruction()) &&
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
b.instruction(), b.index(),
- &points_to_analysis())) {
+ points_to_analysis())) {
return false;
}
}
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index c85e97b691..34e2f7ee20 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -204,7 +204,7 @@ Status HeapSimulator::RunComputation(
buffer->instruction()->opcode() != HloOpcode::kCopy &&
CanShareOperandBufferWithUser(
operand_buffer->instruction(), operand_buffer->index(),
- buffer->instruction(), buffer->index(), &points_to_analysis)) {
+ buffer->instruction(), buffer->index(), points_to_analysis)) {
ShareBuffer(buffer, operand_buffer, instruction);
shared = true;
break;
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index e7ff9e7cf3..a275628779 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -93,7 +94,8 @@ class HloAliasAnalysisTest : public HloTestBase {
for (const HloValue* value_a : buffer.values()) {
for (const HloValue* value_b : buffer.values()) {
if (*value_a != *value_b &&
- ordering.MayInterfere(*value_a, *value_b)) {
+ ordering.MayInterfere(*value_a, *value_b,
+ analysis_->dataflow_analysis())) {
VLOG(1) << *value_a << " interferes with " << *value_b
<< " in buffer: " << buffer;
return true;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 2be1645f1b..213ff07b07 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index aae257dd09..207e553bf7 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -28,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status.h"
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4939335e2f..4b8eb237a6 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -73,7 +74,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
- analysis_->GetValueDefinedAt(b));
+ analysis_->GetValueDefinedAt(b), *analysis_);
}
std::unique_ptr<HloModule> module_;
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 08f572bb2a..3612c51ee8 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -123,8 +123,9 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
}
/* static */
-bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
- const HloValue& value) const {
+bool HloOrdering::UseIsBeforeValueDefinition(
+ const HloUse& use, const HloValue& value,
+ const HloDataflowAnalysis& dataflow) const {
VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
<< ", value=" << value.ToShortString() << ")";
if (ExecutesBefore(use.instruction, value.defining_instruction())) {
@@ -139,7 +140,7 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
CanShareOperandBufferWithUser(
use.instruction->mutable_operand(use.operand_number),
use.operand_index, value.defining_instruction(),
- value.defining_index())) {
+ value.defining_index(), dataflow)) {
VLOG(4) << " use is value def, and instruction can share use buffer";
return true;
}
@@ -172,12 +173,13 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use,
return true;
}
}
- VLOG(4) << " use is not before while";
+ VLOG(4) << " use is not before value";
return false;
}
-bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
- const HloValue& b) const {
+bool HloOrdering::LiveRangeStrictlyBefore(
+ const HloValue& a, const HloValue& b,
+ const HloDataflowAnalysis& dataflow) const {
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
<< ", b = " << b.ToShortString() << ")";
if (!IsDefinedBefore(a, b)) {
@@ -204,7 +206,7 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
- if (!UseIsBeforeValueDefinition(use, b)) {
+ if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
VLOG(4) << "use of a (" << use << ") not before b is defined";
return false;
}
@@ -213,9 +215,11 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a,
return true;
}
-bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b) const {
+bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b,
+ const HloDataflowAnalysis& dataflow) const {
// Buffers without disjoint liveness may interfere.
- return !LiveRangeStrictlyBefore(a, b) && !LiveRangeStrictlyBefore(b, a);
+ return !LiveRangeStrictlyBefore(a, b, dataflow) &&
+ !LiveRangeStrictlyBefore(b, a, dataflow);
}
HloOrderingProto HloOrdering::ToProto() const {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index e0c23a3a08..ee526d8dd7 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
@@ -48,15 +49,17 @@ class HloOrdering {
// Returns whether the given use is before the given value definition under
// the given ordering.
- bool UseIsBeforeValueDefinition(const HloUse& use,
- const HloValue& value) const;
+ bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
+ const HloDataflowAnalysis& dataflow) const;
// Returns whether the given values interfere. Two values interfere if they
// may both be simultaneously live.
- bool MayInterfere(const HloValue& a, const HloValue& b) const;
+ bool MayInterfere(const HloValue& a, const HloValue& b,
+ const HloDataflowAnalysis& dataflow) const;
// Returns true if the live range of the given value 'a' is strictly before
// the live range of value 'b' using the given HLO ordering.
- bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b) const;
+ bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
+ const HloDataflowAnalysis& dataflow) const;
// Returns the sequential instruction order for the given computation, or
// nullptr if the computation does not have a sequential ordering.
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index dbd63eceed..33bafd05c1 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -269,29 +269,32 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// while because of the use of the init value in the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
dataflow->GetValueDefinedAt(xla_while)));
- EXPECT_FALSE(
- ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
- dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while), *dataflow));
EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
- dataflow->GetValueDefinedAt(xla_while)));
+ dataflow->GetValueDefinedAt(xla_while),
+ *dataflow));
// Any value defined in the body or condition is defined before the while, and
// has a live range strictly before the while.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
dataflow->GetValueDefinedAt(xla_while)));
- EXPECT_TRUE(
- ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
- dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while), *dataflow));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
- dataflow->GetValueDefinedAt(xla_while)));
+ dataflow->GetValueDefinedAt(xla_while),
+ *dataflow));
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
dataflow->GetValueDefinedAt(xla_while)));
- EXPECT_TRUE(
- ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
- dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while), *dataflow));
EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
- dataflow->GetValueDefinedAt(xla_while)));
+ dataflow->GetValueDefinedAt(xla_while),
+ *dataflow));
// The live range of the while should be before the add.
EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
@@ -301,10 +304,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
EXPECT_EQ(while_use.instruction, add);
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
- while_use, dataflow->GetValueDefinedAt(add)));
- EXPECT_TRUE(
- ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
- dataflow->GetValueDefinedAt(add)));
+ while_use, dataflow->GetValueDefinedAt(add), *dataflow));
+ EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
+ dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
+ *dataflow));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc
index 317271dfdd..c27a8956a7 100644
--- a/tensorflow/compiler/xla/service/liveness_util.cc
+++ b/tensorflow/compiler/xla/service/liveness_util.cc
@@ -69,6 +69,36 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand,
return false;
}
+bool DoesNotUseOperandBuffer(const HloInstruction* operand,
+ const ShapeIndex& index,
+ const HloInstruction* user,
+ const HloDataflowAnalysis& dataflow) {
+ CHECK(user->IsUserOf(operand))
+ << "user: " << user->ToString() << " operand: " << operand->ToString();
+ if (user->opcode() == HloOpcode::kFusion &&
+ user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ // Find fusion parameter associated with 'operand'.
+ HloInstruction* fusion_param =
+ user->fused_parameter(user->operand_index(operand));
+ // Iterate through all users of all uses of the fusion parameter value.
+ // Return false if any uses are detected, returns true otherwise.
+ const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index);
+ return value.uses().empty();
+ } else {
+ // Return false if no value at 'operand' and 'index' is used at 'user'.
+ for (const HloValue* value :
+ dataflow.GetValueSet(operand, index).values()) {
+ for (const HloUse& use : value->uses()) {
+ if (use.instruction == user) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+}
+
namespace {
// Returns all uses of all aliases of 'instruction' at 'index' in 'uses'.
@@ -153,7 +183,7 @@ bool HasUniqueFusedUseOfOperandAt(
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
- const TuplePointsToAnalysis* points_to_analysis) {
+ const TuplePointsToAnalysis& points_to_analysis) {
CHECK(user->IsUserOf(operand))
<< "user: " << user->ToString() << " operand: " << operand->ToString();
const Shape& operand_subshape =
@@ -164,7 +194,7 @@ bool CanShareOperandBufferWithUser(
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
return false;
}
- if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) {
+ if (user->opcode() == HloOpcode::kFusion) {
if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
@@ -174,7 +204,7 @@ bool CanShareOperandBufferWithUser(
// 'operand_index', and this singleton use is the fused root at operand
// index 0.
return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0,
- *points_to_analysis);
+ points_to_analysis);
} else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
// Output fusion with kAdd fused root.
@@ -202,7 +232,85 @@ bool CanShareOperandBufferWithUser(
// index 'other_add_operand_index').
return HasUniqueFusedUseOfOperandAt(operand, operand_index, user,
other_add_operand_index,
- *points_to_analysis);
+ points_to_analysis);
+ }
+ }
+ if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
+ user->opcode() == HloOpcode::kWhile) {
+ // We eliminated other users in BufferLiveness::live_range_strictly_before,
+ // so here we just need to check that the use is at operand index 0.
+ std::vector<int64> operand_indices = user->OperandIndices(operand);
+ return operand_indices.size() == 1 && operand_indices[0] == 0;
+ }
+ // Check if 'user' is element-wise.
+ return user->IsElementwise();
+}
+
+bool CanShareOperandBufferWithUser(HloInstruction* operand,
+ const ShapeIndex& operand_index,
+ HloInstruction* user,
+ const ShapeIndex& user_index,
+ const HloDataflowAnalysis& dataflow) {
+ CHECK(user->IsUserOf(operand))
+ << "user: " << user->ToString() << " operand: " << operand->ToString();
+ const Shape& operand_subshape =
+ ShapeUtil::GetSubshape(operand->shape(), operand_index);
+ const Shape& user_subshape =
+ ShapeUtil::GetSubshape(user->shape(), user_index);
+ // Check that operand and user emit the same shape and layout.
+ if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
+ return false;
+ }
+
+ if (user->opcode() == HloOpcode::kFusion) {
+ // Get the parameter associated with 'operand';
+ HloInstruction* fusion_param =
+ user->fused_parameter(user->operand_index(operand));
+
+ const HloValue& value =
+ dataflow.GetValueDefinedAt(fusion_param, operand_index);
+ if (value.uses().size() != 1) {
+ return false;
+ }
+ const HloUse& use = value.uses()[0];
+
+ if (user->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ user->fused_expression_root()->opcode() ==
+ HloOpcode::kDynamicUpdateSlice) {
+ // Loop fusion with kDynamicUpdateSlice fused root.
+ //
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root at operand
+ // index 0.
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == 0;
+ } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
+ // Output fusion with kAdd fused root.
+
+ // Check if one operand of kAdd fused root is either kDot, or nested
+ // kFusion of kind kTransposeDot.
+ auto* add = user->fused_expression_root();
+ auto add_operand_it =
+ std::find_if(add->operands().begin(), add->operands().end(),
+ [&](HloInstruction* operand) {
+ return operand->opcode() == HloOpcode::kDot ||
+ (operand->opcode() == HloOpcode::kFusion &&
+ operand->fusion_kind() ==
+ HloInstruction::FusionKind::kTransposeDot);
+ });
+ if (add_operand_it == add->operands().end()) {
+ return false;
+ }
+ auto* matched_add_operand = *add_operand_it;
+ // Calculate operand index of 'add' operand which was not matched above.
+ const int64 other_add_operand_index =
+ matched_add_operand == add->operand(0) ? 1 : 0;
+ // Returns true iff there is exactly one use of 'operand' at shape index
+ // 'operand_index', and this singleton use is the fused root (at operand
+ // index 'other_add_operand_index').
+ return use.instruction == user->fused_expression_root() &&
+ use.operand_number == other_add_operand_index;
}
}
if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h
index c7799e5ab5..28ef991880 100644
--- a/tensorflow/compiler/xla/service/liveness_util.h
+++ b/tensorflow/compiler/xla/service/liveness_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -29,21 +30,34 @@ namespace xla {
// 'operand'. Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
+//
+// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have
+// moved over to the dataflow overload.
bool DoesNotUseOperandBuffer(const HloInstruction* operand,
const ShapeIndex& index,
const HloInstruction* user,
const TuplePointsToAnalysis& points_to_analysis);
+bool DoesNotUseOperandBuffer(const HloInstruction* operand,
+ const ShapeIndex& index,
+ const HloInstruction* user,
+ const HloDataflowAnalysis& dataflow);
// Returns true if 'user' (at 'user_index') can share a buffer with its operand
-// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a
-// points-to analysis argument. Without the analysis, the result is more
-// conservative (returns false more often).
+// 'operand' (at 'operand_index'). Returns false otherwise.
//
// REQUIRES: 'operand' is an operand of 'user'.
+//
+// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have
+// moved over to the dataflow overload.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
- const TuplePointsToAnalysis* points_to_analysis = nullptr);
+ const TuplePointsToAnalysis& points_to_analysis);
+bool CanShareOperandBufferWithUser(HloInstruction* operand,
+ const ShapeIndex& operand_index,
+ HloInstruction* user,
+ const ShapeIndex& user_index,
+ const HloDataflowAnalysis& dataflow);
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
index d89dab4a82..b5e15906d3 100644
--- a/tensorflow/compiler/xla/service/liveness_util_test.cc
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -35,6 +35,8 @@ class PointsToAnalysisTestBase : public HloTestBase {
CHECK_NOTNULL(module_.get());
points_to_analysis_ =
TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
+ dataflow_analysis_ =
+ HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
@@ -45,6 +47,7 @@ class PointsToAnalysisTestBase : public HloTestBase {
std::unique_ptr<HloModule> module_;
HloComputation* computation_ = nullptr;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
+ std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
};
class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {};
@@ -70,6 +73,11 @@ TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_));
EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_));
+
+ EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_));
+ EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_));
+ EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_));
+ EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_));
}
TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
@@ -105,6 +113,10 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_));
EXPECT_FALSE(
DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_));
+
+ EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_));
+ EXPECT_FALSE(
+ DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_));
}
class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {};
@@ -122,10 +134,15 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
BuildModuleAndRunAnalysis(builder.Build());
- EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {},
- points_to_analysis_.get()));
- EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {},
- points_to_analysis_.get()));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_));
+
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
@@ -143,9 +160,14 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
BuildModuleAndRunAnalysis(builder.Build());
EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
- points_to_analysis_.get()));
+ *points_to_analysis_));
+ EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
+ *points_to_analysis_));
+
+ EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {},
+ *dataflow_analysis_));
EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {},
- points_to_analysis_.get()));
+ *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
@@ -161,10 +183,15 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
BuildModuleAndRunAnalysis(builder.Build());
- EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {},
- points_to_analysis_.get()));
- EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {},
- points_to_analysis_.get()));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_));
+
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
@@ -197,9 +224,14 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
// The fusion instruction can share with tuple element 1.
EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
- points_to_analysis_.get()));
+ *points_to_analysis_));
EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
- points_to_analysis_.get()));
+ *points_to_analysis_));
+
+ EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {},
+ *dataflow_analysis_));
+ EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {},
+ *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
@@ -221,12 +253,19 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
// The DynamicUpdateSlice instruction can share with the data operand, but not
// with update or starts.
- EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {},
- points_to_analysis_.get()));
- EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {},
- points_to_analysis_.get()));
- EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {},
- points_to_analysis_.get()));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_));
+ EXPECT_FALSE(
+ CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_));
+ EXPECT_FALSE(
+ CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_));
+
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_));
+ EXPECT_FALSE(
+ CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_));
+ EXPECT_FALSE(
+ CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
@@ -256,7 +295,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
// Output fused dot add should be able to share buffer with 'add_operand'.
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
- points_to_analysis_.get()));
+ *points_to_analysis_));
+
+ EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
+ *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
@@ -292,7 +334,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
// Output fused transpose-dot-add should be share buffer with 'add_operand'.
EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
- points_to_analysis_.get()));
+ *points_to_analysis_));
+
+ EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {},
+ *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
@@ -320,7 +365,10 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
// Output fused operand->reverse->add cannot alias operand buffer 'operand'.
EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
- points_to_analysis_.get()));
+ *points_to_analysis_));
+
+ EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {},
+ *dataflow_analysis_));
}
TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
@@ -360,8 +408,11 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
RunAnalysis();
// The While instruction can share with the data operand.
- EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {},
- points_to_analysis_.get()));
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_));
+
+ EXPECT_TRUE(
+ CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_));
}
} // namespace