aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 16:21:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 16:23:43 -0700
commit11e1a45229b9f758a143b5fcf121ba689eca74e2 (patch)
treecc0ada62345fccebbc1251faa6c51d74382b3879
parentb74197c6cba3e11deaff553c280933afa3e5a075 (diff)
Automated g4 rollback of changelist 200309129
PiperOrigin-RevId: 200472722
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc68
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization_test.cc2
8 files changed, 96 insertions, 34 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 1154eef80e..cb2e159a38 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -2123,6 +2123,7 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
+ ":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
@@ -2130,6 +2131,7 @@ cc_library(
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
+ ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -2143,6 +2145,7 @@ tf_cc_test(
name = "hlo_rematerialization_test",
srcs = ["hlo_rematerialization_test.cc"],
deps = [
+ ":flatten_call_graph",
":hlo",
":hlo_matchers",
":hlo_ordering",
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 33d8338809..3625891b4f 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -613,7 +613,10 @@ class CopyRemover {
VLOG(2) << copy->name() << " is not removable";
return false;
}
-
+ if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
+ VLOG(2) << copy->name() << " is not removable (shape mismatch)";
+ return false;
+ }
const CopyNodes& copy_node = copy_map_.at(copy);
ValueNode* src = copy_node.src;
ValueNode* dest = copy_node.dest;
@@ -947,28 +950,6 @@ class CopyRemover {
BufferValueTracker buffer_value_tracker_;
};
-// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- CopyRemover copy_remover(*alias_analysis, ordering, module);
- XLA_VLOG_LINES(3, copy_remover.ToString());
-
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id())) {
- TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
- }
- }
- }
- return Status::OK();
-}
-
// Add copies to address special constraints on the roots of computations not
// related to live range interference:
//
@@ -1065,13 +1046,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
HloInstruction* instruction = pair.first;
const ShapeTree<bool>& indices_to_copy = pair.second;
+ ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
std::vector<HloInstruction*> users = instruction->users();
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
instruction->parent()->DeepCopyInstruction(
- instruction, &indices_to_copy));
+ instruction, &indices_to_copy, &copies_added));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
+ // Special case copies are not eligible for later copy elision passes.
+ indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
+ if (has_copy) {
+ HloInstruction* copy = *copies_added.mutable_element(index);
+ if (copy != nullptr) {
+ copy->SetCopyElisionAllowed(false);
+ }
+ }
+ });
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@@ -1097,6 +1088,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering,
+ const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
+ MaybeDumpModule("after adding copies to resolve interference", *module);
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
+ HloAliasAnalysis::Run(module));
+ CopyRemover copy_remover(*alias_analysis, ordering, module);
+ XLA_VLOG_LINES(3, copy_remover.ToString());
+
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCopy &&
+ !ContainsKey(copies_to_exclude, instruction->unique_id()) &&
+ instruction->CopyElisionAllowed()) {
+ TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
+ }
+ }
+ }
+ MaybeDumpModule("after removing unnecessary copies", *module);
+
+ return Status::OK();
+}
+
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Copy insertion is performed in three steps:
//
@@ -1158,14 +1174,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
- MaybeDumpModule("after adding copies to resolve interference", *module);
-
DependencyHloOrdering ordering(module);
TF_RETURN_IF_ERROR(
RemoveUnnecessaryCopies(ordering, existing_copies, module));
- MaybeDumpModule("after removing unnecessary copies", *module);
-
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
MaybeDumpModule("after adding special-case copies", *module);
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index 65e3d31e34..0d7b3c20f9 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -64,6 +64,13 @@ class CopyInsertion : public HloPassInterface {
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
};
+// Try to remove as many copies from the module as possible without introducing
+// live range interference. Copy instructions (identified by their unique id) in
+// the set copies_to_exclude are not considered for removal.
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering,
+ const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 2816a3b708..2a38e2b063 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -1046,6 +1046,19 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
+ // TODO(b/80249101): Remove these methods once HLO scheduling and copy
+ // insertion are integrated, and we don't need to run a separate pass
+ // of copy elision anymore.
+ bool CopyElisionAllowed() const {
+ CHECK_EQ(HloOpcode::kCopy, opcode_);
+ return copy_elision_allowed_;
+ }
+
+ void SetCopyElisionAllowed(bool value) {
+ CHECK_EQ(HloOpcode::kCopy, opcode_);
+ copy_elision_allowed_ = value;
+ }
+
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@@ -1568,6 +1581,9 @@ class HloInstruction {
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_window_bounds_;
+ // Used to tag kCopy instructions that are eligible for copy elision.
+ bool copy_elision_allowed_ = true;
+
// Describes the [start, start + size) range size for a dynamic slice
// ('start' is specified dynamically in the second operand of the operation).
std::vector<int64> dynamic_slice_sizes_;
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index dcd4725fe7..6c1e015f77 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -232,6 +232,11 @@ bool HloOrdering::UseIsBeforeValueDefinition(
<< " and def is in FALSE computation";
return true;
}
+ if (value.defining_instruction() == use.instruction) {
+ VLOG(4) << " use is conditional " << use << " and def is "
+ << value.ToShortString();
+ return true;
+ }
}
VLOG(4) << " use is not before value";
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 9c7bc7a5ea..62c07d7fac 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -1201,7 +1202,8 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit_bytes, RematerializationSizes* sizes) {
+ int64 memory_limit_bytes, RematerializationSizes* sizes,
+ bool run_copy_elision) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
@@ -1236,6 +1238,15 @@ StatusOr<bool> HloRematerialization::Run(
return size_function_(buffer.shape());
},
scheduler_algorithm_));
+ if (run_copy_elision) {
+ // We run a separate pass of copy elision here because the sequential
+ // ordering from the HLO schedule allows for more copies to be eliminated.
+ // TODO(b/80249101): Instead of a separate copy elision pass, use the
+ // ordering from the HLO schedule directly for copy insertion.
+ SequentialHloOrdering ordering(module, *sequence);
+ TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module));
+ }
+
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
@@ -1338,9 +1349,10 @@ StatusOr<bool> HloRematerialization::Run(
int64 memory_limit_bytes, HloModule* hlo_module,
MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes) {
+ RematerializationSizes* sizes, bool run_copy_elision) {
HloRematerialization remat(scheduler_algorithm, size_function);
- return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
+ return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
+ run_copy_elision);
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 2ee2dd0571..59b4cf5dcc 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -57,6 +57,12 @@ class HloRematerialization {
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
+ // run_copy_elision: Enable copy elision. This pass is used to eliminate
+ // copies that were inserted before HLO scheduling.
+ //
+ // TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
+ // insertion is integrated with HLO scheduling.
+ //
// Returns whether any instructions were rematerialized. If memory use is
// already below the given limit then no instructions are rematerialized and
// false is returned.
@@ -68,7 +74,7 @@ class HloRematerialization {
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
- RematerializationSizes* sizes = nullptr);
+ RematerializationSizes* sizes, bool run_copy_elision = true);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@@ -83,7 +89,8 @@ class HloRematerialization {
// contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
- int64 memory_limit, RematerializationSizes* sizes);
+ int64 memory_limit, RematerializationSizes* sizes,
+ bool run_copy_elision);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
index e81334d5a8..7a46da6efe 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc
@@ -147,7 +147,7 @@ class HloRematerializationTest : public HloTestBase {
TF_EXPECT_OK(verifier().Run(module).status());
return HloRematerialization::RematerializeAndSchedule(
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
- sequence);
+ sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false);
}
// Various shapes used in the canned computations.