diff options
author | Mark Heffernan <meheff@google.com> | 2017-11-03 13:26:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-03 13:31:12 -0700 |
commit | 456929281592f14d50443cfbdaa2f6b36167a134 (patch) | |
tree | 8a1a18245a6ef9baed8bd9a9f35b7c250ab64901 | |
parent | 5b166f495ae79b6e8144bbd3a1109f4b8d9fb1aa (diff) |
Rollback copy insertion change because it results in a DCHECK with an internal model.
END_PUBLIC
BEGIN_PUBLIC
Automated g4 rollback of changelist 174423881
PiperOrigin-RevId: 174505237
25 files changed, 879 insertions, 2200 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7fe06655cf..c6f6c6c38b 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1644,14 +1644,10 @@ cc_library( deps = [ ":buffer_liveness", ":hlo", - ":hlo_alias_analysis", - ":hlo_dce", - ":hlo_graph_dumper", - ":hlo_ordering", ":hlo_pass", ":liveness_util", ":logical_buffer", - ":tuple_simplifier", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1666,17 +1662,15 @@ tf_cc_test( deps = [ ":copy_insertion", ":hlo", - ":hlo_graph_dumper", ":hlo_matchers", + ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 5c9714d7ea..8536429846 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1235,6 +1235,7 @@ const LogicalBuffer* AddBufferToColocatedSet( // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); DCHECK(!points_to.IsAmbiguous()); + DCHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); return colocated_set->back(); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 4d4c5b953e..89410f42bd 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1538,6 +1538,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1554,8 +1556,10 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto body1 = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output1})); auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); RunCopyInsertion(module.get()); @@ -1672,37 +1676,34 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, gte0, gte1)); - + while0->shape(), HloOpcode::kAdd, while0, while1)); module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + { FlattenCallGraph flatten; TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } - RunCopyInsertion(module.get()); - auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + std::vector<const HloInstruction*> sequence_for_buffer_assigment = { + input1, weights1, one, output1, tuple1, while1, input0, + weights0, zero, output0, tuple0, while0, root_add}; // If this ASSERT_TRUE fails, we constructed a bogus sequence above // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); + + sequence[module->entry_computation()] = + std::move(sequence_for_buffer_assigment); auto assignment = BufferAssigner::Run( @@ -1714,6 +1715,55 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } +// Test buffer assignment for while nodes with multiple uses. +// TODO(b/37245345): Fix buffer assignment for this case. +TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { + auto module = MakeUnique<HloModule>(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); + + auto get0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto get1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { auto module = MakeUnique<HloModule>(TestName()); auto builder = HloComputation::Builder("entry"); diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 8f50b29dad..0453a698a0 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,17 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" -#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" +#include <memory> + #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,1113 +31,597 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { -bool IsEntryParameterValue(const HloValue& value) { - const HloComputation* computation = value.defining_instruction()->parent(); - return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); -} - -bool IsConstantValue(const HloValue& value) { - return value.defining_instruction()->opcode() == HloOpcode::kConstant; -} - -bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsEntryParameterValue(value); -} +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; -// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in -// 'indices_to_copy'. Add control edges from the respective kCopy instructions -// in deep copy of 'from' to the respective kCopy instruction in the deep copy -// of 'to'. -// -// Requirements: 'from' and 'to' must have compatible shapes. +// InstructionCopier encapsulates indices at which to copy 'instruction'. +// All 'instruction' users in 'copy_users' are updated to use the copy. // -// For example, suppose 'from' and 'to' are two-element tuples where index 0 is -// the only index to copy. Prior to deep-copying we have: +// Instruction copies are generated in two phases: +// 1) Recording buffer indices at which 'instruction' requires copies (i.e. +// setting 'indices_to_copy_[index]'=true). +// 2) Inserting kCopy instructions based on indices recorded in phase 1). +// *) Array instructions are copied by inserting a single kCopy instruction. +// *) Tuple-shaped instructions are copied by recursively expanding tuples +// (and tuple-shaped elements), and inserting kCopy instructions for any +// tuple elements which require a copy. As the recursion unwinds, new tuple +// instructions are added to gather the copied (and uncopied) references +// into the output tuple (i.e. the copy of the tuple-shaped instruction). // +// Example two-element tuple with one element that needs a copy: // -// 'from' -// | -// ... -// | -// 'to' +// original-instruction +// / \ +// GTE(0) GTE(1) +// | | +// Copy | +// \ / +// Tuple // copied-instruction // -// DeepCopyAndAddControlEdges produces: -// -// 'from' -// / \ -// GTE GTE -// | | -// Copy | -// / \ / -// | Tuple -// | | -// ctrl ... -// edge | -// | | -// | 'to' -// | / \ -// | GTE GTE -// \ | | -// Copy | -// \ / -// Tuple -// -StatusOr<std::pair<HloInstruction*, HloInstruction*>> -DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, - const ShapeTree<bool>& indices_to_copy) { - DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); - // to/from_copy_tree hold the kCopy instruction produces by the deep - // copies. Elements which are not copied (indices_to_copy.element(index) == - // false) have nullptr at that index. - ShapeTree<HloInstruction*> from_copy_tree(from->shape(), - /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, - from->parent()->DeepCopyInstruction( - from, &indices_to_copy, &from_copy_tree)); - - ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN( - HloInstruction * to_deep_copy, - to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); - - // Add control edges between the respective kCopy instructions. - for (const auto& pair : from_copy_tree) { - const ShapeIndex& index = pair.first; - HloInstruction* from_copy = pair.second; - HloInstruction* to_copy = to_copy_tree.element(index); - if (from_copy == nullptr) { - TF_RET_CHECK(to_copy == nullptr); - continue; - } - TF_RET_CHECK(to_copy != nullptr); - TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); +// As an optimization, if the original instruction is itself a Tuple +// instruction, we elide the unnecessary extra GTE and Tuple instructions, +// and just insert the copy into a new Tuple instruction, with control +// dependencies to ensure the copy occurs after any possible interference. +class InstructionCopier { + public: + InstructionCopier(HloInstruction* instruction, + const std::vector<HloInstruction*>& copy_users) + : instruction_(instruction), + copy_users_(copy_users), + indices_to_copy_(instruction->shape()), + control_predecessors_(instruction->shape()) {} + + // Sets indices that are read-only, and thus do not need to be copied. + void SetReadOnlyIndices(const ShapeTree<bool>& read_only_indices) { + read_only_indices_ = read_only_indices; } - return std::make_pair(from_deep_copy, to_deep_copy); -} - -// Compute the indices of the loop state which need copies in order to avoid -// live range interference. Generally, an element in the loop state does not -// need to be copied if the element is passed through transparently through the -// body. -// -// Returns whether any indices need to be copied. -bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, - const HloInstruction* xla_while, - ShapeTree<bool>* indices_to_copy) { - DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); - - bool any_copies = false; - const HloInstruction* init = xla_while->operand(0); - for (auto& pair : *indices_to_copy) { - const ShapeIndex& index = pair.first; - bool& should_copy = pair.second; - // If there is any ambiguity, then loop state must be copied. - if (dataflow.GetValueSet(init, index).values().size() > 1 || - dataflow.GetValueSet(xla_while, index).values().size() > 1) { - should_copy = true; - } else { - // If the output of the while instruction is not the same as the init - // value of the while, then this element is not passed through the body - // transparently and must be copied. - should_copy = dataflow.GetUniqueValueAt(xla_while, index) != - dataflow.GetUniqueValueAt(init, index); - } - any_copies |= should_copy; + // Sets copy overrides, which are copy instructions to use at each index. This + // is used to share a single copy of read-only entry parameters and constants + // between multiple While loops. + void SetCopyOverrides(const ShapeTree<HloInstruction*>& copy_overrides) { + copy_overrides_ = copy_overrides; } - return any_copies; -} - -// Add kCopy instructions around the given kWhile instruction to eliminate any -// possible live range interference of HLO values assuming a dependency-based -// ordering (HloDependencyOrdering). Copies are added conservatively. There -// likely are copies which are not strictly necessary, but there are removed -// later in the pass via CopyRemover. -// -// -// Elements (each ShapeIndex) in the loop state are considered independently. A -// copy is added to each element of the loop state which is modified in the -// while body. For each such element, a total of three kCopy instructions are -// added at following locations: -// -// (1) The init value is copied before the kWhile instruction. Before: -// -// (Init) -// | -// kWhile -// | -// ... -// -// After: -// -// (Init) -// | -// kCopy -// | -// kWhile -// | -// ... -// -// This copy is necessary in case the init value is simultaneously live -// with the kWhile. -// -// (2) Copies are added to the parameter and root of the while body -// computation. Before: -// -// kParameter -// | -// ... -// | -// (body root) -// -// After: -// -// kParameter -// | -// kCopy ----------+ -// | | -// ... ctrl -// | edge -// (body root) | -// | | -// kCopy <---------+ -// -// The root kCopy becomes the new root of the computation. Both copies are -// necessary to any potential interference between the parameter value and -// the root value. The control edge prevents potential interference -// between the copies themselves. -// -// If the loop state is a tuple then the above kCopy instructions are a deep -// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as -// constructed by HloInstruction::DeepCopyInstruction. -Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, - HloInstruction* xla_while) { - VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); - TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); - ShapeTree<bool> indices_to_copy(xla_while->shape()); - if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, - &indices_to_copy)) { - VLOG(2) << "No copies necessary for kWhile instruction " - << xla_while->name(); - return Status::OK(); - } + // Returns true if all recorded indices are false (returns true otherwise). + bool HasAllIndicesFalse() const; - VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; - for (auto& pair : indices_to_copy) { - if (pair.second) { - VLOG(2) << " " << pair.first; - } - } + // Records instruction buffer indices which point-to a Parameter or Constant. + Status RecordIndicesWhichPointToParamOrConstant( + const TuplePointsToAnalysis& points_to_analysis); - // Deep copy init. - HloInstruction* while_init = xla_while->mutable_operand(0); - TF_ASSIGN_OR_RETURN( - HloInstruction * while_init_copy, - xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); - TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); + // Records instruction buffer indices to copy which are necessary to ensure: + // *) PointsToSet of 'instruction_' is unambiguous and distinct. + // *) No liveness interference between 'instruction_' and 'other_instruction'. + // + // If 'read_only_indices_out' is non-null, read-only indices are set to true. + Status RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out); - // Deep copy the parameter and the root. Extend a control edge from the copy - // of the parameter value to the corresponding copy value of the root. - HloComputation* body = xla_while->while_body(); - HloInstruction* param = body->parameter_instruction(0); - HloInstruction* root = body->root_instruction(); + // Records control predecessors to add for inserted copy instructions. + // 'parameter' must have the same shape as the instruction that will be + // copied, and must define all buffers in the shape. Control predecessors are + // only recorded for indices that have already been marked for copying. + Status RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter); - // If param is the root then all indices should have been passed through the - // while body and we should have returned early above. - TF_RET_CHECK(param != root); + // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', + // and replaces all uses for instructions in 'copy_users_' with copy. + // Returns the instruction which is a copy 'instruction'. + HloInstruction* Copy(); - // Copy users before making a deep copy of the parameter as the deep copy - // will create new users of the parameter (eg, the GTE instructions of the - // deep copy). - std::vector<HloInstruction*> param_users = param->users(); + HloInstruction* instruction() { return instruction_; } - ShapeIndex current_index; - TF_ASSIGN_OR_RETURN(auto pair, - DeepCopyAndAddControlEdges(param, root, indices_to_copy)); + const std::vector<HloInstruction*>& copy_users() const { return copy_users_; } - HloInstruction* param_copy = pair.first; - HloInstruction* root_copy = pair.second; + private: + // Does the given index represent a read-only buffer? + bool IsReadOnlyIndex(const ShapeIndex& index) const { + return !ShapeUtil::IsNil(read_only_indices_.shape()) && + read_only_indices_.element(index); + } - for (HloInstruction* user : param_users) { - TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); + // Returns the copy override at the given index, or nullptr. + HloInstruction* GetCopyOverride(const ShapeIndex& index) const { + return ShapeUtil::IsNil(copy_overrides_.shape()) + ? nullptr + : copy_overrides_.element(index); } - body->set_root_instruction(root_copy); + // Records instruction buffer indices which have ambiguous or non-distinct + // points-to sets. + Status RecordAmbiguousOrNonDistinctIndices( + const TuplePointsToAnalysis& points_to_analysis); - return Status::OK(); -} + // Records instruction buffer indices which have interfering live ranges + // with 'other_instruction' buffers at same index. + Status RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out); -// Removes any control dependencies to or from the given instruction. -Status StripControlDependenciesFrom(HloInstruction* instruction) { - while (!instruction->control_successors().empty()) { - TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( - instruction->control_successors().front())); - } + // Recursively inserts copies of 'instruction' tuple elements at indices + // specified in 'indices_to_copy', and returns the copy of 'instruction'. + HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index); - while (!instruction->control_predecessors().empty()) { - TF_RETURN_IF_ERROR( - instruction->control_predecessors().front()->RemoveControlDependencyTo( - instruction)); + void RecordIndex(const ShapeIndex& index) { + *indices_to_copy_.mutable_element(index) = true; } - return Status::OK(); -} - -// Add kCopy instructions to the given module to guarantee there is no -// live-range interference. Generally interference can only occur around kWhile -// instructions which have update-in-place semantics. -Status AddCopiesToResolveInterference(HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); + HloInstruction* instruction_; + const std::vector<HloInstruction*> copy_users_; + ShapeTree<bool> indices_to_copy_; + ShapeTree<std::vector<HloInstruction*>> control_predecessors_; + ShapeTree<bool> read_only_indices_; + ShapeTree<HloInstruction*> copy_overrides_; +}; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); - } - } - } - return Status::OK(); +bool InstructionCopier::HasAllIndicesFalse() const { + bool all_indices_false = true; + indices_to_copy_.ForEachElement( + [&all_indices_false](const ShapeIndex& /*index*/, bool data) { + if (data) { + all_indices_false = false; + } + }); + return all_indices_false; } -// Class for removing unnecessary copies from the module. -// -// kCopy instructions are added conservatively to guarantee no live range -// interference between HLO values. This class uses a more fine-grained analysis -// to remove some of these added copies which are not strictly necessary. -class CopyRemover { - public: - CopyRemover(const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering, HloModule* module) - : module_(module), - alias_analysis_(alias_analysis), - ordering_(ordering), - buffer_value_tracker_(*module, alias_analysis, ordering) {} - - // Try to elide the given copy. The copy is elided if the instruction is not - // necessary to prevent live-range interference of HLO values. Returns true if - // copy was elided. - // - // The copy instruction is not actually removed here. Instead it is left for - // dead in the graph. Later calls to DCE will remove the instruction. - StatusOr<bool> TryElideCopy(HloInstruction* copy) { - if (buffer_value_tracker_.TryElideCopy(copy)) { - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); - TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); - return true; - } - return false; +Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( + const TuplePointsToAnalysis& points_to_analysis) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction_); + // Shallow copy the instruction if the points-to set of the top-level + // buffer is ambiguous. This is necessary because the backends must know + // statically what the top-level buffer of the result is. + if (points_to.element(/*index=*/{}).size() > 1) { + RecordIndex({}); } - string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); - StrAppend(&out, " Buffer values, in dependency order:\n"); - for (const HloBuffer& buffer : alias_analysis_.buffers()) { - StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); + // Multiple buffers within a parameter/constant may be live out, so collect + // a set of indices at which to copy first. + points_to.ForEachElement([this](const ShapeIndex& index, + const PointsToSet::BufferList& buffers) { + if (IsReadOnlyIndex(index)) { + return; } - return out; - } - - private: - // Class which tracks the HLO values within each HLO buffer in the module - // during copy removal. - // - // The values are held in a linked list where there is one list for each - // buffer. Removing a copy instruction merges together the values in the - // source buffer of the copy to the destination buffer of the copy. This class - // tracks these value lists as copies are removed from the graph (and value - // lists are merged). - // - // The BufferValueTracker object is initialized to match the state of - // HloAliasAnalysis. However, as copies are removed this state diverges. The - // values-to-buffer mapping is maintained outside of HloAliasAnalysis because - // a fully updatable alias analysis is very slow. - class BufferValueTracker { - public: - // The values held in a single HLO buffer are represented using a linked - // list. An element type in this list is ValueNode. - // - // This linked list is hand-rolled to enable efficient splicing of lists - // using only references to list elements without knowing which lists are - // being spliced. std::list requires a reference to the list object to - // splice. - struct ValueNode { - explicit ValueNode(const HloValue* v) : value(v) {} - - const HloValue* value; - - // The uses are maintained outside of HloValue::uses() because - // HloValue::uses() is not updatable (a fully updatable dataflow analysis - // is slow). - std::vector<const HloUse*> uses; - - // next/prev elements in the linked list. The list is circularly linked so - // these values are never null for elements in the list. - ValueNode* prev = nullptr; - ValueNode* next = nullptr; - }; - - BufferValueTracker(const HloModule& module, - const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering) - : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { - // Construct a list for each HLO buffer in the alias analysis. Maintain a - // map from HloValue to the respective list element representing that - // value. The map is used to construct the copy info map below. - tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node; - for (const HloBuffer& buffer : alias_analysis.buffers()) { - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - for (const HloValue* value_b : buffer.values()) { - if (value_a != value_b) { - DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, - dataflow_) || - ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, - dataflow_)) - << value_a->ToShortString() << " and " - << value_b->ToShortString() << " are not ordered"; - } - } - } - - std::vector<const HloValue*> values = buffer.values(); - std::sort(values.begin(), values.end(), - [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); - - // Create a list containing all of the values in the buffer. - AddValueList(values, &value_to_node); + for (const LogicalBuffer* buffer : buffers) { + // pointee is the HloInstruction producing the buffer which may be + // liveout. + HloInstruction* pointee = buffer->instruction(); + if (pointee->opcode() == HloOpcode::kParameter || + pointee->opcode() == HloOpcode::kConstant) { + VLOG(2) << "Parameter or constant buffer " << buffer->ToString() + << " index: " << tensorflow::str_util::Join(index, ",") + << " may be live out of computation: " << pointee->ToString(); + RecordIndex(index); + break; } - - // Create copy_map_ which contains the source and destination values - // of all copies. - CreateCopyMap(module, value_to_node); - - XLA_VLOG_LINES(3, ToString()); - TF_DCHECK_OK(Verify()); } + }); + return Status::OK(); +} - // Add a list containing the given values to BufferValueTracker. This - // represents the values contained in a single buffer. For each value in - // 'values' an entry is created in value_to_node which indicates the - // respective ValueNode representing that value. - void AddValueList( - tensorflow::gtl::ArraySlice<const HloValue*> values, - tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) { - ValueNode* tail = nullptr; - ValueNode* head = nullptr; - for (const HloValue* value : values) { - auto new_node = new ValueNode(value); - (*value_to_node)[value] = new_node; - - // Copy the HLO values's uses into the ValueNode for the value. These - // uses in ValueNode are updated as copies are removed. - new_node->uses.reserve(value->uses().size()); - for (const HloUse& use : value->uses()) { - new_node->uses.push_back(&use); - } - - // Connect the new node into the linked list. - if (tail == nullptr) { - head = new_node; - } else { - tail->next = new_node; - new_node->prev = tail; - } - tail = new_node; - } - - // The linked list is circular so connect the head and tail. - tail->next = head; - head->prev = tail; - value_lists_.insert(head); - } +Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out) { + TF_RETURN_IF_ERROR( + RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); + TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( + liveness, other_instruction, read_only_indices_out)); + return Status::OK(); +} - // This method also fills in copy_map_ which indicates which nodes - // in the value lists corresponding to the source and destination values of - // kCopy instructions. value_to_node should map each HloValue to its - // respective ValueNode. - void CreateCopyMap( - const HloModule& module, - const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>& - value_to_node) { - for (HloComputation* computation : module.computations()) { - for (HloInstruction* instruction : computation->instructions()) { - // Add copies with unambiguous source values to the map. Copies with - // ambiguous sources are not removable. - if (instruction->opcode() == HloOpcode::kCopy) { - const HloValueSet& src_value_set = - dataflow_.GetValueSet(instruction->operand(0)); - if (src_value_set.values().size() == 1) { - CopyNodes& copy_node = copy_map_[instruction]; - copy_node.dest = - value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); - copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); - } +Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( + const TuplePointsToAnalysis& points_to_analysis) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction_); + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>> + buffer_to_source_indices; + points_to.ForEachElement( + [this, &buffer_to_source_indices]( + const ShapeIndex& index, const PointsToSet::BufferList& buffers) { + if (buffers.size() > 1) { + // Record ambiguous points-to set at 'index'. + if (!indices_to_copy_.element(index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " with ambiguous points-to set."; + RecordIndex(index); } } - } - } + // For each 'buffer': record a mapping from 'buffer' to 'index'. + for (const LogicalBuffer* buffer : buffers) { + buffer_to_source_indices[buffer].push_back(index); + } + }); - ~BufferValueTracker() { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - const ValueNode* tmp = p->next; - delete p; - p = tmp; - } while (p != head); - } + // Record all non-distinct indices detected in 'buffer_to_source_indices'. + for (const auto& buff_to_src : buffer_to_source_indices) { + if (buff_to_src.second.size() == 1) { + continue; } - - // Verify invariants within the linked lists. - Status Verify() const { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - // Verify links between elements are consistent. - TF_RET_CHECK(p->prev->next == p); - TF_RET_CHECK(p->next->prev == p); - - const HloInstruction* def = p->value->defining_instruction(); - if (def->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, def)) { - TF_RET_CHECK(copy_map_.at(def).dest == p); - } - for (const HloUse* use : p->uses) { - if (use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, use->instruction)) { - TF_RET_CHECK(copy_map_.at(use->instruction).src == p); - } - } - - p = p->next; - } while (p != head); + for (const ShapeIndex& src_index : buff_to_src.second) { + // Record non-distinct points-to set at 'src_index'. + if (!indices_to_copy_.element(src_index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(src_index, ",") + << " because of non-distinct points-to set."; + RecordIndex(src_index); } - return Status::OK(); } + } + return Status::OK(); +} - // Try to elide the given copy. Elision of a copy is possible only if no - // live range interference is introduced by the copy's elimination. If - // elision is possible, then the internal state (value lists) are updated, - // and true is returned. Returns false otherwise. - bool TryElideCopy(const HloInstruction* copy) { - VLOG(2) << "Trying to remove " << copy->name(); - - if (!ContainsKey(copy_map_, copy)) { - VLOG(2) << copy->name() << " is not removable"; - return false; - } - - const CopyNodes& copy_node = copy_map_.at(copy); - ValueNode* src = copy_node.src; - ValueNode* dest = copy_node.dest; - DCHECK(src != nullptr); - DCHECK(dest != nullptr); - - auto is_live_range_before = [this](const ValueNode& a, - const ValueNode& b) { - if (LiveRangeBefore(a, b)) { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is before " << b.value->ToShortString(); - return true; - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); - return false; +Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out) { + // Record all buffer indices for 'instruction_', which interfere with + // 'other_instruction' at the same index. + ShapeUtil::ForEachSubshape( + instruction_->shape(), + [this, &liveness, other_instruction, read_only_indices_out]( + const Shape& /*subshape*/, const ShapeIndex& index) { + if (IsReadOnlyIndex(index)) { + return; } - }; - - // A kCopy instruction copies an HLO value from a source buffer and - // defines an HLO value in a destination buffer. Most generally, the - // source and destination buffers may each hold more than one value at - // different points in the computation so we define the following: - // - // Values in source buffer: {s_0, ..., s_n} - // Values in destination buffer: {d_0, ..., d_m} - // - // A kCopy instruction between these buffers copies a value s_x in the - // source buffer and defines a value d_y in the destination buffer. The - // elision of a copy merges the source and destination buffers together, - // so the list of values for the source and destination buffers are - // merged. - // - // We handle two different cases for copy elision: - // - // (1) the kCopy defines the first value in the destination buffer (d_0). - // - // (2) the kCopy copies the last value in the source buffer (s_n). - // - // For the remaining case where the kCopy copies a not-last value from the - // source buffer to a not-first value of the destination buffer, the kCopy - // instruction cannot be removed. This case is generated, for example, if - // the kCopy copies a while body parameter of the loop state at one tuple - // index to a different tuple index in the while body root. Removal of the - // copy necessarily results in live range interference of values in the - // loop state at the two different tuple indices. - // - // We can only perform copy elision if the resulting merged values have - // totally ordered live ranges; otherwise the merged buffer would have - // live range interference. - if (IsHead(*dest)) { - // The copy copies an arbitrary value in the source buffer (call it s_x) - // and defines d_0, the first value in the destination buffer. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} - // - // Removing the copy eliminates d_0, and uses of d_0 become uses of - // s_x. In the above ordering, the live range of d_m must be ordered - // before the live range of s_{x+1} and the definition and all uses of - // s_x must be ordered before the definition of d_1. These conditions - // are checked below prior to elision. - // - // ** Technically it might be possible to have a non-interfering - // non-trivial interleaving of the values of the source and - // destination buffers in the resulting order. However, this case is - // slow and complicated to check and likely not worth it. So instead - // we simply check for the case where *all* values of the destination - // buffer (d_1 through d_m) are spliced into the point where the copy - // used to be. - VLOG(2) << copy->name() << " defines the first value in its buffer"; - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); - if (!is_live_range_before(*src, *next_dest)) { - return false; - } + if (indices_to_copy_.element(index)) { + // Return if previous pass already set index. + return; } - ValueNode* next_src = Next(*src); - - if (next_src != nullptr) { - // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. - ValueNode* last_dest = dest->prev; - DCHECK(IsTail(*last_dest)); - if (!is_live_range_before(*last_dest, *next_src)) { - return false; + const auto& points_to_analysis = liveness.points_to_analysis(); + // Lookup buffers for 'instruction_' and 'other_instruction'. + const auto instruction_buffers = + points_to_analysis.GetPointsToSet(instruction_).element(index); + // If 'instruction_' has ambiguous points-to-set at 'index', it would + // have been recorded in a previous pass (and we would have returned + // early at the entry to this function). As a result, here we know that + // 'instruction_' has just one buffer in its points-to-set. + CHECK_EQ(1, instruction_buffers.size()); + const LogicalBuffer* instruction_buffer = instruction_buffers[0]; + + const auto other_instruction_buffers = + points_to_analysis.GetPointsToSet(other_instruction).element(index); + // Do not insert a copy if both instructions point at the same buffer. + // This eliminates unnecessary copies of read-only tuple elements. + // If 'instruction_' and 'other_instruction' point to the same buffer, + // then that buffer is not updated on the path between the two + // instructions. Therefore, any other (possibly interference-causing) + // users of that buffer from 'other_instruction' will see the same data, + // irrespective of whether we insert a copy of this buffer at + // 'instruction_' or not. + if (other_instruction_buffers.size() == 1 && + other_instruction_buffers[0]->id() == instruction_buffer->id()) { + if (read_only_indices_out != nullptr) { + *read_only_indices_out->mutable_element(index) = true; } + return; } - - // Splice in destination buffer values list right after 'src'. - SpliceAfter(dest, src); - } else if (IsTail(*src)) { - // The copy copies the last value in the source buffer, s_n, and defines - // an arbitrary value in the destination buffer, d_y. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} - // - // Removing the copy eliminates d_y, and uses of d_y become uses of - // s_n. To enforce the above order, the live range of d_{y-1} must be - // before the live range of s_0, and the live range of s_n must be - // before the live range of d_{y+1}. - // - // ** See comment above in the code handling Case (1). - VLOG(2) << copy->name() << " copies the last value (" - << src->value->ToShortString() << ") in its buffer"; - - ValueNode* prev_dest = Prev(*dest); - // nullptr condition handled above in the first 'if' case. - DCHECK(prev_dest != nullptr); - ValueNode* first_src = src->next; - DCHECK(IsHead(*first_src)); - if (!is_live_range_before(*prev_dest, *first_src)) { - // Live range of value d_{y-1} is not before s_0. - return false; - } - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - if (!is_live_range_before(*src, *next_dest)) { - // Live range of value s_n is not before d_{y+1}. - return false; + // We can't say anything about the ambiguity of 'other_instruction' at + // this point, so we need to check interference between the single + // buffer in the points-to set of 'instruction_' and all buffers in + // 'other_instruction_buffers'. + for (const LogicalBuffer* other_buffer : other_instruction_buffers) { + if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " instruction_buffer: " << instruction_buffer->ToString() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " because of interference with buffer: " + << other_buffer->ToString(); + RecordIndex(index); + break; } } + }); + return Status::OK(); +} - // Splice source buffer values list right after 'prev_dest'. - SpliceAfter(first_src, prev_dest); - } else { - VLOG(2) - << copy->name() - << " copies value in middle of source buffer to value in middle " - "of destination buffer"; - return false; - } - - RemoveCopyValue(dest); - - XLA_VLOG_LINES(4, ToString()); - TF_DCHECK_OK(Verify()); - - return true; - } - - // Delete the given ValueNode associated with a elided kCopy - // instruction. This should be called after splicing the value lists of the - // source and destination buffers together. - void RemoveCopyValue(ValueNode* copy_value_node) { - CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), - HloOpcode::kCopy); - ValueNode* operand_node = copy_value_node->prev; - CHECK(operand_node != copy_value_node); - - VLOG(2) << "Removing copy " << operand_node->value->ToShortString() - << " => " << copy_value_node->value->ToShortString(); - - // Splice out the copy value node. - operand_node->next = copy_value_node->next; - copy_value_node->next->prev = operand_node; - - // Patch up uses. Remove use of copy from operand_node uses. - auto it = - std::find_if(operand_node->uses.begin(), operand_node->uses.end(), - [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); - CHECK(it != operand_node->uses.end()); - operand_node->uses.erase(it); - - // If the elided copy has any uses which are themselves kCopy instructions - // then patch up the copy info to reflect the that this kCopy instruction - // has a different operand (the operand of the elided copy). - for (const HloUse* copy_use : copy_value_node->uses) { - operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy) { - copy_map_.at(copy_use->instruction).src = operand_node; - } - } - - // Delete the copy info and the value node. - copy_map_.erase(copy_value_node->value->defining_instruction()); - delete copy_value_node; - } - - // Returns true if the live range of given value 'a' is before the live - // range of 'b'. - // - // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not - // updated as copies are removed. - bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - if (a.uses.empty()) { - VLOG(2) << "Empty uses"; - return ordering_.IsDefinedBefore(*a.value, *b.value); - } - for (const HloUse* use : a.uses) { - VLOG(2) << "use: " << *use; - VLOG(2) << "is before:" << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Not before"; - return false; +// This is called when 'instruction_' is a while body root, and 'parameter' is +// the while body parameter. We record all users of all aliases of 'parameter' +// as control predecessors, so that when we add a copy of 'instruction_', we can +// mark the control dependencies. This is necessary because points-to and +// liveness analysis doesn't know about the aliasing between the while body root +// and param. Without these control dependencies, the copy might get scheduled +// to run at a point that interferes with users of the buffer. +Status InstructionCopier::RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter) { + return indices_to_copy_.ForEachElementWithStatus( + [this, &points_to_analysis, parameter](const ShapeIndex& index, + bool will_copy) { + if (will_copy) { + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + points_to_analysis.GetBufferDefinedAt(parameter, index)); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + user, points_to_analysis)) { + continue; + } + + if (user != instruction_) { + control_predecessors_.mutable_element(index)->push_back(user); + } + } + } } - } - return true; - } - - // Returns whether 'node' is the last node in its list. - bool IsTail(const ValueNode& node) const { - return ContainsKey(value_lists_, node.next); - } - - // Returns whether 'node' is the first node in its list. - bool IsHead(const ValueNode& node) const { - return ContainsKey(value_lists_, &node); - } - - // Returns the next node in the list after 'node'. If 'node' is the - // tail, then nullptr is returned. - ValueNode* Next(const ValueNode& node) const { - if (IsTail(node)) { - return nullptr; - } else { - return node.next; - } - } - - // Returns the previous node in the list before 'node'. If 'node' - // is the head, then nullptr is returned. - ValueNode* Prev(const ValueNode& node) const { - if (IsHead(node)) { - return nullptr; - } else { - return node.prev; - } - } - - // Splices the entire linked list with 'head' as its head right after the - // node 'insert_after' in another linked list. - void SpliceAfter(ValueNode* head, ValueNode* insert_after) { - DCHECK(IsHead(*head)); - value_lists_.erase(head); - - ValueNode* tail = head->prev; - tail->next = insert_after->next; - insert_after->next->prev = tail; - - insert_after->next = head; - head->prev = insert_after; - } - - string ToString() const { - string out = StrCat("BufferValueTracker:\n"); - StrAppend(&out, " Def-use chains in each buffer:\n"); - for (const ValueNode* head : value_lists_) { - StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), - ":\n"); - const ValueNode* p = head; - do { - StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), - "\n"); - - p = p->next; - } while (p != head); - } - StrAppend(&out, " Potentially removable copies:\n"); - for (const auto& pair : copy_map_) { - const HloInstruction* copy = pair.first; - const CopyNodes& copy_info = pair.second; + return Status::OK(); + }); +} - StrAppend(&out, " ", copy->name(), " : ", - copy_info.src->value->ToShortString(), " => ", - copy_info.dest->value->ToShortString(), "\n"); - } - return out; +// Recursively inserts copies of 'instruction' tuple element buffers at +// indices in 'indices_to_copy_', expanding tuples as needed. +HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, + ShapeIndex* index) { + const int64 num_tuple_elements = + ShapeUtil::TupleElementCount(instruction->shape()); + std::vector<HloInstruction*> elem_copies(num_tuple_elements); + for (int64 i = 0; i < num_tuple_elements; ++i) { + HloInstruction* elem; + if (instruction->opcode() == HloOpcode::kTuple) { + // If the instruction is already a Tuple instruction, we know that the + // element buffers are aliased, so we can just grab the operand directly. + elem = instruction->mutable_operand(i); + } else { + // Otherwise we need to add a GTE to unpack the element out of the tuple. + elem = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); } - - private: - const HloDataflowAnalysis& dataflow_; - const HloOrdering& ordering_; - - // The heads of all the value lists. Each value list represents the HLO - // values contained in a particular HLO buffer. The values in the list are - // in dependency order. - tensorflow::gtl::FlatSet<const ValueNode*> value_lists_; - - // Copy removal requires fast access to the value list elements - // corresponding to the source and destination values of the kCopy - // instruction. This data structure holds pointers to these elements for - // each kCopy instruction in the graph. - struct CopyNodes { - // The source and destinations values of the kCopy instruction. - ValueNode* src = nullptr; - ValueNode* dest = nullptr; - }; - tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_; - }; - - HloModule* module_; - const HloAliasAnalysis& alias_analysis_; - const HloOrdering& ordering_; - - // Object tracking the HLO values contained in each HLO buffer. - BufferValueTracker buffer_value_tracker_; -}; - -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet<HloInstruction::Id>& copies_to_exclude, - HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); - - tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + index->push_back(i); + if (ShapeUtil::IsTuple(elem->shape())) { + elem_copies[i] = CopyTuple(elem, index); + } else if (!indices_to_copy_.element(*index)) { + elem_copies[i] = elem; + } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { + elem_copies[i] = copy_override; + } else { + HloInstruction* elem_copy = elem->parent()->AddInstruction( + HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); + for (HloInstruction* control_predecessor : + control_predecessors_.element(*index)) { + VLOG(2) << "Adding control dependency from " + << control_predecessor->ToString() << " to " + << elem_copy->ToString(); + TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); } + elem_copies[i] = elem_copy; } + index->pop_back(); } - - return Status::OK(); + return instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(elem_copies)); } -// Add copies to address special constraints on the roots of computations not -// related to live range interference: -// -// (1) Entry computation root must be unambiguous and distinct. -// -// (2) Any computation called by a kCall instruction must have an -// unambiguous root. -// -// (3) Constants and parameters cannot be live out of the entry computation -// -Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - - // Identify which shape indices of which instructions need to be copied. Store - // these results in 'instructions_to_copy'. - std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy; - auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, - const ShapeIndex& index) { - auto it = instructions_to_copy.find(instruction); - if (it == instructions_to_copy.end()) { - auto it_added = instructions_to_copy.emplace( - std::piecewise_construct, std::forward_as_tuple(instruction), - std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); - it = it_added.first; - } - *it->second.mutable_element(index) = true; - }; - - // Iterate through values of all constants and entry parameters. These values - // are special because they are held in read-only buffers. If any of these - // values share a buffer with other values (for example, the init value of a - // while is a constant) then copy the value at its definition and replace all - // its uses with the copy. - for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { - if (ValueIsReadOnly(*value) && - alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { - VLOG(2) << "Value " << value->ToShortString() - << " is read only, but its buffer contains more than one value. " - "Copying."; - add_index_to_copy(value->defining_instruction(), value->defining_index()); - } +// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. +HloInstruction* InstructionCopier::Copy() { + ShapeIndex index; + HloInstruction* copy; + if (ShapeUtil::IsTuple(instruction_->shape())) { + copy = CopyTuple(instruction_, &index); + } else { + copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction_->shape(), HloOpcode::kCopy, instruction_)); } - - // Identify copies which must be added at root instructions - for (HloComputation* computation : module->computations()) { - const CallGraphNode& node = call_graph.GetNode(computation); - if (node.context() == CallContext::kParallel) { - continue; - } - TF_RET_CHECK(node.context() == CallContext::kSequential); - - const bool is_entry = computation == module->entry_computation(); - HloInstruction* root = computation->root_instruction(); - - // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet<const HloBuffer*> seen; - ShapeUtil::ForEachSubshape( - root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { - std::vector<const HloBuffer*> buffers_at_index = - alias_analysis->ComputeBuffersAt(root, index); - bool buffer_seen_before = false; - for (const HloBuffer* buffer : buffers_at_index) { - buffer_seen_before |= !seen.insert(buffer).second; - } - if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { - VLOG(2) << "Index " << index << " of root of computation " - << computation->name() << " (" << root->name() - << ") has ambiguous or non-distinct buffer. Copying."; - add_index_to_copy(root, index); - } - }); - - // For entry instructions, mark any parameter or constant values. - if (is_entry) { - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (ValueIsReadOnly(*value)) { - VLOG(2) << "Root of entry computation (" << root->name() - << ") has constant or entry parameter value at index " - << index << ". Copying."; - add_index_to_copy(root, index); - } - } - } - } + for (HloInstruction* user : copy_users_) { + VLOG(2) << "Adding copy between instruction: " << instruction_->name() + << " and user: " << user->name(); + TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy)); } + return copy; +} - // TODO(b/62548313): Buffer assignment uses TuplePointsToAnalysis which is - // computation-scoped. This means the analysis doesn't have visibility to - // constants and entry parameters that cross computation boundaries. This can - // cause invalid buffer assignments so additional conservative copies are - // added to handle these cases. Remove this whole loop when buffer assignment - // uses alias analysis. - for (HloComputation* computation : module->computations()) { - const CallGraphNode& node = call_graph.GetNode(computation); - - bool is_while_body = false; - if (node.context() == CallContext::kSequential && - !node.caller_callsites().empty()) { - CHECK_EQ(node.caller_callsites().size(), 1); - const HloInstruction* calling_instruction = - node.caller_callsites()[0].instruction(); - is_while_body = calling_instruction->opcode() == HloOpcode::kWhile && - calling_instruction->while_body() == node.computation(); +// The 'read_only_indices' are initialized based on points-to analysis on the +// while body corresponding to 'while_hlo'. If the init buffer corresponding to +// a read-only index aliases with a constant, it cannot be considered read-only, +// and must be copied. This is necessary because BufferAssignment does not +// currently assign an allocation for constants (b/32248867). +// This function performs this fix-up of 'read_only_indices'. +// +// Returns a ShapeTree of copy_overrides, which implements an optimization to +// allow multiple while loops that share the same read-only constants to +// share a single copy. +StatusOr<ShapeTree<HloInstruction*>> RevertReadOnlyIndicesForConstants( + const HloInstruction* while_hlo, + const TuplePointsToAnalysis& points_to_analysis, + ShapeTree<bool>* read_only_indices, + FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) { + const HloInstruction* init_hlo = while_hlo->operand(0); + const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); + + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatSet<const LogicalBuffer*> buffer_set; + + ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape()); + points_to.ForEachElement([init_hlo, read_only_indices, shared_copies, + &buffer_set, ©_overrides]( + const ShapeIndex& index, + const PointsToSet::BufferList& buffers) { + // Look for read-only entry parameters. + if (!read_only_indices->element(index)) { + return; } - VLOG(2) << computation->name() << " is_while_body: " << is_while_body; - HloInstruction* root = computation->root_instruction(); + for (const LogicalBuffer* buffer : buffers) { + HloInstruction* pointee = buffer->instruction(); + const bool is_constant = pointee->opcode() == HloOpcode::kConstant; + if (!is_constant) { + continue; + } - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (IsConstantValue(*value) && !is_while_body) { - VLOG(2) << "Root of computation (" << root->name() - << ") is constant at index " << index << ". Copying."; - add_index_to_copy(root, index); + // We have found an constant that is read-only in + // the while body. These buffers are managed by the caller, and cannot + // be aliased with HLO buffers. Revert this read-only index, + // to allow it to be copied. + *read_only_indices->mutable_element(index) = false; + + // Optimization to allow multiple while loops that share the same + // read-only entry constants to share a single copy. + // Only unambiguous and distinct array-shaped buffers are allowed, to + // reduce code complexity. The shape of the entry parameter must be + // identical to the shape of the init_hlo at this index, to ensure + // there were no intervening bitcast or GTE instructions, which are + // also hard to handle. + const Shape& pointee_shape = pointee->shape(); + const Shape& init_shape = + ShapeUtil::GetSubshape(init_hlo->shape(), index); + if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && + ShapeUtil::Equal(pointee_shape, init_shape) && + buffer_set.count(buffer) < 1) { + HloInstruction** copy = &(*shared_copies)[pointee]; + if (*copy == nullptr) { + *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary( + pointee_shape, HloOpcode::kCopy, pointee)); } + // Add the copy as an override. + *copy_overrides.mutable_element(index) = *copy; } - } - } - // Add copy instructions indicated in 'instructions_to_copy' to the module. - for (const auto& pair : instructions_to_copy) { - HloInstruction* instruction = pair.first; - const ShapeTree<bool>& indices_to_copy = pair.second; + // Tracks whether this current buffer is distinct. + buffer_set.insert(buffer); - std::vector<HloInstruction*> users = instruction->users(); - TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, - instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); - for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); + // We've already reverted the read-only index and handled the + // single-copy optimization above, so there's nothing more to do. + break; } - if (instruction == instruction->parent()->root_instruction()) { - instruction->parent()->set_root_instruction(deep_copy); - } - } - - return Status::OK(); -} - -Status VerifyNoLiveRangeInterference(HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - DependencyHloOrdering ordering(module); - TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); - return Status::OK(); + }); + return copy_overrides; } -void MaybeDumpModule(const string& message, const HloModule& module) { - if (VLOG_IS_ON(3)) { - VLOG(3) << message; - XLA_VLOG_LINES(3, module.ToString()); - hlo_graph_dumper::MaybeDumpHloModule(module, message); +} // anonymous namespace + +// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the +// base class, since the regular CopyInsertion logic above selectively copies +// tuple elements, while this method assumes all buffers need to be deep copied. +StatusOr<HloInstruction*> CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { + auto copy_it = inserted_copies_.find(hlo); + if (copy_it == inserted_copies_.end()) { + HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); + inserted_copies_.insert({hlo, copy}); + return copy; + } else { + return copy_it->second; } } -} // namespace - StatusOr<bool> CopyInsertion::Run(HloModule* module) { - // Copy insertion is performed in three steps: - // - // (1) Add copies conservatively to guarantee that there is no live-range - // interference. This is done simplistically and usually results in more - // copies than is strictly necessary. - // - // (2) Using a more fine-grained analysis, remove as many copies that were - // added in (1) as possible while ensuring no live-range interference. - // - // (3) Add copies to resolve issues not related to live range interference - // such as parameters and constants live out of the entry computation. - // - // We add copies then remove them (step (1) then (2)) rather than simply - // adding only the copies that are necessary because, in general, it is - // difficult to figure out the minimal set of copies to add once there is - // interference. On the other hand, it is easy to determine if removing a copy - // will introduce interference. - // - // The final copy insertion in (3) is done separately to simplify the - // implementation of copy removal in (2) which is the most complicated part of - // the pass. As is, copy removal only has to reason about live range - // interference. If all copies were added in step (1) then copy removal would - // also have to reason about things like constants and parameters live out of - // the computation. - MaybeDumpModule("before copy insertion", *module); - - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); - if (!call_graph->IsFlattened()) { - return FailedPrecondition( - "Call graph must be flattened before copy insertion."); - } + bool changed = false; + VLOG(2) << "CopyInsertion for module " << module->name(); - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies; - for (HloComputation* computation : module->computations()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr<BufferLiveness> liveness, + BufferLiveness::Run(module, MakeUnique<DependencyHloOrdering>(module))); + const auto& points_to_analysis = liveness->points_to_analysis(); + XLA_VLOG_LINES(2, points_to_analysis.ToString()); + XLA_VLOG_LINES(2, module->ToString()); + + // Gather all while body computations and while instructions. + FlatSet<const HloComputation*> while_body_computations; + std::vector<HloInstruction*> while_instructions; + for (auto* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + if (instruction->opcode() == HloOpcode::kWhile) { + while_body_computations.insert(instruction->while_body()); + while_instructions.push_back(instruction); } } } - TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); - - // Simplify the tuple structures introduced by the deep copies. This should be - // done before removing copies (RemoveUnnecessaryCopies) because tuple - // simplification changes dependencies in the graph which changes live range - // interference in the graph. Also run DCE to remove the dead Tuple/GTE - // instructions introduced by tuple simplification. - TupleSimplifier tuple_simplifier; - HloDCE dce; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - TF_RETURN_IF_ERROR(dce.Run(module).status()); - - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - - MaybeDumpModule("after adding copies to resolve interference", *module); - - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); - - MaybeDumpModule("after removing unnecessary copies", *module); - - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); - - MaybeDumpModule("after adding special-case copies", *module); + // Collect instruction buffer indices to copy in 'instructions_to_copy'. + std::vector<InstructionCopier> instructions_to_copy; + + // Add copies of computation root instructions, if needed. + FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices; + for (auto* computation : module->MakeNonfusionComputations()) { + VLOG(2) << "computation " << computation->name(); + InstructionCopier root_copier(computation->root_instruction(), + /*copy_users=*/{}); + if (while_body_computations.count(computation) > 0) { + // Record root indices to copy for while body sub-computations. We do not + // need to call RecordIndicesWhichPointToParamOrConstant for the while + // body root instruction here, because any necessary copies needed to + // avoid constants or parameters in the output are handled by while.init + // operand copy insertion below (which will share an allocation). + HloInstruction* while_body_param = computation->parameter_instruction(0); + ShapeTree<bool> read_only_indices(while_body_param->shape()); + TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( + *liveness, while_body_param, &read_only_indices)); + while_body_read_only_indices[computation] = read_only_indices; + + // Mark control predecessors, based on the body param, for any copies + // we'll be inserting. This ensures the copy doesn't run too early. + TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( + points_to_analysis, while_body_param)); + } else { + // Record root indices to copy for general computations. + TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( + points_to_analysis)); + } + instructions_to_copy.push_back(root_copier); + } - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + // Add copies of while 'init' operand instructions, if needed. 'shared_copies' + // is used to ensure that multiple while loops can share a single copy of the + // same entry parameter or constant, if all loops use it read-only. + // + // TODO(b/33301720) Remove redundant while instruction copies. + FlatMap<const HloInstruction*, HloInstruction*> shared_copies; + for (HloInstruction* while_hlo : while_instructions) { + // Fix read_only_indices to account for entry constants. Also + // initialize copy_overrides, which ensures a single copy for each read-only + // constant that is used in multiple while loops. + ShapeTree<bool>* read_only_indices = + &while_body_read_only_indices[while_hlo->while_body()]; + TF_ASSIGN_OR_RETURN( + const ShapeTree<HloInstruction*> copy_overrides, + RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis, + read_only_indices, &shared_copies)); + // Create InstructionCopier for init operand of while instruction. + HloInstruction* init_hlo = while_hlo->mutable_operand(0); + InstructionCopier init_copier(init_hlo, {while_hlo}); + init_copier.SetReadOnlyIndices(*read_only_indices); + init_copier.SetCopyOverrides(copy_overrides); + // Record 'init' buffer indices which point-to a Constant or Parameter. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( + points_to_analysis)); + // Record indices necessary to colocate while and init operand buffers. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( + *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); + instructions_to_copy.push_back(init_copier); + } - MaybeDumpModule("after copy insertion", *module); + for (InstructionCopier& to_copy : instructions_to_copy) { + if (to_copy.HasAllIndicesFalse()) { + continue; + } + changed = true; - if (VLOG_IS_ON(1)) { - int64 num_total_copies = 0; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - num_total_copies++; - } - } + // Copy instruction at recorded buffer indices. + HloComputation* computation = to_copy.instruction()->parent(); + HloInstruction* copy = to_copy.Copy(); + if (to_copy.instruction() == computation->root_instruction()) { + computation->set_root_instruction(copy); } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); - VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } - return true; + VLOG(3) << "After copy insertion for module " << module->name(); + XLA_VLOG_LINES(3, module->ToString()); + + return changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index ea3c36b5c7..28bb62e40c 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -25,25 +25,12 @@ limitations under the License. namespace xla { -// Copy insertion is a legalization HLO pass which inserts copies (kCopy -// instructions) to eliminate several kinds of problems in the HLO module. -// -// (1) Entry parameter or a constant live out of the entry computation. Entry -// computation arguments and constants have different lifetimes than the -// computation result and cannot share the same allocation. Parameters and -// constants live out of non-entry computations do not need copies. -// -// (2) Different values which are simultaneously live and which must be held -// in the same buffer. This can occur in while bodies. Specifically, the -// while loop state (the arguments to the while instruction) is updated -// in-place and the update may clobber the value from the previous -// iteration before the previous value is dead. Computations called from -// kCall instructions do not need such copies because kCall has no update -// in-place semantics. -// -// (3) The buffer set of the root instruction of the entry computation must be -// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and -// InstructionAliasSet::IsDistinct return true. +// HLO pass which inserts a copy of the root instruction (creating a new root) +// if the root is or points-to any constant or parameter instruction. +// If the root instruction is a Tuple, only tuple elements which point to +// constant or parameter instructions will be copied. +// Copy insertion is necessary because constant and parameter arrays have +// different lifetimes than computation results. class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } @@ -51,6 +38,15 @@ class CopyInsertion : public HloPassInterface { // Run the pass on the given module. Returns whether the module was changed // (copies were inserted). StatusOr<bool> Run(HloModule* module) override; + + protected: + // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // duplicate copies. + StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo); + + // A map containing all copies inserted during the copy insertion pass. The + // key is the copied instruction and the value is the copy. + tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 8807c6480b..a2eacc5c7d 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,19 +17,18 @@ limitations under the License. #include <set> -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/test_benchmark.h" namespace op = xla::testing::opcode_matchers; @@ -38,53 +37,35 @@ namespace { using ::testing::UnorderedElementsAre; -int64 CountCopies(const HloComputation& computation) { - int64 count = 0; - for (const auto& instruction : computation.instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - count++; - } - } - return count; -} - -int64 CountCopies(const HloModule& module) { - int64 count = 0; - for (const auto& computation : module.computations()) { - count += CountCopies(*computation); - } - return count; -} - -int64 CountControlEdges(const HloComputation& computation) { - int64 count = 0; - for (const auto& instruction : computation.instructions()) { - count += instruction->control_successors().size(); - } - return count; -} - -int64 CountControlEdges(const HloModule& module) { - int64 count = 0; - for (const auto& computation : module.computations()) { - count += CountControlEdges(*computation); - } - return count; -} - class CopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CopyInsertion copy_insertion; - ASSERT_IS_OK(copy_insertion.Run(module).status()); + EXPECT_IS_OK(copy_insertion.Run(module).status()); + + // Verify the points to set of the root of the computation after copy + // insertion contains no constants or parameters, and is distinct and + // non-ambiguous. + auto points_to_analysis = + TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); + const auto& points_to = points_to_analysis->GetPointsToSet( + module->entry_computation()->root_instruction()); + EXPECT_TRUE(points_to.IsDistinct()); + EXPECT_TRUE(!points_to.IsAmbiguous()); + + auto maybe_live_out_buffers = + points_to_analysis + ->GetPointsToSet(module->entry_computation()->root_instruction()) + .CreateFlattenedSet(); + + for (const LogicalBuffer* buffer : maybe_live_out_buffers) { + EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); + EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); + } } - - const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); }; TEST_F(CopyInsertionTest, SingleParameter) { - // Computation is a single parameter passed into a tuple. The parameter should - // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -96,15 +77,14 @@ TEST_F(CopyInsertionTest, SingleParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); + HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(x))); + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, SingleConstant) { - // Computation is a single constant passed into a tuple. The parameter should - // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); @@ -116,42 +96,11 @@ TEST_F(CopyInsertionTest, SingleConstant) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); + HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(constant))); -} - -TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { - // Verify that an kCopy instructions which exist in the pass before - // copy-insertion remain in the graph after copy-insertion. - auto module = CreateNewModule(); - - auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); - HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kCopy, constant)); - HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( - constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); - HloInstruction* add_copy = builder.AddInstruction( - HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); - - module->AddEntryComputation(builder.Build()); - - EXPECT_EQ(CountCopies(*module), 3); - - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 3); - - EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -178,12 +127,12 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); + HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(old_root->operand(0)), + op::Copy(old_root->operand(1)), old_root->operand(2))); } TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { @@ -216,7 +165,6 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 2); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Copy(op::GetTupleElement(old_root)), @@ -239,7 +187,6 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -261,7 +208,6 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -281,11 +227,11 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); + HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(bitcast))); + op::Tuple(op::Copy(old_root->operand(0)))); } TEST_F(CopyInsertionTest, NestedTupleParameter) { @@ -311,8 +257,6 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 3); - HloInstruction* new_root = module->entry_computation()->root_instruction(); EXPECT_NE(old_root, new_root); @@ -349,13 +293,12 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { EXPECT_EQ(gte, module->entry_computation()->root_instruction()); + HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT( - module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), - op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(old_root)), + op::Copy(op::GetTupleElement(old_root)))); } TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { @@ -388,7 +331,6 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); - EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -404,10 +346,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // The parameter 'nested' specifies the loop state shape from which to // read the induction variable. std::unique_ptr<HloComputation> BuildConditionComputation( - const Shape& loop_state_shape) { + bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<int32>(10))); + const Shape& loop_state_shape = + nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -638,7 +582,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, inner_init})); auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_init->shape(), condition, body, loop_state_init)); + loop_state_shape_, condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -714,28 +658,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. - auto add = builder.AddInstruction(HloInstruction::CreateBinary( + builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); - auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, - data_init, &builder); - - // Add an additional binary operation operating on the while and the - // interfering add so that neither operation is dead. - auto gte = xla_while->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); - auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( - data_shape_, HloOpcode::kSubtract, add, gte)); - auto gte0 = xla_while->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); - auto tuple = xla_while->parent()->AddInstruction( - HloInstruction::CreateTuple({gte0, sub})); - - xla_while->parent()->set_root_instruction(tuple); - - return xla_while; + return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, + &builder); } HloInstruction* BuildWhileInstructionWithCustomInit( @@ -745,8 +672,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<int32>(0))); - auto condition = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape)); + auto condition = + module_->AddEmbeddedComputation(BuildConditionComputation(nested)); auto body = module_->AddEmbeddedComputation( BuildIndependentBodyComputation(nested)); auto loop_state_init = builder->AddInstruction( @@ -779,21 +706,23 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // CopyInsertion pass should not generate any copies. // TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { - auto condition = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); + auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); auto body = module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); + HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); + HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); - // Body should have no copies as the adds can be done inplace. - EXPECT_EQ(CountCopies(*body), 0); - EXPECT_EQ(CountControlEdges(*module_), 0); + // No copies should be inserted so root should not be updated. + EXPECT_EQ(old_root, new_root); - // Both init indices need copies as they are constants. - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); + // Both init indices need copies. + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while body computation with dependent tuple elements: @@ -808,33 +737,20 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { // Tuple(Copy(out0), out1) // TEST_F(WhileCopyInsertionTest, DependentTupleElements) { - auto condition = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); + auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); + HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); + HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); - EXPECT_EQ(CountCopies(*body), 1); - EXPECT_EQ(CountControlEdges(*body), 0); - - EXPECT_THAT( - body->root_instruction(), - op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); - - auto add = body->root_instruction()->operand(0); - auto bcast = body->root_instruction()->operand(1)->operand(1); - ASSERT_EQ(add->opcode(), HloOpcode::kAdd); - ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); - - EXPECT_THAT( - while_hlo->while_body()->root_instruction(), - op::Tuple(op::Add(op::Copy(), op::Constant()), - op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); - - // Both init indices need copies as they are constants. - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); + EXPECT_THAT(new_root, + op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1))); + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while body computation with read-only tuple element 0: @@ -852,26 +768,33 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { // // CopyInsertion pass should not generate any copies for the while body. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { - auto condition = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); + auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); auto body = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); - BuildWhileInstruction(condition, body); + auto while_hlo = BuildWhileInstruction(condition, body); + const HloInstruction* old_init = while_hlo->operand(0); + HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); + HloInstruction* new_root = body->root_instruction(); + const HloInstruction* new_init = while_hlo->operand(0); + + // No copies should be inserted in the body, so root should not be updated. + EXPECT_EQ(old_root, new_root); - // No copies or control edges should be inserted. The body is legal as is. - EXPECT_EQ(CountCopies(*body), 0); - EXPECT_EQ(CountControlEdges(*body), 0); + // Both indices need copies, even though Index 0 is read-only, since both are + // constants, which must be copied. + EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Same as above, but with two while loops, sharing entry parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { - auto condition1 = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); - auto condition2 = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); + auto condition1 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = + module_->AddEmbeddedComputation(BuildConditionComputation()); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -889,46 +812,30 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - - // Add a couple elements from each of the while so both whiles are live. - auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); - auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); - - auto entry = module_->AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // Neither body should have any copies or control edges in them. - EXPECT_EQ(CountCopies(*body1), 0); - EXPECT_EQ(CountCopies(*body2), 0); - EXPECT_EQ(CountControlEdges(*body1), 0); - EXPECT_EQ(CountControlEdges(*body2), 0); + // Both while loops alias iter_param, since index 0 is read-only in the body. + EXPECT_EQ(while_hlo1->operand(0)->operand(0), + while_hlo2->operand(0)->operand(0)); + EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_param); - // Only two copies should be necessary. Each of the whiles should have - // a copy of tuple element 1 (init value is a parameter, and the element is - // not non-read-only) so each of the while bodies gets its own buffer to write - // element 1 into. - EXPECT_EQ(CountCopies(*entry), 2); - - EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); - EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); - - // The two copies of element 1 should be different. + // Each while loop gets its own copy of data_param, since index 1 is not + // read-only in the body. EXPECT_NE(while_hlo1->operand(0)->operand(1), while_hlo2->operand(0)->operand(1)); + EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param)); + EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param)); } // Same as above, but with two while loops, sharing non-parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { - auto condition1 = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); - auto condition2 = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape_)); + auto condition1 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = + module_->AddEmbeddedComputation(BuildConditionComputation()); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -951,28 +858,21 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - - // Add a couple elements from each of the while so both whiles are not dead. - auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); - auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); - auto entry = module_->AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // Ideally only one copy should be necessary. One of the whiles should - // have a copy of tuple element 1 (the non-read-only element) so each of the - // while bodies gets its own buffer to write element 1 into. However, the - // analysis isn't perfect and adds an additional copy of element 0. - EXPECT_EQ(CountCopies(*entry), 2); + // No copies of iter_value are necessary, since index 0 is read-only in both + // while bodies. + EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value); + EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value); - EXPECT_THAT(while_hlo1->operand(0), - op::Tuple(op::Exp(), op::Copy(op::Exp()))); - EXPECT_THAT(while_hlo2->operand(0), - op::Tuple(op::Exp(), op::Copy(op::Exp()))); + // Each while loop gets its own copy of data_value, since index 1 is not + // read-only in the body. + EXPECT_NE(while_hlo1->operand(0)->operand(1), + while_hlo2->operand(0)->operand(1)); + EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value)); + EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value)); } // Tests while body computation with nested tuple elements: @@ -1005,34 +905,18 @@ TEST_F(WhileCopyInsertionTest, // Tuple // new root // TEST_F(WhileCopyInsertionTest, NestedTupleElements) { - auto condition = module_->AddEmbeddedComputation( - BuildConditionComputation(nested_loop_state_shape_)); + auto condition = + module_->AddEmbeddedComputation(BuildConditionComputation(true)); auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); BuildWhileInstruction(condition, body, true); - // HloInstruction* old_root = body->root_instruction(); + HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - // The only copy necessary is for the kReverse as it cannot be done - // in-place (instruction can share buffer with operand). The other elements of - // the loop state are kAdd instructions which can be done in-place. - EXPECT_EQ(CountCopies(*body), 1); - - // Each element of the init needs a copy as all are constants. - EXPECT_EQ(CountCopies(*module_), 4); - - // Either the kReverse itself must be copied or the operand of the kReverse - // must be copied. - if (body->root_instruction()->operand(1)->operand(1)->opcode() == - HloOpcode::kCopy) { - EXPECT_THAT( - body->root_instruction(), - op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); - } else { - EXPECT_THAT( - body->root_instruction(), - op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); - } + EXPECT_THAT(body->root_instruction(), + op::Tuple(old_root->operand(0), + op::Tuple(old_root->operand(1)->operand(0), + op::Copy(old_root->operand(1)->operand(1))))); } // Tests while init instruction which points-to a constant. @@ -1043,13 +927,11 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) { // TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); - + auto old_init = while_hlo->operand(0); InsertCopies(module_.get()); - EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); - EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction which points-to a parameter. @@ -1060,13 +942,11 @@ TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { // TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); - + auto old_init = while_hlo->operand(0); InsertCopies(module_.get()); - EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); - EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction which has an ambiguous points-to set. @@ -1095,34 +975,15 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { // TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); - + auto old_init = while_hlo->operand(0); InsertCopies(module_.get()); - EXPECT_EQ(CountCopies(*module_), 4); - // The entry computation requires three copies to resolve the ambiguity of two - // init elements and the constant passed in as one of the init elements. - EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), - op::Tuple(op::Copy(op::GetTupleElement()), - op::Copy(op::GetTupleElement())))); - - // The body requires one copy because the buffer set is not distinct: the - // result of one of the adds is written into two elements of the output of the - // loop body. Either element might be copied. - EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); - if (while_hlo->while_body() - ->root_instruction() - ->operand(1) - ->operand(0) - ->opcode() == HloOpcode::kCopy) { - EXPECT_THAT( - while_hlo->while_body()->root_instruction(), - op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); - } else { - EXPECT_THAT( - while_hlo->while_body()->root_instruction(), - op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); - } + + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple( + op::Copy(old_init->operand(0)), + op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))), + op::Copy(op::GetTupleElement(old_init->operand(1)))))); } // Tests while init instruction which has a non-distinct points-to set. @@ -1150,43 +1011,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { // TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); - + auto old_init = while_hlo->operand(0); InsertCopies(module_.get()); - // The entry computation requires two copies to resolve the non-disinctness of - // two init elements and the constant passed in as one of the init - // elements. Either element can be copied for the distinctness issue. - EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); - if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == - HloOpcode::kCopy) { - EXPECT_THAT( - while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), - op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); - } else { - EXPECT_THAT( - while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), - op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); - } - - // The body requires one copy because the buffer set is not distinct: the - // result of one of the adds is written into two elements of the output of the - // loop body. Either element might be copied. - EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); - if (while_hlo->while_body() - ->root_instruction() - ->operand(1) - ->operand(0) - ->opcode() == HloOpcode::kCopy) { - EXPECT_THAT( - while_hlo->while_body()->root_instruction(), - op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); - } else { - EXPECT_THAT( - while_hlo->while_body()->root_instruction(), - op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); - } + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(old_init->operand(0)), + op::Tuple(op::Copy(old_init->operand(1)->operand(0)), + op::Copy(old_init->operand(1)->operand(0))))); } // Tests while init instruction buffer which interferes with while result @@ -1200,13 +1031,11 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { // TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); - + auto old_init = while_hlo->operand(0); InsertCopies(module_.get()); - EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); + EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), + op::Copy(old_init->operand(1)))); } // Tests while init instruction buffer which has a non-distinct points-to set: @@ -1215,21 +1044,18 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { // Parameter(F32, {8}))) // // where the second and third parameters are identical *and* the tuple shared -// by another while instruction. +// by another while instruction.. // // Verifies that the resulting point-to set is distinct in the resulting Tuple // (non-identical Copys). In other words, verifies that copy sharing does not // insert identical copies to the resulting tuple. TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { + auto condition1 = + module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition2 = + module_->AddEmbeddedComputation(BuildConditionComputation()); // Loop body that outputs tuple comprises two elements dependent on the init // tuple. - const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_, data_shape_}); - - auto condition1 = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape)); - auto condition2 = module_->AddEmbeddedComputation( - BuildConditionComputation(loop_state_shape)); auto body1 = module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); auto body2 = @@ -1246,6 +1072,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( @@ -1253,479 +1081,43 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition2, body2, loop_init)); - // Add add instruction so neither while is dead. - auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); - auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); - builder.AddInstruction( - HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); - module_->AddEntryComputation(builder.Build()); - InsertCopies(module_.get()); - - // None of the bodies should have copies or control flow edges. - EXPECT_EQ(CountCopies(*body1), 0); - EXPECT_EQ(CountCopies(*body2), 0); + auto points_to_analysis = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally - // these should not need to be copied before either while. However, copy - // insertion is not able to reason about the transparency of elements through - // while bodies in all circumstances so extra copies are added (b/xxx). - EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); + // Asserts that the init tuples before copy insertion is non-distinct. + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); + ASSERT_FALSE( + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); - EXPECT_THAT(while_hlo1->operand(0), - op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); - EXPECT_THAT(while_hlo2->operand(0), - op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); -} + auto old_init1 = while_hlo1->operand(0); + auto old_init2 = while_hlo2->operand(0); -TEST_F(CopyInsertionTest, SwizzlingWhile) { - // Test a while instruction with a body which permutes its tuple parameter - // elements. - auto module = CreateNewModule(); - const Shape loop_state_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - // Body simply interchanges the two tuple elements in the loop state. - auto body_builder = HloComputation::Builder("body"); - auto body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto body_element_0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); - auto body_element_1 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({body_element_1, body_element_0})); - HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); - - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); - cond_builder.AddInstruction(HloInstruction::CreateUnary( - cond_constant->shape(), HloOpcode::kNot, cond_constant)); - HloComputation* condition = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); - auto tuple = builder.AddInstruction( - HloInstruction::CreateTuple({constant1, constant2})); - auto xla_while = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); - module->AddEntryComputation(builder.Build()); - - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 6); - - // The loop state elements should be copied at the parameter and at the root - // with a control edge in between (see DeepCopyAndAddControlEdges). This is - // technically one more copy than is strictly necessary, but in order to have - // only three copies the copies of different loop state elements must be - // ordered with a control edge. - EXPECT_EQ(CountCopies(*body), 4); - EXPECT_EQ(CountControlEdges(*body), 2); - - EXPECT_THAT(body->root_instruction(), - op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); - - EXPECT_EQ(CountCopies(*module->entry_computation()), 2); - EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); -} - -TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { - // Test a while instruction with a body which permutes its tuple parameter - // elements and applies one operation to one of the elements. The addition of - // the operation (instruction) on the element makes the live range of the - // respective input and output elements different than if the instruction were - // not there (as in the SwizzlingWhile test above). - auto module = CreateNewModule(); - const Shape loop_state_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - // Body interchanges the two tuple elements in the loop state and negates one - // of them. - auto body_builder = HloComputation::Builder("body"); - auto body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto body_element_0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); - auto body_element_1 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); - auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( - scalar_shape_, HloOpcode::kNegate, body_element_1)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({negate, body_element_0})); - HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); - - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); - cond_builder.AddInstruction(HloInstruction::CreateUnary( - cond_constant->shape(), HloOpcode::kNot, cond_constant)); - HloComputation* condition = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); - auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0))); - auto tuple = builder.AddInstruction( - HloInstruction::CreateTuple({constant1, constant2})); - auto xla_while = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); - module->AddEntryComputation(builder.Build()); - - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 6); - - // The loop state elements should be copied at the parameter and at the root - // with a control edge in between (see DeepCopyAndAddControlEdges). - EXPECT_EQ(CountCopies(*body), 4); - EXPECT_EQ(CountControlEdges(*body), 2); - - EXPECT_THAT( - body->root_instruction(), - op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); - - EXPECT_EQ(CountCopies(*module->entry_computation()), 2); - EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); -} - -TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { - // Test a while instruction with a body which permutes it's tuple parameter - // elements similar to SwizzlinWhile above. However, in this test the input to - // the while body is a single constant (both loop state elements are the same - // constant). This means no copies are necessary because both loop state - // elements are the same so interchanging them is a no-op. - auto module = CreateNewModule(); - const Shape loop_state_shape = - ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - - // Body simply interchanges the two tuple elements in the loop state. - auto body_builder = HloComputation::Builder("body"); - auto body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto body_element_0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); - auto body_element_1 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({body_element_1, body_element_0})); - HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); - - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); - cond_builder.AddInstruction(HloInstruction::CreateUnary( - cond_constant->shape(), HloOpcode::kNot, cond_constant)); - HloComputation* condition = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0))); - auto tuple = - builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); - builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); - module->AddEntryComputation(builder.Build()); - - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 2); - EXPECT_EQ(CountCopies(*body), 0); - - EXPECT_EQ(CountCopies(*module->entry_computation()), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(), op::Copy())); -} - -TEST_F(CopyInsertionTest, SequentialWhiles) { - // Construct a computation with a series of sequential while instructions - // containing four loop state elements: - // - // element 0 is passed to each while directly from an entry parameter. - // - // element 1 is passed transparently in series through all the while bodies. - // - // element 2 is negated in each while body. (in-place possible) - // - // element 3 is reversed in each while body. (in-place not possible) - // - const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); - const Shape loop_state_shape = ShapeUtil::MakeTupleShape( - {element_shape, element_shape, element_shape, element_shape}); - - auto module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto param_0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, element_shape, "param_0")); - auto param_1 = builder.AddInstruction( - HloInstruction::CreateParameter(1, element_shape, "param_1")); - auto param_2 = builder.AddInstruction( - HloInstruction::CreateParameter(2, element_shape, "param_2")); - auto param_3 = builder.AddInstruction( - HloInstruction::CreateParameter(3, element_shape, "param_3")); - - // The number of sequential kWhile instructions. - const int kNumWhiles = 3; - - HloInstruction* prev_element_1 = param_1; - HloInstruction* prev_element_2 = param_2; - HloInstruction* prev_element_3 = param_3; - - // Vector containing all of the while instructions. - std::vector<const HloInstruction*> whiles; - for (int i = 0; i < kNumWhiles; ++i) { - auto body_builder = HloComputation::Builder("body"); - auto body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto body_element_0 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); - auto body_element_1 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); - auto body_element_2 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); - auto body_element_3 = body_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); - auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( - element_shape, HloOpcode::kNegate, body_element_2)); - auto reverse = body_builder.AddInstruction( - HloInstruction::CreateReverse(element_shape, body_element_3, {0})); - body_builder.AddInstruction(HloInstruction::CreateTuple( - {body_element_0, body_element_1, negate, reverse})); - HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); - - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "param")); - auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); - cond_builder.AddInstruction(HloInstruction::CreateUnary( - cond_constant->shape(), HloOpcode::kNot, cond_constant)); - HloComputation* condition = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( - {param_0, prev_element_1, prev_element_2, prev_element_3})); - - auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape, condition, body, while_init)); - whiles.push_back(xla_while); - if (i != kNumWhiles - 1) { - prev_element_1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); - prev_element_2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); - prev_element_3 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); - } - } - - module->AddEntryComputation(builder.Build()); - - InsertCopies(module.get()); - - // Each while body has one copy. And each loop state element is copied once in - // the entry computation. - EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); - - // Each while body should have exactly one copy for element three which is an - // op (kReverse) which cannot be done in place. - for (const HloInstruction* xla_while : whiles) { - EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); - } - - EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), - op::Copy(), op::Copy())); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), - op::GetTupleElement())); -} - -TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { - // Test a while body and condition which are each simply a constant (root of - // computation is a constant). Each constant should be copied. The copy in the - // condition is not strictly necessary, but added due to b/32248867. - auto module = CreateNewModule(); - auto builder = HloComputation::Builder(TestName()); - auto param_0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); - - auto body_builder = HloComputation::Builder("body"); - body_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0))); - HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); - - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape_, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); - HloComputation* condition = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto xla_while = builder.AddInstruction( - HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); - - module->AddEntryComputation(builder.Build()); - - InsertCopies(module.get()); - - EXPECT_EQ(CountCopies(*module), 3); - - EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); - EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); - EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant())); -} - -std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) { - auto builder = HloComputation::Builder("trivial_condition"); - builder.AddInstruction( - HloInstruction::CreateParameter(0, shape, "loop_state")); - auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0<bool>(false))); - builder.AddInstruction(HloInstruction::CreateUnary( - constant->shape(), HloOpcode::kNot, constant)); - return builder.Build(); -} - -std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() { - auto builder = HloComputation::Builder("benchmark_loop_body"); - const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); - const Shape loop_state_shape = - ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); - HloInstruction* param = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); - HloInstruction* element_0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, param, 0)); - HloInstruction* element_1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, param, 1)); - HloInstruction* element_2 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(element_shape, param, 2)); - - HloInstruction* rev_1 = builder.AddInstruction( - HloInstruction::CreateReverse(element_shape, element_1, {0})); - HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( - element_shape, HloOpcode::kAdd, element_1, element_2)); - - builder.AddInstruction( - HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); - return builder.Build(); -} - -void BM_SequentialWhiles(int num_iters, int num_whiles) { - // This benchmark constructs a chain of sequential while instructions. - tensorflow::testing::StopTiming(); - for (int i = 0; i < num_iters; ++i) { - HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); - - auto builder = HloComputation::Builder("BM_SequentialWhiles"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {42}), "x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {42}), "y")); - HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(F32, {42}), "z")); - HloInstruction* init = - builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); - - HloInstruction* prev_loop_state = init; - for (int w = 0; w < num_whiles; ++w) { - HloComputation* condition = - module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); - HloComputation* body = - module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); - prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( - init->shape(), condition, body, prev_loop_state)); - } - module.AddEntryComputation(builder.Build()); - - CopyInsertion copy_insertion; - - tensorflow::testing::StartTiming(); - ASSERT_IS_OK(copy_insertion.Run(&module).status()); - tensorflow::testing::StopTiming(); - - // The entry computation should have three copies, and each body has one. - ASSERT_EQ(CountCopies(module), 3 + num_whiles); - } -} - -void BM_ParallelWhiles(int num_iters, int num_whiles) { - // This benchmark constructs a fan-out of parallel while instructions. - tensorflow::testing::StopTiming(); - for (int i = 0; i < num_iters; ++i) { - HloModuleConfig config; - config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); - HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), - config); - - auto builder = HloComputation::Builder("BM_ParallelWhiles"); - HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {42}), "x")); - HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {42}), "y")); - HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( - 2, ShapeUtil::MakeShape(F32, {42}), "z")); - HloInstruction* init = - builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); - - HloInstruction* sum = nullptr; - for (int w = 0; w < num_whiles; ++w) { - HloComputation* condition = - module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); - HloComputation* body = - module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); - - HloInstruction* xla_while = builder.AddInstruction( - HloInstruction::CreateWhile(init->shape(), condition, body, init)); - - if (sum == nullptr) { - sum = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); - } else { - HloInstruction* element_0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); - sum = builder.AddInstruction(HloInstruction::CreateBinary( - x->shape(), HloOpcode::kAdd, sum, element_0)); - } - } - module.AddEntryComputation(builder.Build()); - - CopyInsertion copy_insertion; + InsertCopies(module_.get()); - tensorflow::testing::StartTiming(); - ASSERT_IS_OK(copy_insertion.Run(&module).status()); - tensorflow::testing::StopTiming(); + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Copy(old_init1->operand(0)), + op::Copy(old_init1->operand(1)), + op::Copy(old_init1->operand(2)))); - // Each body receives of copy of two of the parameters (the corresponding - // elements in the body are modifed), and there is one copy in each body. - ASSERT_EQ(CountCopies(module), 3 * num_whiles); - } + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Copy(old_init2->operand(0)), + op::Copy(old_init2->operand(1)), + op::Copy(old_init2->operand(2)))); + + // Verifies the init tuples after copy insertion is distinct. + points_to_analysis = + TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + const auto& points_to1 = + points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); + EXPECT_TRUE(points_to1.IsDistinct()); + + const auto& points_to2 = + points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); + EXPECT_TRUE(points_to2.IsDistinct()); } -BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); -BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index af2bd6d5d7..46e83282d5 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -243,81 +243,6 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_; }; - -// This copy insertion pass is a hack to address deficiencies in buffer -// assignment. Buffer assignment uses TuplePointsToAnalysis which is -// computation-scoped and thus has limited visibility across computation -// boundaries. However, CopyInsertion uses module-scoped HloAliasAnalysis and -// expects buffer assignment to have the same understanding of the graph. This -// mismatch manifests in the parallel cpu backend, where the HLO outlining -// results is a minefield of potential problems. This pass conservatively adds -// copies to avoid any potential problems in buffer assignemnt. -// -// Technically these issues exist in all the backends. However, they only -// manifest in the parallel cpu backend because of the outlining. Moving this -// into the main copy insertion pass results in performance regressions n the -// other backends. -// -// TODO(b/62548313): Remove this. -class CpuParallelCopyInsertion : public HloPassInterface { - public: - tensorflow::StringPiece name() const override { - return "cpu-parallel-copy-insertion"; - } - - StatusOr<bool> Run(HloModule* module) override { - // Copy roots of all non-entry sequentially-called (eg, kCall, kWhile) - // computations. - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); - TF_RETURN_IF_ERROR( - call_graph->VisitNodes([module](const CallGraphNode& node) -> Status { - if (node.context() == CallContext::kSequential && - !node.caller_callsites().empty()) { - TF_ASSIGN_OR_RETURN(HloInstruction * root_copy, - node.computation()->DeepCopyInstruction( - node.computation()->root_instruction())); - node.computation()->set_root_instruction(root_copy); - } - return Status::OK(); - })); - - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow, - HloDataflowAnalysis::Run(module)); - - // Add copies to the operand of dynamic update slices which have read-only - // values (constants and parameters). Buffer assignment which is based on - // computation-scoped tuple points-to analysis does not properly track these - // read-only values across kCall instructions. This can result in cases - // where a outlined computation parameter operand of a dynamic update slice - // aliases a constant or parameter in the entry computation and the dynamic - // update slice is attempted in-place. - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { - HloInstruction* operand = instruction->mutable_operand(0); - for (const HloValue* value : - dataflow->GetValueSet(operand).values()) { - if (value->defining_instruction()->opcode() == - HloOpcode::kConstant || - value->defining_instruction()->opcode() == - HloOpcode::kParameter) { - HloInstruction* operand_copy = - instruction->parent()->AddInstruction( - HloInstruction::CreateUnary(operand->shape(), - HloOpcode::kCopy, operand)); - TF_RETURN_IF_ERROR( - operand->ReplaceUseWith(instruction, operand_copy)); - break; - } - } - } - } - } - - return true; - } -}; - } // namespace Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { @@ -406,16 +331,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass<HloDCE>(); - pipeline.AddPass<FlattenCallGraph>(); pipeline.AddPass<CopyInsertion>(); if (options::CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. pipeline.AddPass<ParallelizationPreparation>(max_parallelism, ShapeSizeBytesFunction()); - pipeline.AddPass<CpuParallelCopyInsertion>(); } pipeline.AddPass<HloDCE>(); + pipeline.AddPass<FlattenCallGraph>(); return pipeline.Run(module).status(); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index df7e128217..b9c4adce93 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -350,8 +350,8 @@ cc_library( ":ir_emission_utils", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", - "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/compiler/xla/service:logical_buffer", + "//tensorflow/compiler/xla/service:tuple_points_to_analysis", "//tensorflow/core:lib", ], ) @@ -573,14 +573,11 @@ tf_cc_test( deps = [ ":instruction_fusion", ":while_transformer", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", - "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc index f7a3260641..3dc8555201 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc @@ -22,53 +22,41 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/logging.h" namespace xla { namespace gpu { -StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy( - HloInstruction* hlo) { - auto copy_it = inserted_copies_.find(hlo); - if (copy_it == inserted_copies_.end()) { - HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); - inserted_copies_.insert({hlo, copy}); - return copy; - } else { - return copy_it->second; - } -} - StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) { - CopyInsertion generic_copy_insertion; + TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module)); - TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); - - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow, - HloDataflowAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(auto points_to_analysis, + TuplePointsToAnalysis::Run(module)); // Make sure all operands of a library call are in memory instead of constants - // in IR. + // in IR. The top-level (index {}) of the points-to set of each operand + // indicates the source(s) of the array buffer. If any of these are constant, + // then add a copy to materialize the array. HloComputation* computation = module->entry_computation(); for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { if (ImplementedAsLibraryCall(*hlo)) { for (int64 i = 0; i < hlo->operand_count(); ++i) { HloInstruction* operand = hlo->mutable_operand(i); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - bool copy_operand = false; - for (const HloValue* value : dataflow->GetValueSet(operand).values()) { - if (value->defining_instruction()->opcode() == HloOpcode::kConstant) { - copy_operand = true; - break; - } - } - if (copy_operand) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); + const PointsToSet& points_to = + points_to_analysis->GetPointsToSet(operand); + const auto& element = points_to.element(/*index=*/{}); + if (std::any_of(element.begin(), element.end(), + [](const LogicalBuffer* buffer_source) { + return buffer_source->instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + CopyInsertion::FindOrInsertCopy(operand)); TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); changed = true; } @@ -76,31 +64,6 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) { } } - // Init values of a while nodes cannot be constants. Insert copies for any - // constants found at the operand of a while. - tensorflow::gtl::FlatSet<HloInstruction*> copied_constants; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - for (auto& pair : - dataflow->GetInstructionValueSet(instruction->operand(0))) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->opcode() == - HloOpcode::kConstant && - !ContainsKey(copied_constants, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - copied_constants.insert(constant); - } - } - } - } - } - } - return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/copy_insertion.h index 2ca9a13fd8..11077dad2e 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { @@ -25,20 +25,9 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public HloPassInterface { +class GpuCopyInsertion : public CopyInsertion { public: - tensorflow::StringPiece name() const override { return "copy-insertion"; } - StatusOr<bool> Run(HloModule* module) override; - - protected: - // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making - // duplicate copies. - StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted to materialize operands of library - // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 80dccf5b65..2caa8f6051 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -220,8 +220,9 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass<HloDCE>(); - pipeline.AddPass<FlattenCallGraph>(); pipeline.AddPass<GpuCopyInsertion>(); + pipeline.AddPass<HloDCE>(); + pipeline.AddPass<FlattenCallGraph>(); return pipeline.Run(hlo_module).status(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index f16daa0b54..44188473d3 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -17,12 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" -#include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -36,6 +33,8 @@ class WhileTransformerTest : public HloTestBase { : module_(CreateNewModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), + loop_state_shape_(ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} std::unique_ptr<HloComputation> BuildConditionComputation( @@ -43,8 +42,8 @@ class WhileTransformerTest : public HloTestBase { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit))); - auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( - 0, GetLoopStateShape(tuple_index), "loop_state")); + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); @@ -59,8 +58,8 @@ class WhileTransformerTest : public HloTestBase { const int64 increment) { auto builder = HloComputation::Builder(TestName() + ".Body"); // Create param instruction to access loop state. - auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( - 0, GetLoopStateShape(ind_var_tuple_index), "loop_state")); + auto loop_state = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); // Update the induction variable GTE(ind_var_tuple_index). auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -74,7 +73,7 @@ class WhileTransformerTest : public HloTestBase { data_shape_, loop_state, data_tuple_index)); // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {})); + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. @@ -99,9 +98,8 @@ class WhileTransformerTest : public HloTestBase { HloInstruction::CreateTuple({induction_var_init, data_init})) : builder.AddInstruction( HloInstruction::CreateTuple({data_init, induction_var_init})); - auto while_hlo = builder.AddInstruction( - HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index), - condition, body, loop_state_init)); + auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape_, condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -117,34 +115,18 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier([](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); - }); - TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; - TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); - } - - Shape GetLoopStateShape(const int64 ind_var_tuple_index) { - if (ind_var_tuple_index == 0) { - return ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_}); - } else { - return ShapeUtil::MakeTupleShape( - {data_shape_, induction_variable_shape_}); - } + EXPECT_IS_OK(copy_insertion.Run(module_.get()).status()); } std::unique_ptr<HloModule> module_; Shape induction_variable_shape_; Shape data_shape_; + Shape loop_state_shape_; Shape condition_result_shape_; }; -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -155,16 +137,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); + ASSERT_TRUE(result.ok()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple<int64, int64, int64>(0, 10, 1))); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { +TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -175,16 +154,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - TF_ASSERT_OK(result.status()); + ASSERT_TRUE(result.ok()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple<int64, int64, int64>(0, 10, 1))); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { +TEST_F(WhileTransformerTest, InvalidLoopLimit) { // Build computation with invalid loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); @@ -200,10 +176,7 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { HasSubstr("Loop start must be less than loop limit.")); } -// TODO(b/68830972): The while transformer is far too fragile. It patterns -// matches the exact expressions of opcodes. Re-enable when transformation is -// more general -TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { +TEST_F(WhileTransformerTest, InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 0fb11792b8..6f80994751 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -144,10 +144,8 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - tensorflow::gtl::FlatSet<const HloValue*>& old_value_set = - buffers_.at(old_buffer_number); - old_value_set.erase(&value); - if (old_value_set.empty()) { + buffers_.at(old_buffer_number).erase(&value); + if (buffers_.at(old_buffer_number).empty()) { buffers_.erase(old_buffer_number); } @@ -177,7 +175,7 @@ class BufferValueMap { // Value is init of a while (use is while). std::vector<BufferNumber> aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(2) << "use of value " << value.ToShortString() << ": " << use; + VLOG(1) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = @@ -413,7 +411,7 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run( HloModule* module) { - VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); + VLOG(1) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 1677c77f2e..8f595b45e9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -412,18 +412,16 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr<std::unique_ptr<HloComputation>> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, - const std::function<void(std::unique_ptr<HloComputation>)>& - add_fused_computation, + tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map, HloInstruction* fusion_instruction) { std::vector<std::unique_ptr<HloInstruction>> instructions; tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction, - HloInstruction::CreateFromProto( - module, instruction_proto, instruction_map, - computation_map, add_fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr<HloInstruction> instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } @@ -533,7 +531,6 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { - LOG(FATAL) << "DEATH!"; return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " "has incompatible shape", diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 3208197f89..c9782cc981 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -152,18 +152,12 @@ class HloComputation { // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. - // fusion_instruction: if non-null then the newly created computation will - // be + // fusion_instruction: if non-null then the newly created computation will be // constructed as a fused computation with this instruction as its fusion // parent. static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto( HloModule* module, const HloComputationProto& proto, - const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, - const std::function<void(std::unique_ptr<HloComputation>)>& - add_fused_computation, + tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map, HloInstruction* fusion_instruction = nullptr); // Gets the instructions in this computation. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 2286cfe488..92261bce62 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -75,41 +75,11 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, std::forward_as_tuple(value_id, instruction, index, is_phi)); CHECK(emplaced.second); - VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString(); - return &emplaced.first->second; } -void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { - HloValue& value = values_.at(value_id); - VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; - - value_ids_to_delete_.push_back(value_id); -} - -void HloDataflowAnalysis::DeleteMarkedValues() { - // Verify that no marked-for-deletion values are in any of the value sets. - tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(), - value_ids_to_delete_.end()); - for (const auto& pair : value_sets_) { - const HloInstruction* instruction = pair.first; - const InstructionValueSet& instruction_value_set = pair.second; - for (const auto& index_value_set : instruction_value_set) { - const HloValueSet& value_set = index_value_set.second; - for (const HloValue* value : value_set.values()) { - DCHECK(!ContainsKey(id_set, value->id())) - << "Value " << value->ToShortString() - << " marked for deletion, but still exists in value set for " - "instruction " - << instruction->name(); - } - } - } - - for (HloValue::Id value_id : value_ids_to_delete_) { - values_.erase(value_id); - } - value_ids_to_delete_.clear(); +void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { + values_.erase(value_id); } string HloDataflowAnalysis::ToString() const { @@ -151,7 +121,6 @@ bool HloDataflowAnalysis::Phi( HloInstruction* instruction, tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) { CHECK(ssa_form_); - VLOG(4) << "Phi(" << instruction->name() << ")"; for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); @@ -214,7 +183,7 @@ bool HloDataflowAnalysis::Phi( } else if (current_value != &new_value) { if (current_value_defined_here) { // Remove the existing phi. - MarkValueForDeletion(current_value->id()); + DeleteHloValue(current_value->id()); } value_set.Clear(); value_set.AddValue(&new_value); @@ -224,8 +193,7 @@ bool HloDataflowAnalysis::Phi( // Multiple distinct values reach this point. A phi value is // necessary. CHECK_GT(input_value_ids.size(), 1); - if (current_value == nullptr || - !(current_value->is_phi() && current_value_defined_here)) { + if (current_value == nullptr || !current_value->is_phi()) { value_set.Clear(); value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); changed = true; @@ -468,13 +436,11 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } } -void HloDataflowAnalysis::Propagate() { +void HloDataflowAnalysis::UpdateInstructionsAndPropagate( + tensorflow::gtl::ArraySlice<HloInstruction*> instructions) { std::queue<HloInstruction*> worklist; - - for (HloComputation* computation : module_->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - worklist.push(instruction); - } + for (HloInstruction* instruction : instructions) { + worklist.push(instruction); } while (!worklist.empty()) { @@ -631,10 +597,18 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); - dataflow_analysis->Propagate(); - // Delete all values marked for deletion. - dataflow_analysis->DeleteMarkedValues(); + // Construct list of all instructions to initialize the worklist to propagate + // the data flow. For efficiency sort the instruction in post order so + // producers appear before consumers. + std::vector<HloInstruction*> all_instructions; + for (const HloComputation* computation : module->MakeComputationPostOrder()) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + all_instructions.push_back(instruction); + } + } + dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); // Add in positions to all values. for (const HloComputation* computation : module->computations()) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 49b1343873..207e553bf7 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -126,16 +126,13 @@ class HloDataflowAnalysis { HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); - // Mark the HloValue with the given ID for deletion. - void MarkValueForDeletion(HloValue::Id value_id); - - // Delete all HloValues marked for deletion. Should be called after - // propagation is complete. - void DeleteMarkedValues(); + // Delete the HloValue with the given ID. + void DeleteHloValue(HloValue::Id value_id); // Constructs and initializes the InstructionValueSets of all instructions to // contain exactly the HloValues defined by each instruction. These values can - // then propagated throughout the HLO graph by calling Propagate. + // then propagated throughout the HLO graph by calling + // UpdateInstructionsAndPropagate. Status InitializeInstructionValueSets(); // Updates the value set of the given instruction based on the values flowing @@ -153,8 +150,10 @@ class HloDataflowAnalysis { bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); - // Propagate the dataflow through the module. - void Propagate(); + // Update the value sets of the given instructions and propagate the + // changes to fixed point. + void UpdateInstructionsAndPropagate( + tensorflow::gtl::ArraySlice<HloInstruction*> instructions); // Return the result of the SSA Phi function applied to the given inputs at // the given instruction. If skip_top_level is true, then the top level of the @@ -190,11 +189,6 @@ class HloDataflowAnalysis { // A map from instruction to InstructionValueSet. std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_; - // Values marked for deletion during construction. We don't delete them - // immediately because references to them may still remain in ValueSets. After - // construction, these values are deleted. - std::vector<HloValue::Id> value_ids_to_delete_; - // A vector containing all HloValues sorted by HloValue::Id. std::vector<const HloValue*> values_vector_; diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 40e67c8780..a4921232f5 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -37,9 +37,6 @@ namespace xla { StatusOr<bool> HloDCE::Run(HloModule* module) { bool changed = false; - VLOG(2) << "Before dce:"; - XLA_VLOG_LINES(2, module->ToString()); - for (auto* computation : module->MakeNonfusionComputations()) { std::unordered_set<HloInstruction*> live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( @@ -61,8 +58,6 @@ StatusOr<bool> HloDCE::Run(HloModule* module) { } for (HloInstruction* dead_root : dead_roots) { - VLOG(1) << "Removing dead root " << dead_root->ToString() - << " and it's unused operands"; TF_RETURN_IF_ERROR( computation->RemoveInstructionAndUnusedOperands(dead_root)); changed = true; @@ -92,9 +87,6 @@ StatusOr<bool> HloDCE::Run(HloModule* module) { } } - VLOG(2) << "After dce:"; - XLA_VLOG_LINES(2, module->ToString()); - return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2c7e735a1c..e09899e48d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -51,9 +51,7 @@ using ::tensorflow::strings::StrCat; StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map, - const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, - const std::function<void(std::unique_ptr<HloComputation>)>& - add_fused_computation) { + tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -79,19 +77,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), - computation_map, add_fused_computation, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back(fused_computation.get()); - add_fused_computation(std::move(fused_computation)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr<HloComputation> fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), computation_map, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back( + module->AddEmbeddedComputation(std::move(fused_computation))); } else { for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(computation_map, computation_name)) + TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) << "No computation named " << computation_name; instruction->called_computations_.push_back( - computation_map.at(computation_name)); + computation_map->at(computation_name)); } } @@ -2011,10 +2009,8 @@ string HloInstruction::ToCategory() const { bool saw_rank_1 = false; bool saw_higher_rank = false; for (const auto* operand : operands()) { - if (!ShapeUtil::IsTuple(operand->shape())) { - saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; - saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; - } + saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; + saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; } if (saw_rank_1 && saw_higher_rank) { return "rank-1-broadcast binary fusion"; @@ -2299,8 +2295,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { template Status HloInstruction::Visit(DfsHloVisitor* visitor); template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); -using DFSStack = tensorflow::gtl::InlinedVector< - std::pair<HloInstruction::Id, HloInstruction*>, 16>; +using DFSStack = + tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>; // Push "child" onto the dfs_stack if not already visited. Returns false if a // cycle was detected, and true otherwise. @@ -2308,7 +2304,7 @@ template <typename Visitor> inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack, HloInstruction* child) { CHECK(child != nullptr); - const HloInstruction::Id id = child->unique_id(); + const int id = child->unique_id(); CHECK_GE(id, 0) << "instruction may not have a parent computation"; switch (visitor->GetVisitState(id)) { case Visitor::kVisiting: @@ -2325,8 +2321,8 @@ inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack, } using InternalCompareFunction = - std::function<bool(std::pair<HloInstruction::Id, const HloInstruction*>, - std::pair<HloInstruction::Id, const HloInstruction*>)>; + std::function<bool(std::pair<int, const HloInstruction*>, + std::pair<int, const HloInstruction*>)>; template <typename Visitor> static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, @@ -2345,7 +2341,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, do { DCHECK(!dfs_stack.empty()); - HloInstruction::Id current_id = dfs_stack.back().first; + int current_id = dfs_stack.back().first; HloInstruction* current_node = dfs_stack.back().second; CHECK_GE(current_id, 0) << current_id << ": " << current_node << ": instruction may not have parent computation"; @@ -2424,13 +2420,13 @@ Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; - InternalCompareFunction func = - [&operand_order](std::pair<HloInstruction::Id, const HloInstruction*> a, - std::pair<HloInstruction::Id, const HloInstruction*> b) { - // Call the client's comparison function on the actual HloInstruction* - // objects (ignoring the internal ids we also have in our stack entries) - return operand_order(a.second, b.second); - }; + InternalCompareFunction func = [&operand_order]( + std::pair<int, const HloInstruction*> a, + std::pair<int, const HloInstruction*> b) { + // Call the client's comparison function on the actual HloInstruction* + // objects (ignoring the internal ids we also have in our stack entries) + return operand_order(a.second, b.second); + }; TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func, /*ignore_control_predecessors=*/false)); if (call_finish_visit) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 411f926a87..4d8fe6bc10 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -83,16 +83,12 @@ class HloInstruction { // must contain all operands of the newly constructed instruction. // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed instruction - // calls. - // add_fused_computation: A function to call to add a fused - // computation. Used (clearly) when the instruction is a fusion - // instruction. + // calls. If the instruction is a fusion instruction, then the fusion + // computation is added to this map and the module. static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map, - const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map, - const std::function<void(std::unique_ptr<HloComputation>)>& - add_fused_computation); + tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map); // Creates a parameter-retrieving instruction. static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number, @@ -981,8 +977,7 @@ class HloInstruction { void UniquifyName(NameUniquer* name_uniquer); // Set the unique id for this instruction to "id" - using Id = int; - void SetUniqueId(Id id) { + void SetUniqueId(int id) { CHECK_EQ(unique_id_, -1); // Should not be assigned already CHECK_GE(id, 0); unique_id_ = id; @@ -990,7 +985,7 @@ class HloInstruction { // Return the unique ID assigned to this node via SetUniqueId (or -1 // if no id has been assigned yet). - Id unique_id() const { return unique_id_; } + int unique_id() const { return unique_id_; } // Sets the debug metadata for this instruction. void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; } @@ -1093,7 +1088,7 @@ class HloInstruction { // Returns how this instruction uses elements of its `i`th operand. UseKind OperandElementUse(int64 i) const; - Id unique_id_; // Unique to this HloInstruction within a HloModule + int unique_id_; // Unique to this HloInstruction within a HloModule // Opcode for this instruction. HloOpcode opcode_; diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index d2cee6f8b1..659f3d8c26 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -296,16 +296,9 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( tensorflow::gtl::FlatMap<string, HloComputation*> computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr<HloComputation> computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, computation_map, - /*add_fused_computation=*/ - [&module](std::unique_ptr<HloComputation> fused_computation) { - module->AddComputationInternal(std::move(fused_computation), - /*is_entry=*/false, - /*uniquify_names=*/false); - })); + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, &computation_map)); CHECK_NE(computation.get(), nullptr); TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); string computation_name = computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index 1f9a989961..e6cf0d37b8 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -184,7 +184,7 @@ void HloValue::AddPosition(HloInstruction* instruction, live_out_of_module_ = true; } - if (instruction == defining_instruction()->parent()->root_instruction()) { + if (instruction == instruction->parent()->root_instruction()) { live_out_of_computation_ = true; } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index 2ecf57ad3d..34899b7400 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -55,34 +55,22 @@ static Status EmitDynamicUpdateSliceInPlaceImpl( // Calculate output_index, where we'll write the value from update. For // each dimension, // - // output_index[dim] = (start_index[dim] + update_index[dim]) + // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. // IrArray::Index output_index(rank); for (int64 i = 0; i < rank; ++i) { + llvm::Value* dim_size = llvm::ConstantInt::get( + update_index[i]->getType(), output_shape.dimensions(i)); llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( start_index[i], update_index[i]->getType()); - output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]); - } - - // Check if 'index' intersects start/end indices. If it does not (indices - // are out of bounds) then no update is performed. - llvm::Value* in_bounds = llvm::ConstantInt::get(ir_builder->getInt1Ty(), 1); - for (int64 i = 0; i < rank; ++i) { - llvm::Value* dim_size = llvm::ConstantInt::get( - output_index[i]->getType(), output_shape.dimensions(i)); - in_bounds = ir_builder->CreateAnd( - in_bounds, ir_builder->CreateICmpSLT(output_index[i], dim_size), - "in_bounds"); + output_index[i] = ir_builder->CreateURem( + ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); } // Do output[output_index] = update[update_index]. TF_ASSIGN_OR_RETURN(llvm::Value * update_data, update_array_generator(update_index)); - llvm::Value* input_data = - output_array.EmitReadArrayElement(output_index, ir_builder); - llvm::Value* to_write_data = - ir_builder->CreateSelect(in_bounds, update_data, input_data); - output_array.EmitWriteArrayElement(output_index, to_write_data, ir_builder); + output_array.EmitWriteArrayElement(output_index, update_data, ir_builder); return Status::OK(); }; diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 5a012c93d6..4920f17a7e 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -180,8 +180,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. -XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc index f568f58154..92b2b1ee77 100644 --- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc +++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" GTEST_API_ int main(int argc, char** argv) { std::vector<tensorflow::Flag> flag_list; @@ -31,7 +30,5 @@ GTEST_API_ int main(int argc, char** argv) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; } - int result = RUN_ALL_TESTS(); - tensorflow::testing::RunBenchmarks(); - return result; + return RUN_ALL_TESTS(); } |