aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2017-08-29 13:20:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-29 13:24:00 -0700
commit66cc9e5474c38b1fddb4a192748f4a3248364c49 (patch)
tree2ca58e1adc76a8d70a6ab986650958afdcbb1b9b
parent28ea83ce387f3794e22745112fe9bdedb9e5993b (diff)
[XLA] Make RunDFSMemoryScheduler more accurate
Skip parameters and constants, they are long lived in ways which are not relevant to the analysis. Also, remove superfluous bitcasts so as to make our ordering more accurate. PiperOrigin-RevId: 166892521
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc17
2 files changed, 30 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index d3649a1ed1..2dcba6996c 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -121,6 +121,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
HloInstruction* rhs) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+
Status HandleBroadcast(HloInstruction* broadcast) override;
Status HandleConcatenate(
@@ -340,6 +342,20 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add,
return Status::OK();
}
+Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) {
+ // If a bitcast feeds a bitcast, make it a single bitcast.
+ if (bitcast->operand(0)->opcode() == HloOpcode::kBitcast) {
+ return ReplaceWithNewInstruction(
+ bitcast, HloInstruction::CreateUnary(
+ bitcast->shape(), HloOpcode::kBitcast,
+ bitcast->mutable_operand(0)->mutable_operand(0)));
+ }
+ // All bitcasts can be eliminated (assuming layout constraints are
+ // satisified).
+ ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0));
+ return Status::OK();
+}
+
Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) {
// If a copy feeds a copy, make it a single copy.
if (copy->operand(0)->opcode() == HloOpcode::kCopy) {
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
index 3df760d159..25be448c8d 100644
--- a/tensorflow/compiler/xla/service/hlo_scheduling.cc
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -72,6 +72,13 @@ class ListScheduler {
return scheduler.CreateSchedule();
}
+ // Returns whether the memory used by the given HLO should be ignored by the
+ // scheduling heuristic.
+ static bool IgnoreInstruction(const HloInstruction& instruction) {
+ return instruction.opcode() == HloOpcode::kParameter ||
+ instruction.opcode() == HloOpcode::kConstant;
+ }
+
private:
// The scheduling priority of an instruction is first the number of bytes
// freed by scheduling the instruction, and second (tie-breaker) by the number
@@ -127,9 +134,8 @@ class ListScheduler {
// Returns whether the memory used by the given buffer should be ignored by
// the scheduling heuristic.
- bool IgnoreBuffer(const LogicalBuffer& buffer) {
- return buffer.instruction()->opcode() == HloOpcode::kParameter ||
- buffer.instruction()->opcode() == HloOpcode::kConstant;
+ static bool IgnoreBuffer(const LogicalBuffer& buffer) {
+ return IgnoreInstruction(*buffer.instruction());
}
// An entry in the worklist used by CreateSchedule. Corresponds to one
@@ -306,6 +312,11 @@ StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
+ if (ListScheduler::IgnoreInstruction(*hlo)) {
+ extra_users[hlo] = 0;
+ total_sizes[hlo] = 0;
+ continue;
+ }
extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
total_sizes[hlo] = SumLogicalBufferSizes(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);