aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-03 13:26:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 13:31:12 -0700
commit456929281592f14d50443cfbdaa2f6b36167a134 (patch)
tree8a1a18245a6ef9baed8bd9a9f35b7c250ab64901
parent5b166f495ae79b6e8144bbd3a1109f4b8d9fb1aa (diff)
Rollback copy insertion change because it results in a DCHECK with an internal model.
END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 174423881 PiperOrigin-RevId: 174505237
-rw-r--r--tensorflow/compiler/xla/service/BUILD10
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc78
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc1526
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h34
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc948
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc78
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD7
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_insertion.cc73
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_insertion.h15
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc61
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc54
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h17
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc13
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ops.cc24
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/xla_internal_test_main.cc5
25 files changed, 879 insertions, 2200 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7fe06655cf..c6f6c6c38b 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1644,14 +1644,10 @@ cc_library(
deps = [
":buffer_liveness",
":hlo",
- ":hlo_alias_analysis",
- ":hlo_dce",
- ":hlo_graph_dumper",
- ":hlo_ordering",
":hlo_pass",
":liveness_util",
":logical_buffer",
- ":tuple_simplifier",
+ ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
@@ -1666,17 +1662,15 @@ tf_cc_test(
deps = [
":copy_insertion",
":hlo",
- ":hlo_graph_dumper",
":hlo_matchers",
+ ":tuple_points_to_analysis",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 5c9714d7ea..8536429846 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -1235,6 +1235,7 @@ const LogicalBuffer* AddBufferToColocatedSet(
// CopyInsertion ensures root points-to set is unambiguous and distinct.
const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
DCHECK(!points_to.IsAmbiguous());
+ DCHECK(points_to.IsDistinct());
colocated_set->push_back(points_to.element(index)[0]);
return colocated_set->back();
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 4d4c5b953e..89410f42bd 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1538,6 +1538,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ auto output1 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -1554,8 +1556,10 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto body1 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output1}));
auto while1 = builder.AddInstruction(
- HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
+ HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
@@ -1672,37 +1676,34 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
- auto gte0 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
- auto gte1 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
- while0->shape(), HloOpcode::kAdd, gte0, gte1));
-
+ while0->shape(), HloOpcode::kAdd, while0, while1));
module->AddEntryComputation(builder.Build());
+ RunCopyInsertion(module.get());
+
{
FlattenCallGraph flatten;
TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
EXPECT_TRUE(result);
}
- RunCopyInsertion(module.get());
-
auto sequence =
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo sequence for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
- sequence[module->entry_computation()] = {
- input1, weights1, one, output1, while1->operand(0), while1,
- input0, weights0, zero, output0, while0->operand(0), while0,
- gte0, gte1, root_add};
+ std::vector<const HloInstruction*> sequence_for_buffer_assigment = {
+ input1, weights1, one, output1, tuple1, while1, input0,
+ weights0, zero, output0, tuple0, while0, root_add};
// If this ASSERT_TRUE fails, we constructed a bogus sequence above
// and this test itself is buggy.
- ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
+ ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment));
+
+ sequence[module->entry_computation()] =
+ std::move(sequence_for_buffer_assigment);
auto assignment =
BufferAssigner::Run(
@@ -1714,6 +1715,55 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
+// Test buffer assignment for while nodes with multiple uses.
+// TODO(b/37245345): Fix buffer assignment for this case.
+TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
+ auto module = MakeUnique<HloModule>(TestName());
+ auto builder = HloComputation::Builder(TestName());
+
+ auto input0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape_, "input0"));
+ auto weights0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, data_shape_, "weights0"));
+
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ auto output0 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+
+ auto cond0 =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
+ auto body0 =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+
+ auto tuple0 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output0}));
+ auto while0 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
+ auto while1 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0));
+
+ auto get0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
+ auto get1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1));
+ module->AddEntryComputation(builder.Build());
+
+ RunCopyInsertion(module.get());
+
+ {
+ FlattenCallGraph flatten;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
+ EXPECT_TRUE(result);
+ }
+
+ auto assignment = RunBufferAssignment(module.get());
+
+ EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
+}
+
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
auto module = MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 8f50b29dad..0453a698a0 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,17 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
-#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
+#include <memory>
+
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_dce.h"
-#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
-#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -33,1113 +31,597 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace {
-bool IsEntryParameterValue(const HloValue& value) {
- const HloComputation* computation = value.defining_instruction()->parent();
- return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
- computation == computation->parent()->entry_computation();
-}
-
-bool IsConstantValue(const HloValue& value) {
- return value.defining_instruction()->opcode() == HloOpcode::kConstant;
-}
-
-bool ValueIsReadOnly(const HloValue& value) {
- return IsConstantValue(value) || IsEntryParameterValue(value);
-}
+using tensorflow::gtl::FlatMap;
+using tensorflow::gtl::FlatSet;
-// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
-// 'indices_to_copy'. Add control edges from the respective kCopy instructions
-// in deep copy of 'from' to the respective kCopy instruction in the deep copy
-// of 'to'.
-//
-// Requirements: 'from' and 'to' must have compatible shapes.
+// InstructionCopier encapsulates indices at which to copy 'instruction'.
+// All 'instruction' users in 'copy_users' are updated to use the copy.
//
-// For example, suppose 'from' and 'to' are two-element tuples where index 0 is
-// the only index to copy. Prior to deep-copying we have:
+// Instruction copies are generated in two phases:
+// 1) Recording buffer indices at which 'instruction' requires copies (i.e.
+// setting 'indices_to_copy_[index]'=true).
+// 2) Inserting kCopy instructions based on indices recorded in phase 1).
+// *) Array instructions are copied by inserting a single kCopy instruction.
+// *) Tuple-shaped instructions are copied by recursively expanding tuples
+// (and tuple-shaped elements), and inserting kCopy instructions for any
+// tuple elements which require a copy. As the recursion unwinds, new tuple
+// instructions are added to gather the copied (and uncopied) references
+// into the output tuple (i.e. the copy of the tuple-shaped instruction).
//
+// Example two-element tuple with one element that needs a copy:
//
-// 'from'
-// |
-// ...
-// |
-// 'to'
+// original-instruction
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy |
+// \ /
+// Tuple // copied-instruction
//
-// DeepCopyAndAddControlEdges produces:
-//
-// 'from'
-// / \
-// GTE GTE
-// | |
-// Copy |
-// / \ /
-// | Tuple
-// | |
-// ctrl ...
-// edge |
-// | |
-// | 'to'
-// | / \
-// | GTE GTE
-// \ | |
-// Copy |
-// \ /
-// Tuple
-//
-StatusOr<std::pair<HloInstruction*, HloInstruction*>>
-DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
- const ShapeTree<bool>& indices_to_copy) {
- DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
- // to/from_copy_tree hold the kCopy instruction produces by the deep
- // copies. Elements which are not copied (indices_to_copy.element(index) ==
- // false) have nullptr at that index.
- ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
- /*init_value=*/nullptr);
- TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
- from->parent()->DeepCopyInstruction(
- from, &indices_to_copy, &from_copy_tree));
-
- ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * to_deep_copy,
- to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
-
- // Add control edges between the respective kCopy instructions.
- for (const auto& pair : from_copy_tree) {
- const ShapeIndex& index = pair.first;
- HloInstruction* from_copy = pair.second;
- HloInstruction* to_copy = to_copy_tree.element(index);
- if (from_copy == nullptr) {
- TF_RET_CHECK(to_copy == nullptr);
- continue;
- }
- TF_RET_CHECK(to_copy != nullptr);
- TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
+// As an optimization, if the original instruction is itself a Tuple
+// instruction, we elide the unnecessary extra GTE and Tuple instructions,
+// and just insert the copy into a new Tuple instruction, with control
+// dependencies to ensure the copy occurs after any possible interference.
+class InstructionCopier {
+ public:
+ InstructionCopier(HloInstruction* instruction,
+ const std::vector<HloInstruction*>& copy_users)
+ : instruction_(instruction),
+ copy_users_(copy_users),
+ indices_to_copy_(instruction->shape()),
+ control_predecessors_(instruction->shape()) {}
+
+ // Sets indices that are read-only, and thus do not need to be copied.
+ void SetReadOnlyIndices(const ShapeTree<bool>& read_only_indices) {
+ read_only_indices_ = read_only_indices;
}
- return std::make_pair(from_deep_copy, to_deep_copy);
-}
-
-// Compute the indices of the loop state which need copies in order to avoid
-// live range interference. Generally, an element in the loop state does not
-// need to be copied if the element is passed through transparently through the
-// body.
-//
-// Returns whether any indices need to be copied.
-bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
- const HloInstruction* xla_while,
- ShapeTree<bool>* indices_to_copy) {
- DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
-
- bool any_copies = false;
- const HloInstruction* init = xla_while->operand(0);
- for (auto& pair : *indices_to_copy) {
- const ShapeIndex& index = pair.first;
- bool& should_copy = pair.second;
- // If there is any ambiguity, then loop state must be copied.
- if (dataflow.GetValueSet(init, index).values().size() > 1 ||
- dataflow.GetValueSet(xla_while, index).values().size() > 1) {
- should_copy = true;
- } else {
- // If the output of the while instruction is not the same as the init
- // value of the while, then this element is not passed through the body
- // transparently and must be copied.
- should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
- dataflow.GetUniqueValueAt(init, index);
- }
- any_copies |= should_copy;
+ // Sets copy overrides, which are copy instructions to use at each index. This
+ // is used to share a single copy of read-only entry parameters and constants
+ // between multiple While loops.
+ void SetCopyOverrides(const ShapeTree<HloInstruction*>& copy_overrides) {
+ copy_overrides_ = copy_overrides;
}
- return any_copies;
-}
-
-// Add kCopy instructions around the given kWhile instruction to eliminate any
-// possible live range interference of HLO values assuming a dependency-based
-// ordering (HloDependencyOrdering). Copies are added conservatively. There
-// likely are copies which are not strictly necessary, but there are removed
-// later in the pass via CopyRemover.
-//
-//
-// Elements (each ShapeIndex) in the loop state are considered independently. A
-// copy is added to each element of the loop state which is modified in the
-// while body. For each such element, a total of three kCopy instructions are
-// added at following locations:
-//
-// (1) The init value is copied before the kWhile instruction. Before:
-//
-// (Init)
-// |
-// kWhile
-// |
-// ...
-//
-// After:
-//
-// (Init)
-// |
-// kCopy
-// |
-// kWhile
-// |
-// ...
-//
-// This copy is necessary in case the init value is simultaneously live
-// with the kWhile.
-//
-// (2) Copies are added to the parameter and root of the while body
-// computation. Before:
-//
-// kParameter
-// |
-// ...
-// |
-// (body root)
-//
-// After:
-//
-// kParameter
-// |
-// kCopy ----------+
-// | |
-// ... ctrl
-// | edge
-// (body root) |
-// | |
-// kCopy <---------+
-//
-// The root kCopy becomes the new root of the computation. Both copies are
-// necessary to any potential interference between the parameter value and
-// the root value. The control edge prevents potential interference
-// between the copies themselves.
-//
-// If the loop state is a tuple then the above kCopy instructions are a deep
-// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as
-// constructed by HloInstruction::DeepCopyInstruction.
-Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
- HloInstruction* xla_while) {
- VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
- TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
- ShapeTree<bool> indices_to_copy(xla_while->shape());
- if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
- &indices_to_copy)) {
- VLOG(2) << "No copies necessary for kWhile instruction "
- << xla_while->name();
- return Status::OK();
- }
+ // Returns true if all recorded indices are false (returns true otherwise).
+ bool HasAllIndicesFalse() const;
- VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
- for (auto& pair : indices_to_copy) {
- if (pair.second) {
- VLOG(2) << " " << pair.first;
- }
- }
+ // Records instruction buffer indices which point-to a Parameter or Constant.
+ Status RecordIndicesWhichPointToParamOrConstant(
+ const TuplePointsToAnalysis& points_to_analysis);
- // Deep copy init.
- HloInstruction* while_init = xla_while->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * while_init_copy,
- xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
- TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
+ // Records instruction buffer indices to copy which are necessary to ensure:
+ // *) PointsToSet of 'instruction_' is unambiguous and distinct.
+ // *) No liveness interference between 'instruction_' and 'other_instruction'.
+ //
+ // If 'read_only_indices_out' is non-null, read-only indices are set to true.
+ Status RecordIndicesToCopyForColocatingBuffers(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out);
- // Deep copy the parameter and the root. Extend a control edge from the copy
- // of the parameter value to the corresponding copy value of the root.
- HloComputation* body = xla_while->while_body();
- HloInstruction* param = body->parameter_instruction(0);
- HloInstruction* root = body->root_instruction();
+ // Records control predecessors to add for inserted copy instructions.
+ // 'parameter' must have the same shape as the instruction that will be
+ // copied, and must define all buffers in the shape. Control predecessors are
+ // only recorded for indices that have already been marked for copying.
+ Status RecordControlPredecessors(
+ const TuplePointsToAnalysis& points_to_analysis,
+ HloInstruction* parameter);
- // If param is the root then all indices should have been passed through the
- // while body and we should have returned early above.
- TF_RET_CHECK(param != root);
+ // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy',
+ // and replaces all uses for instructions in 'copy_users_' with copy.
+ // Returns the instruction which is a copy 'instruction'.
+ HloInstruction* Copy();
- // Copy users before making a deep copy of the parameter as the deep copy
- // will create new users of the parameter (eg, the GTE instructions of the
- // deep copy).
- std::vector<HloInstruction*> param_users = param->users();
+ HloInstruction* instruction() { return instruction_; }
- ShapeIndex current_index;
- TF_ASSIGN_OR_RETURN(auto pair,
- DeepCopyAndAddControlEdges(param, root, indices_to_copy));
+ const std::vector<HloInstruction*>& copy_users() const { return copy_users_; }
- HloInstruction* param_copy = pair.first;
- HloInstruction* root_copy = pair.second;
+ private:
+ // Does the given index represent a read-only buffer?
+ bool IsReadOnlyIndex(const ShapeIndex& index) const {
+ return !ShapeUtil::IsNil(read_only_indices_.shape()) &&
+ read_only_indices_.element(index);
+ }
- for (HloInstruction* user : param_users) {
- TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
+ // Returns the copy override at the given index, or nullptr.
+ HloInstruction* GetCopyOverride(const ShapeIndex& index) const {
+ return ShapeUtil::IsNil(copy_overrides_.shape())
+ ? nullptr
+ : copy_overrides_.element(index);
}
- body->set_root_instruction(root_copy);
+ // Records instruction buffer indices which have ambiguous or non-distinct
+ // points-to sets.
+ Status RecordAmbiguousOrNonDistinctIndices(
+ const TuplePointsToAnalysis& points_to_analysis);
- return Status::OK();
-}
+ // Records instruction buffer indices which have interfering live ranges
+ // with 'other_instruction' buffers at same index.
+ Status RecordIndicesWhichInterfereWithOtherInstruction(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out);
-// Removes any control dependencies to or from the given instruction.
-Status StripControlDependenciesFrom(HloInstruction* instruction) {
- while (!instruction->control_successors().empty()) {
- TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
- instruction->control_successors().front()));
- }
+ // Recursively inserts copies of 'instruction' tuple elements at indices
+ // specified in 'indices_to_copy', and returns the copy of 'instruction'.
+ HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index);
- while (!instruction->control_predecessors().empty()) {
- TF_RETURN_IF_ERROR(
- instruction->control_predecessors().front()->RemoveControlDependencyTo(
- instruction));
+ void RecordIndex(const ShapeIndex& index) {
+ *indices_to_copy_.mutable_element(index) = true;
}
- return Status::OK();
-}
-
-// Add kCopy instructions to the given module to guarantee there is no
-// live-range interference. Generally interference can only occur around kWhile
-// instructions which have update-in-place semantics.
-Status AddCopiesToResolveInterference(HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
+ HloInstruction* instruction_;
+ const std::vector<HloInstruction*> copy_users_;
+ ShapeTree<bool> indices_to_copy_;
+ ShapeTree<std::vector<HloInstruction*>> control_predecessors_;
+ ShapeTree<bool> read_only_indices_;
+ ShapeTree<HloInstruction*> copy_overrides_;
+};
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kWhile) {
- TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
- }
- }
- }
- return Status::OK();
+bool InstructionCopier::HasAllIndicesFalse() const {
+ bool all_indices_false = true;
+ indices_to_copy_.ForEachElement(
+ [&all_indices_false](const ShapeIndex& /*index*/, bool data) {
+ if (data) {
+ all_indices_false = false;
+ }
+ });
+ return all_indices_false;
}
-// Class for removing unnecessary copies from the module.
-//
-// kCopy instructions are added conservatively to guarantee no live range
-// interference between HLO values. This class uses a more fine-grained analysis
-// to remove some of these added copies which are not strictly necessary.
-class CopyRemover {
- public:
- CopyRemover(const HloAliasAnalysis& alias_analysis,
- const HloOrdering& ordering, HloModule* module)
- : module_(module),
- alias_analysis_(alias_analysis),
- ordering_(ordering),
- buffer_value_tracker_(*module, alias_analysis, ordering) {}
-
- // Try to elide the given copy. The copy is elided if the instruction is not
- // necessary to prevent live-range interference of HLO values. Returns true if
- // copy was elided.
- //
- // The copy instruction is not actually removed here. Instead it is left for
- // dead in the graph. Later calls to DCE will remove the instruction.
- StatusOr<bool> TryElideCopy(HloInstruction* copy) {
- if (buffer_value_tracker_.TryElideCopy(copy)) {
- TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy));
- TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0)));
- return true;
- }
- return false;
+Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant(
+ const TuplePointsToAnalysis& points_to_analysis) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction_);
+ // Shallow copy the instruction if the points-to set of the top-level
+ // buffer is ambiguous. This is necessary because the backends must know
+ // statically what the top-level buffer of the result is.
+ if (points_to.element(/*index=*/{}).size() > 1) {
+ RecordIndex({});
}
- string ToString() const {
- string out = StrCat("CopyRemover, module ", module_->name(), "\n");
- StrAppend(&out, " Buffer values, in dependency order:\n");
- for (const HloBuffer& buffer : alias_analysis_.buffers()) {
- StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
+ // Multiple buffers within a parameter/constant may be live out, so collect
+ // a set of indices at which to copy first.
+ points_to.ForEachElement([this](const ShapeIndex& index,
+ const PointsToSet::BufferList& buffers) {
+ if (IsReadOnlyIndex(index)) {
+ return;
}
- return out;
- }
-
- private:
- // Class which tracks the HLO values within each HLO buffer in the module
- // during copy removal.
- //
- // The values are held in a linked list where there is one list for each
- // buffer. Removing a copy instruction merges together the values in the
- // source buffer of the copy to the destination buffer of the copy. This class
- // tracks these value lists as copies are removed from the graph (and value
- // lists are merged).
- //
- // The BufferValueTracker object is initialized to match the state of
- // HloAliasAnalysis. However, as copies are removed this state diverges. The
- // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
- // a fully updatable alias analysis is very slow.
- class BufferValueTracker {
- public:
- // The values held in a single HLO buffer are represented using a linked
- // list. An element type in this list is ValueNode.
- //
- // This linked list is hand-rolled to enable efficient splicing of lists
- // using only references to list elements without knowing which lists are
- // being spliced. std::list requires a reference to the list object to
- // splice.
- struct ValueNode {
- explicit ValueNode(const HloValue* v) : value(v) {}
-
- const HloValue* value;
-
- // The uses are maintained outside of HloValue::uses() because
- // HloValue::uses() is not updatable (a fully updatable dataflow analysis
- // is slow).
- std::vector<const HloUse*> uses;
-
- // next/prev elements in the linked list. The list is circularly linked so
- // these values are never null for elements in the list.
- ValueNode* prev = nullptr;
- ValueNode* next = nullptr;
- };
-
- BufferValueTracker(const HloModule& module,
- const HloAliasAnalysis& alias_analysis,
- const HloOrdering& ordering)
- : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
- // Construct a list for each HLO buffer in the alias analysis. Maintain a
- // map from HloValue to the respective list element representing that
- // value. The map is used to construct the copy info map below.
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node;
- for (const HloBuffer& buffer : alias_analysis.buffers()) {
- // Verify values contained in the buffer are strictly ordered. This
- // should always be the case after adding copies to eliminate
- // interference. Specifically, the addition of the control flow edges
- // between copies added around aliased operations (kWhile) guarantees
- // this strict order.
- for (const HloValue* value_a : buffer.values()) {
- for (const HloValue* value_b : buffer.values()) {
- if (value_a != value_b) {
- DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
- dataflow_) ||
- ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
- dataflow_))
- << value_a->ToShortString() << " and "
- << value_b->ToShortString() << " are not ordered";
- }
- }
- }
-
- std::vector<const HloValue*> values = buffer.values();
- std::sort(values.begin(), values.end(),
- [this](const HloValue* a, const HloValue* b) {
- return ordering_.IsDefinedBefore(*a, *b);
- });
-
- // Create a list containing all of the values in the buffer.
- AddValueList(values, &value_to_node);
+ for (const LogicalBuffer* buffer : buffers) {
+ // pointee is the HloInstruction producing the buffer which may be
+ // liveout.
+ HloInstruction* pointee = buffer->instruction();
+ if (pointee->opcode() == HloOpcode::kParameter ||
+ pointee->opcode() == HloOpcode::kConstant) {
+ VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
+ << " index: " << tensorflow::str_util::Join(index, ",")
+ << " may be live out of computation: " << pointee->ToString();
+ RecordIndex(index);
+ break;
}
-
- // Create copy_map_ which contains the source and destination values
- // of all copies.
- CreateCopyMap(module, value_to_node);
-
- XLA_VLOG_LINES(3, ToString());
- TF_DCHECK_OK(Verify());
}
+ });
+ return Status::OK();
+}
- // Add a list containing the given values to BufferValueTracker. This
- // represents the values contained in a single buffer. For each value in
- // 'values' an entry is created in value_to_node which indicates the
- // respective ValueNode representing that value.
- void AddValueList(
- tensorflow::gtl::ArraySlice<const HloValue*> values,
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
- ValueNode* tail = nullptr;
- ValueNode* head = nullptr;
- for (const HloValue* value : values) {
- auto new_node = new ValueNode(value);
- (*value_to_node)[value] = new_node;
-
- // Copy the HLO values's uses into the ValueNode for the value. These
- // uses in ValueNode are updated as copies are removed.
- new_node->uses.reserve(value->uses().size());
- for (const HloUse& use : value->uses()) {
- new_node->uses.push_back(&use);
- }
-
- // Connect the new node into the linked list.
- if (tail == nullptr) {
- head = new_node;
- } else {
- tail->next = new_node;
- new_node->prev = tail;
- }
- tail = new_node;
- }
-
- // The linked list is circular so connect the head and tail.
- tail->next = head;
- head->prev = tail;
- value_lists_.insert(head);
- }
+Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out) {
+ TF_RETURN_IF_ERROR(
+ RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis()));
+ TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction(
+ liveness, other_instruction, read_only_indices_out));
+ return Status::OK();
+}
- // This method also fills in copy_map_ which indicates which nodes
- // in the value lists corresponding to the source and destination values of
- // kCopy instructions. value_to_node should map each HloValue to its
- // respective ValueNode.
- void CreateCopyMap(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>&
- value_to_node) {
- for (HloComputation* computation : module.computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- // Add copies with unambiguous source values to the map. Copies with
- // ambiguous sources are not removable.
- if (instruction->opcode() == HloOpcode::kCopy) {
- const HloValueSet& src_value_set =
- dataflow_.GetValueSet(instruction->operand(0));
- if (src_value_set.values().size() == 1) {
- CopyNodes& copy_node = copy_map_[instruction];
- copy_node.dest =
- value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
- copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
- }
+Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
+ const TuplePointsToAnalysis& points_to_analysis) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction_);
+ // Mapping from LogicalBuffer to index (used to detect non-distinct indices).
+ FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
+ buffer_to_source_indices;
+ points_to.ForEachElement(
+ [this, &buffer_to_source_indices](
+ const ShapeIndex& index, const PointsToSet::BufferList& buffers) {
+ if (buffers.size() > 1) {
+ // Record ambiguous points-to set at 'index'.
+ if (!indices_to_copy_.element(index)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(index, ",")
+ << " with ambiguous points-to set.";
+ RecordIndex(index);
}
}
- }
- }
+ // For each 'buffer': record a mapping from 'buffer' to 'index'.
+ for (const LogicalBuffer* buffer : buffers) {
+ buffer_to_source_indices[buffer].push_back(index);
+ }
+ });
- ~BufferValueTracker() {
- for (const ValueNode* head : value_lists_) {
- const ValueNode* p = head;
- do {
- const ValueNode* tmp = p->next;
- delete p;
- p = tmp;
- } while (p != head);
- }
+ // Record all non-distinct indices detected in 'buffer_to_source_indices'.
+ for (const auto& buff_to_src : buffer_to_source_indices) {
+ if (buff_to_src.second.size() == 1) {
+ continue;
}
-
- // Verify invariants within the linked lists.
- Status Verify() const {
- for (const ValueNode* head : value_lists_) {
- const ValueNode* p = head;
- do {
- // Verify links between elements are consistent.
- TF_RET_CHECK(p->prev->next == p);
- TF_RET_CHECK(p->next->prev == p);
-
- const HloInstruction* def = p->value->defining_instruction();
- if (def->opcode() == HloOpcode::kCopy &&
- ContainsKey(copy_map_, def)) {
- TF_RET_CHECK(copy_map_.at(def).dest == p);
- }
- for (const HloUse* use : p->uses) {
- if (use->instruction->opcode() == HloOpcode::kCopy &&
- ContainsKey(copy_map_, use->instruction)) {
- TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
- }
- }
-
- p = p->next;
- } while (p != head);
+ for (const ShapeIndex& src_index : buff_to_src.second) {
+ // Record non-distinct points-to set at 'src_index'.
+ if (!indices_to_copy_.element(src_index)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(src_index, ",")
+ << " because of non-distinct points-to set.";
+ RecordIndex(src_index);
}
- return Status::OK();
}
+ }
+ return Status::OK();
+}
- // Try to elide the given copy. Elision of a copy is possible only if no
- // live range interference is introduced by the copy's elimination. If
- // elision is possible, then the internal state (value lists) are updated,
- // and true is returned. Returns false otherwise.
- bool TryElideCopy(const HloInstruction* copy) {
- VLOG(2) << "Trying to remove " << copy->name();
-
- if (!ContainsKey(copy_map_, copy)) {
- VLOG(2) << copy->name() << " is not removable";
- return false;
- }
-
- const CopyNodes& copy_node = copy_map_.at(copy);
- ValueNode* src = copy_node.src;
- ValueNode* dest = copy_node.dest;
- DCHECK(src != nullptr);
- DCHECK(dest != nullptr);
-
- auto is_live_range_before = [this](const ValueNode& a,
- const ValueNode& b) {
- if (LiveRangeBefore(a, b)) {
- VLOG(2) << " Live range of " << a.value->ToShortString()
- << " is before " << b.value->ToShortString();
- return true;
- } else {
- VLOG(2) << " Live range of " << a.value->ToShortString()
- << " is not before " << b.value->ToShortString();
- return false;
+Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out) {
+ // Record all buffer indices for 'instruction_', which interfere with
+ // 'other_instruction' at the same index.
+ ShapeUtil::ForEachSubshape(
+ instruction_->shape(),
+ [this, &liveness, other_instruction, read_only_indices_out](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ if (IsReadOnlyIndex(index)) {
+ return;
}
- };
-
- // A kCopy instruction copies an HLO value from a source buffer and
- // defines an HLO value in a destination buffer. Most generally, the
- // source and destination buffers may each hold more than one value at
- // different points in the computation so we define the following:
- //
- // Values in source buffer: {s_0, ..., s_n}
- // Values in destination buffer: {d_0, ..., d_m}
- //
- // A kCopy instruction between these buffers copies a value s_x in the
- // source buffer and defines a value d_y in the destination buffer. The
- // elision of a copy merges the source and destination buffers together,
- // so the list of values for the source and destination buffers are
- // merged.
- //
- // We handle two different cases for copy elision:
- //
- // (1) the kCopy defines the first value in the destination buffer (d_0).
- //
- // (2) the kCopy copies the last value in the source buffer (s_n).
- //
- // For the remaining case where the kCopy copies a not-last value from the
- // source buffer to a not-first value of the destination buffer, the kCopy
- // instruction cannot be removed. This case is generated, for example, if
- // the kCopy copies a while body parameter of the loop state at one tuple
- // index to a different tuple index in the while body root. Removal of the
- // copy necessarily results in live range interference of values in the
- // loop state at the two different tuple indices.
- //
- // We can only perform copy elision if the resulting merged values have
- // totally ordered live ranges; otherwise the merged buffer would have
- // live range interference.
- if (IsHead(*dest)) {
- // The copy copies an arbitrary value in the source buffer (call it s_x)
- // and defines d_0, the first value in the destination buffer. After
- // merging, the values in the combined buffer must be strictly ordered
- // as follows** to elide the copy:
- //
- // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
- //
- // Removing the copy eliminates d_0, and uses of d_0 become uses of
- // s_x. In the above ordering, the live range of d_m must be ordered
- // before the live range of s_{x+1} and the definition and all uses of
- // s_x must be ordered before the definition of d_1. These conditions
- // are checked below prior to elision.
- //
- // ** Technically it might be possible to have a non-interfering
- // non-trivial interleaving of the values of the source and
- // destination buffers in the resulting order. However, this case is
- // slow and complicated to check and likely not worth it. So instead
- // we simply check for the case where *all* values of the destination
- // buffer (d_1 through d_m) are spliced into the point where the copy
- // used to be.
- VLOG(2) << copy->name() << " defines the first value in its buffer";
- ValueNode* next_dest = Next(*dest);
- if (next_dest != nullptr) {
- // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
- if (!is_live_range_before(*src, *next_dest)) {
- return false;
- }
+ if (indices_to_copy_.element(index)) {
+ // Return if previous pass already set index.
+ return;
}
- ValueNode* next_src = Next(*src);
-
- if (next_src != nullptr) {
- // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
- ValueNode* last_dest = dest->prev;
- DCHECK(IsTail(*last_dest));
- if (!is_live_range_before(*last_dest, *next_src)) {
- return false;
+ const auto& points_to_analysis = liveness.points_to_analysis();
+ // Lookup buffers for 'instruction_' and 'other_instruction'.
+ const auto instruction_buffers =
+ points_to_analysis.GetPointsToSet(instruction_).element(index);
+ // If 'instruction_' has ambiguous points-to-set at 'index', it would
+ // have been recorded in a previous pass (and we would have returned
+ // early at the entry to this function). As a result, here we know that
+ // 'instruction_' has just one buffer in its points-to-set.
+ CHECK_EQ(1, instruction_buffers.size());
+ const LogicalBuffer* instruction_buffer = instruction_buffers[0];
+
+ const auto other_instruction_buffers =
+ points_to_analysis.GetPointsToSet(other_instruction).element(index);
+ // Do not insert a copy if both instructions point at the same buffer.
+ // This eliminates unnecessary copies of read-only tuple elements.
+ // If 'instruction_' and 'other_instruction' point to the same buffer,
+ // then that buffer is not updated on the path between the two
+ // instructions. Therefore, any other (possibly interference-causing)
+ // users of that buffer from 'other_instruction' will see the same data,
+ // irrespective of whether we insert a copy of this buffer at
+ // 'instruction_' or not.
+ if (other_instruction_buffers.size() == 1 &&
+ other_instruction_buffers[0]->id() == instruction_buffer->id()) {
+ if (read_only_indices_out != nullptr) {
+ *read_only_indices_out->mutable_element(index) = true;
}
+ return;
}
-
- // Splice in destination buffer values list right after 'src'.
- SpliceAfter(dest, src);
- } else if (IsTail(*src)) {
- // The copy copies the last value in the source buffer, s_n, and defines
- // an arbitrary value in the destination buffer, d_y. After
- // merging, the values in the combined buffer must be strictly ordered
- // as follows** to elide the copy:
- //
- // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
- //
- // Removing the copy eliminates d_y, and uses of d_y become uses of
- // s_n. To enforce the above order, the live range of d_{y-1} must be
- // before the live range of s_0, and the live range of s_n must be
- // before the live range of d_{y+1}.
- //
- // ** See comment above in the code handling Case (1).
- VLOG(2) << copy->name() << " copies the last value ("
- << src->value->ToShortString() << ") in its buffer";
-
- ValueNode* prev_dest = Prev(*dest);
- // nullptr condition handled above in the first 'if' case.
- DCHECK(prev_dest != nullptr);
- ValueNode* first_src = src->next;
- DCHECK(IsHead(*first_src));
- if (!is_live_range_before(*prev_dest, *first_src)) {
- // Live range of value d_{y-1} is not before s_0.
- return false;
- }
- ValueNode* next_dest = Next(*dest);
- if (next_dest != nullptr) {
- if (!is_live_range_before(*src, *next_dest)) {
- // Live range of value s_n is not before d_{y+1}.
- return false;
+ // We can't say anything about the ambiguity of 'other_instruction' at
+ // this point, so we need to check interference between the single
+ // buffer in the points-to set of 'instruction_' and all buffers in
+ // 'other_instruction_buffers'.
+ for (const LogicalBuffer* other_buffer : other_instruction_buffers) {
+ if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " instruction_buffer: " << instruction_buffer->ToString()
+ << " at index: " << tensorflow::str_util::Join(index, ",")
+ << " because of interference with buffer: "
+ << other_buffer->ToString();
+ RecordIndex(index);
+ break;
}
}
+ });
+ return Status::OK();
+}
- // Splice source buffer values list right after 'prev_dest'.
- SpliceAfter(first_src, prev_dest);
- } else {
- VLOG(2)
- << copy->name()
- << " copies value in middle of source buffer to value in middle "
- "of destination buffer";
- return false;
- }
-
- RemoveCopyValue(dest);
-
- XLA_VLOG_LINES(4, ToString());
- TF_DCHECK_OK(Verify());
-
- return true;
- }
-
- // Delete the given ValueNode associated with a elided kCopy
- // instruction. This should be called after splicing the value lists of the
- // source and destination buffers together.
- void RemoveCopyValue(ValueNode* copy_value_node) {
- CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
- HloOpcode::kCopy);
- ValueNode* operand_node = copy_value_node->prev;
- CHECK(operand_node != copy_value_node);
-
- VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
- << " => " << copy_value_node->value->ToShortString();
-
- // Splice out the copy value node.
- operand_node->next = copy_value_node->next;
- copy_value_node->next->prev = operand_node;
-
- // Patch up uses. Remove use of copy from operand_node uses.
- auto it =
- std::find_if(operand_node->uses.begin(), operand_node->uses.end(),
- [copy_value_node](const HloUse* use) {
- return use->instruction ==
- copy_value_node->value->defining_instruction();
- });
- CHECK(it != operand_node->uses.end());
- operand_node->uses.erase(it);
-
- // If the elided copy has any uses which are themselves kCopy instructions
- // then patch up the copy info to reflect the that this kCopy instruction
- // has a different operand (the operand of the elided copy).
- for (const HloUse* copy_use : copy_value_node->uses) {
- operand_node->uses.push_back(copy_use);
- if (copy_use->instruction->opcode() == HloOpcode::kCopy) {
- copy_map_.at(copy_use->instruction).src = operand_node;
- }
- }
-
- // Delete the copy info and the value node.
- copy_map_.erase(copy_value_node->value->defining_instruction());
- delete copy_value_node;
- }
-
- // Returns true if the live range of given value 'a' is before the live
- // range of 'b'.
- //
- // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
- // updated as copies are removed.
- bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
- if (a.uses.empty()) {
- VLOG(2) << "Empty uses";
- return ordering_.IsDefinedBefore(*a.value, *b.value);
- }
- for (const HloUse* use : a.uses) {
- VLOG(2) << "use: " << *use;
- VLOG(2) << "is before:" << *b.value;
- if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
- VLOG(2) << "Not before";
- return false;
+// This is called when 'instruction_' is a while body root, and 'parameter' is
+// the while body parameter. We record all users of all aliases of 'parameter'
+// as control predecessors, so that when we add a copy of 'instruction_', we can
+// mark the control dependencies. This is necessary because points-to and
+// liveness analysis doesn't know about the aliasing between the while body root
+// and param. Without these control dependencies, the copy might get scheduled
+// to run at a point that interferes with users of the buffer.
+Status InstructionCopier::RecordControlPredecessors(
+ const TuplePointsToAnalysis& points_to_analysis,
+ HloInstruction* parameter) {
+ return indices_to_copy_.ForEachElementWithStatus(
+ [this, &points_to_analysis, parameter](const ShapeIndex& index,
+ bool will_copy) {
+ if (will_copy) {
+ TF_ASSIGN_OR_RETURN(
+ const LogicalBuffer* buffer,
+ points_to_analysis.GetBufferDefinedAt(parameter, index));
+ for (const BufferAlias& alias :
+ points_to_analysis.GetBufferAliases(*buffer)) {
+ for (HloInstruction* user : alias.instruction()->users()) {
+ if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
+ user, points_to_analysis)) {
+ continue;
+ }
+
+ if (user != instruction_) {
+ control_predecessors_.mutable_element(index)->push_back(user);
+ }
+ }
+ }
}
- }
- return true;
- }
-
- // Returns whether 'node' is the last node in its list.
- bool IsTail(const ValueNode& node) const {
- return ContainsKey(value_lists_, node.next);
- }
-
- // Returns whether 'node' is the first node in its list.
- bool IsHead(const ValueNode& node) const {
- return ContainsKey(value_lists_, &node);
- }
-
- // Returns the next node in the list after 'node'. If 'node' is the
- // tail, then nullptr is returned.
- ValueNode* Next(const ValueNode& node) const {
- if (IsTail(node)) {
- return nullptr;
- } else {
- return node.next;
- }
- }
-
- // Returns the previous node in the list before 'node'. If 'node'
- // is the head, then nullptr is returned.
- ValueNode* Prev(const ValueNode& node) const {
- if (IsHead(node)) {
- return nullptr;
- } else {
- return node.prev;
- }
- }
-
- // Splices the entire linked list with 'head' as its head right after the
- // node 'insert_after' in another linked list.
- void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
- DCHECK(IsHead(*head));
- value_lists_.erase(head);
-
- ValueNode* tail = head->prev;
- tail->next = insert_after->next;
- insert_after->next->prev = tail;
-
- insert_after->next = head;
- head->prev = insert_after;
- }
-
- string ToString() const {
- string out = StrCat("BufferValueTracker:\n");
- StrAppend(&out, " Def-use chains in each buffer:\n");
- for (const ValueNode* head : value_lists_) {
- StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
- ":\n");
- const ValueNode* p = head;
- do {
- StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
- Join(p->uses, "; ",
- [](string* s, const HloUse* use) {
- StrAppend(s, use->ToString());
- }),
- "\n");
-
- p = p->next;
- } while (p != head);
- }
- StrAppend(&out, " Potentially removable copies:\n");
- for (const auto& pair : copy_map_) {
- const HloInstruction* copy = pair.first;
- const CopyNodes& copy_info = pair.second;
+ return Status::OK();
+ });
+}
- StrAppend(&out, " ", copy->name(), " : ",
- copy_info.src->value->ToShortString(), " => ",
- copy_info.dest->value->ToShortString(), "\n");
- }
- return out;
+// Recursively inserts copies of 'instruction' tuple element buffers at
+// indices in 'indices_to_copy_', expanding tuples as needed.
+HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction,
+ ShapeIndex* index) {
+ const int64 num_tuple_elements =
+ ShapeUtil::TupleElementCount(instruction->shape());
+ std::vector<HloInstruction*> elem_copies(num_tuple_elements);
+ for (int64 i = 0; i < num_tuple_elements; ++i) {
+ HloInstruction* elem;
+ if (instruction->opcode() == HloOpcode::kTuple) {
+ // If the instruction is already a Tuple instruction, we know that the
+ // element buffers are aliased, so we can just grab the operand directly.
+ elem = instruction->mutable_operand(i);
+ } else {
+ // Otherwise we need to add a GTE to unpack the element out of the tuple.
+ elem = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
+ i));
}
-
- private:
- const HloDataflowAnalysis& dataflow_;
- const HloOrdering& ordering_;
-
- // The heads of all the value lists. Each value list represents the HLO
- // values contained in a particular HLO buffer. The values in the list are
- // in dependency order.
- tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
-
- // Copy removal requires fast access to the value list elements
- // corresponding to the source and destination values of the kCopy
- // instruction. This data structure holds pointers to these elements for
- // each kCopy instruction in the graph.
- struct CopyNodes {
- // The source and destinations values of the kCopy instruction.
- ValueNode* src = nullptr;
- ValueNode* dest = nullptr;
- };
- tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_;
- };
-
- HloModule* module_;
- const HloAliasAnalysis& alias_analysis_;
- const HloOrdering& ordering_;
-
- // Object tracking the HLO values contained in each HLO buffer.
- BufferValueTracker buffer_value_tracker_;
-};
-
-// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<HloInstruction::Id>& copies_to_exclude,
- HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- CopyRemover copy_remover(*alias_analysis, ordering, module);
- XLA_VLOG_LINES(3, copy_remover.ToString());
-
- tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id())) {
- TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
+ index->push_back(i);
+ if (ShapeUtil::IsTuple(elem->shape())) {
+ elem_copies[i] = CopyTuple(elem, index);
+ } else if (!indices_to_copy_.element(*index)) {
+ elem_copies[i] = elem;
+ } else if (HloInstruction* copy_override = GetCopyOverride(*index)) {
+ elem_copies[i] = copy_override;
+ } else {
+ HloInstruction* elem_copy = elem->parent()->AddInstruction(
+ HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem));
+ for (HloInstruction* control_predecessor :
+ control_predecessors_.element(*index)) {
+ VLOG(2) << "Adding control dependency from "
+ << control_predecessor->ToString() << " to "
+ << elem_copy->ToString();
+ TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy));
}
+ elem_copies[i] = elem_copy;
}
+ index->pop_back();
}
-
- return Status::OK();
+ return instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(elem_copies));
}
-// Add copies to address special constraints on the roots of computations not
-// related to live range interference:
-//
-// (1) Entry computation root must be unambiguous and distinct.
-//
-// (2) Any computation called by a kCall instruction must have an
-// unambiguous root.
-//
-// (3) Constants and parameters cannot be live out of the entry computation
-//
-Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
-
- // Identify which shape indices of which instructions need to be copied. Store
- // these results in 'instructions_to_copy'.
- std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy;
- auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
- const ShapeIndex& index) {
- auto it = instructions_to_copy.find(instruction);
- if (it == instructions_to_copy.end()) {
- auto it_added = instructions_to_copy.emplace(
- std::piecewise_construct, std::forward_as_tuple(instruction),
- std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
- it = it_added.first;
- }
- *it->second.mutable_element(index) = true;
- };
-
- // Iterate through values of all constants and entry parameters. These values
- // are special because they are held in read-only buffers. If any of these
- // values share a buffer with other values (for example, the init value of a
- // while is a constant) then copy the value at its definition and replace all
- // its uses with the copy.
- for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
- if (ValueIsReadOnly(*value) &&
- alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
- VLOG(2) << "Value " << value->ToShortString()
- << " is read only, but its buffer contains more than one value. "
- "Copying.";
- add_index_to_copy(value->defining_instruction(), value->defining_index());
- }
+// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'.
+HloInstruction* InstructionCopier::Copy() {
+ ShapeIndex index;
+ HloInstruction* copy;
+ if (ShapeUtil::IsTuple(instruction_->shape())) {
+ copy = CopyTuple(instruction_, &index);
+ } else {
+ copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction_->shape(), HloOpcode::kCopy, instruction_));
}
-
- // Identify copies which must be added at root instructions
- for (HloComputation* computation : module->computations()) {
- const CallGraphNode& node = call_graph.GetNode(computation);
- if (node.context() == CallContext::kParallel) {
- continue;
- }
- TF_RET_CHECK(node.context() == CallContext::kSequential);
-
- const bool is_entry = computation == module->entry_computation();
- HloInstruction* root = computation->root_instruction();
-
- // Mark nondistinct/ambiguous indices.
- tensorflow::gtl::FlatSet<const HloBuffer*> seen;
- ShapeUtil::ForEachSubshape(
- root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
- std::vector<const HloBuffer*> buffers_at_index =
- alias_analysis->ComputeBuffersAt(root, index);
- bool buffer_seen_before = false;
- for (const HloBuffer* buffer : buffers_at_index) {
- buffer_seen_before |= !seen.insert(buffer).second;
- }
- if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) {
- VLOG(2) << "Index " << index << " of root of computation "
- << computation->name() << " (" << root->name()
- << ") has ambiguous or non-distinct buffer. Copying.";
- add_index_to_copy(root, index);
- }
- });
-
- // For entry instructions, mark any parameter or constant values.
- if (is_entry) {
- for (const auto& pair :
- alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
- const ShapeIndex& index = pair.first;
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (ValueIsReadOnly(*value)) {
- VLOG(2) << "Root of entry computation (" << root->name()
- << ") has constant or entry parameter value at index "
- << index << ". Copying.";
- add_index_to_copy(root, index);
- }
- }
- }
- }
+ for (HloInstruction* user : copy_users_) {
+ VLOG(2) << "Adding copy between instruction: " << instruction_->name()
+ << " and user: " << user->name();
+ TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy));
}
+ return copy;
+}
- // TODO(b/62548313): Buffer assignment uses TuplePointsToAnalysis which is
- // computation-scoped. This means the analysis doesn't have visibility to
- // constants and entry parameters that cross computation boundaries. This can
- // cause invalid buffer assignments so additional conservative copies are
- // added to handle these cases. Remove this whole loop when buffer assignment
- // uses alias analysis.
- for (HloComputation* computation : module->computations()) {
- const CallGraphNode& node = call_graph.GetNode(computation);
-
- bool is_while_body = false;
- if (node.context() == CallContext::kSequential &&
- !node.caller_callsites().empty()) {
- CHECK_EQ(node.caller_callsites().size(), 1);
- const HloInstruction* calling_instruction =
- node.caller_callsites()[0].instruction();
- is_while_body = calling_instruction->opcode() == HloOpcode::kWhile &&
- calling_instruction->while_body() == node.computation();
+// The 'read_only_indices' are initialized based on points-to analysis on the
+// while body corresponding to 'while_hlo'. If the init buffer corresponding to
+// a read-only index aliases with a constant, it cannot be considered read-only,
+// and must be copied. This is necessary because BufferAssignment does not
+// currently assign an allocation for constants (b/32248867).
+// This function performs this fix-up of 'read_only_indices'.
+//
+// Returns a ShapeTree of copy_overrides, which implements an optimization to
+// allow multiple while loops that share the same read-only constants to
+// share a single copy.
+StatusOr<ShapeTree<HloInstruction*>> RevertReadOnlyIndicesForConstants(
+ const HloInstruction* while_hlo,
+ const TuplePointsToAnalysis& points_to_analysis,
+ ShapeTree<bool>* read_only_indices,
+ FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
+ const HloInstruction* init_hlo = while_hlo->operand(0);
+ const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
+
+ // Mapping from LogicalBuffer to index (used to detect non-distinct indices).
+ FlatSet<const LogicalBuffer*> buffer_set;
+
+ ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
+ points_to.ForEachElement([init_hlo, read_only_indices, shared_copies,
+ &buffer_set, &copy_overrides](
+ const ShapeIndex& index,
+ const PointsToSet::BufferList& buffers) {
+ // Look for read-only entry parameters.
+ if (!read_only_indices->element(index)) {
+ return;
}
- VLOG(2) << computation->name() << " is_while_body: " << is_while_body;
- HloInstruction* root = computation->root_instruction();
+ for (const LogicalBuffer* buffer : buffers) {
+ HloInstruction* pointee = buffer->instruction();
+ const bool is_constant = pointee->opcode() == HloOpcode::kConstant;
+ if (!is_constant) {
+ continue;
+ }
- for (const auto& pair :
- alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
- const ShapeIndex& index = pair.first;
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (IsConstantValue(*value) && !is_while_body) {
- VLOG(2) << "Root of computation (" << root->name()
- << ") is constant at index " << index << ". Copying.";
- add_index_to_copy(root, index);
+ // We have found an constant that is read-only in
+ // the while body. These buffers are managed by the caller, and cannot
+ // be aliased with HLO buffers. Revert this read-only index,
+ // to allow it to be copied.
+ *read_only_indices->mutable_element(index) = false;
+
+ // Optimization to allow multiple while loops that share the same
+ // read-only entry constants to share a single copy.
+ // Only unambiguous and distinct array-shaped buffers are allowed, to
+ // reduce code complexity. The shape of the entry parameter must be
+ // identical to the shape of the init_hlo at this index, to ensure
+ // there were no intervening bitcast or GTE instructions, which are
+ // also hard to handle.
+ const Shape& pointee_shape = pointee->shape();
+ const Shape& init_shape =
+ ShapeUtil::GetSubshape(init_hlo->shape(), index);
+ if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
+ ShapeUtil::Equal(pointee_shape, init_shape) &&
+ buffer_set.count(buffer) < 1) {
+ HloInstruction** copy = &(*shared_copies)[pointee];
+ if (*copy == nullptr) {
+ *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary(
+ pointee_shape, HloOpcode::kCopy, pointee));
}
+ // Add the copy as an override.
+ *copy_overrides.mutable_element(index) = *copy;
}
- }
- }
- // Add copy instructions indicated in 'instructions_to_copy' to the module.
- for (const auto& pair : instructions_to_copy) {
- HloInstruction* instruction = pair.first;
- const ShapeTree<bool>& indices_to_copy = pair.second;
+ // Tracks whether this current buffer is distinct.
+ buffer_set.insert(buffer);
- std::vector<HloInstruction*> users = instruction->users();
- TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
- instruction->parent()->DeepCopyInstruction(
- instruction, &indices_to_copy));
- for (HloInstruction* user : users) {
- TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
+ // We've already reverted the read-only index and handled the
+ // single-copy optimization above, so there's nothing more to do.
+ break;
}
- if (instruction == instruction->parent()->root_instruction()) {
- instruction->parent()->set_root_instruction(deep_copy);
- }
- }
-
- return Status::OK();
-}
-
-Status VerifyNoLiveRangeInterference(HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- DependencyHloOrdering ordering(module);
- TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
- return Status::OK();
+ });
+ return copy_overrides;
}
-void MaybeDumpModule(const string& message, const HloModule& module) {
- if (VLOG_IS_ON(3)) {
- VLOG(3) << message;
- XLA_VLOG_LINES(3, module.ToString());
- hlo_graph_dumper::MaybeDumpHloModule(module, message);
+} // anonymous namespace
+
+// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the
+// base class, since the regular CopyInsertion logic above selectively copies
+// tuple elements, while this method assumes all buffers need to be deep copied.
+StatusOr<HloInstruction*> CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) {
+ auto copy_it = inserted_copies_.find(hlo);
+ if (copy_it == inserted_copies_.end()) {
+ HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie();
+ inserted_copies_.insert({hlo, copy});
+ return copy;
+ } else {
+ return copy_it->second;
}
}
-} // namespace
-
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
- // Copy insertion is performed in three steps:
- //
- // (1) Add copies conservatively to guarantee that there is no live-range
- // interference. This is done simplistically and usually results in more
- // copies than is strictly necessary.
- //
- // (2) Using a more fine-grained analysis, remove as many copies that were
- // added in (1) as possible while ensuring no live-range interference.
- //
- // (3) Add copies to resolve issues not related to live range interference
- // such as parameters and constants live out of the entry computation.
- //
- // We add copies then remove them (step (1) then (2)) rather than simply
- // adding only the copies that are necessary because, in general, it is
- // difficult to figure out the minimal set of copies to add once there is
- // interference. On the other hand, it is easy to determine if removing a copy
- // will introduce interference.
- //
- // The final copy insertion in (3) is done separately to simplify the
- // implementation of copy removal in (2) which is the most complicated part of
- // the pass. As is, copy removal only has to reason about live range
- // interference. If all copies were added in step (1) then copy removal would
- // also have to reason about things like constants and parameters live out of
- // the computation.
- MaybeDumpModule("before copy insertion", *module);
-
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
- if (!call_graph->IsFlattened()) {
- return FailedPrecondition(
- "Call graph must be flattened before copy insertion.");
- }
+ bool changed = false;
+ VLOG(2) << "CopyInsertion for module " << module->name();
- // Gather Ids of existing kCopy instructions in the module. We avoid removing
- // these copies (except via DCE in TupleSimplifier) because they may have been
- // added for reasons not considered by copy insertion (eg, layout assignment).
- // Instruction id is used instead of HloInstruction* because the pointer
- // values may be recycled.
- tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies;
- for (HloComputation* computation : module->computations()) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferLiveness> liveness,
+ BufferLiveness::Run(module, MakeUnique<DependencyHloOrdering>(module)));
+ const auto& points_to_analysis = liveness->points_to_analysis();
+ XLA_VLOG_LINES(2, points_to_analysis.ToString());
+ XLA_VLOG_LINES(2, module->ToString());
+
+ // Gather all while body computations and while instructions.
+ FlatSet<const HloComputation*> while_body_computations;
+ std::vector<HloInstruction*> while_instructions;
+ for (auto* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- existing_copies.insert(instruction->unique_id());
+ if (instruction->opcode() == HloOpcode::kWhile) {
+ while_body_computations.insert(instruction->while_body());
+ while_instructions.push_back(instruction);
}
}
}
- TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
-
- // Simplify the tuple structures introduced by the deep copies. This should be
- // done before removing copies (RemoveUnnecessaryCopies) because tuple
- // simplification changes dependencies in the graph which changes live range
- // interference in the graph. Also run DCE to remove the dead Tuple/GTE
- // instructions introduced by tuple simplification.
- TupleSimplifier tuple_simplifier;
- HloDCE dce;
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- TF_RETURN_IF_ERROR(dce.Run(module).status());
-
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
-
- MaybeDumpModule("after adding copies to resolve interference", *module);
-
- DependencyHloOrdering ordering(module);
- TF_RETURN_IF_ERROR(
- RemoveUnnecessaryCopies(ordering, existing_copies, module));
-
- MaybeDumpModule("after removing unnecessary copies", *module);
-
- TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
-
- MaybeDumpModule("after adding special-case copies", *module);
+ // Collect instruction buffer indices to copy in 'instructions_to_copy'.
+ std::vector<InstructionCopier> instructions_to_copy;
+
+ // Add copies of computation root instructions, if needed.
+ FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices;
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ VLOG(2) << "computation " << computation->name();
+ InstructionCopier root_copier(computation->root_instruction(),
+ /*copy_users=*/{});
+ if (while_body_computations.count(computation) > 0) {
+ // Record root indices to copy for while body sub-computations. We do not
+ // need to call RecordIndicesWhichPointToParamOrConstant for the while
+ // body root instruction here, because any necessary copies needed to
+ // avoid constants or parameters in the output are handled by while.init
+ // operand copy insertion below (which will share an allocation).
+ HloInstruction* while_body_param = computation->parameter_instruction(0);
+ ShapeTree<bool> read_only_indices(while_body_param->shape());
+ TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers(
+ *liveness, while_body_param, &read_only_indices));
+ while_body_read_only_indices[computation] = read_only_indices;
+
+ // Mark control predecessors, based on the body param, for any copies
+ // we'll be inserting. This ensures the copy doesn't run too early.
+ TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors(
+ points_to_analysis, while_body_param));
+ } else {
+ // Record root indices to copy for general computations.
+ TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant(
+ points_to_analysis));
+ }
+ instructions_to_copy.push_back(root_copier);
+ }
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- TF_RETURN_IF_ERROR(dce.Run(module).status());
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
+ // Add copies of while 'init' operand instructions, if needed. 'shared_copies'
+ // is used to ensure that multiple while loops can share a single copy of the
+ // same entry parameter or constant, if all loops use it read-only.
+ //
+ // TODO(b/33301720) Remove redundant while instruction copies.
+ FlatMap<const HloInstruction*, HloInstruction*> shared_copies;
+ for (HloInstruction* while_hlo : while_instructions) {
+ // Fix read_only_indices to account for entry constants. Also
+ // initialize copy_overrides, which ensures a single copy for each read-only
+ // constant that is used in multiple while loops.
+ ShapeTree<bool>* read_only_indices =
+ &while_body_read_only_indices[while_hlo->while_body()];
+ TF_ASSIGN_OR_RETURN(
+ const ShapeTree<HloInstruction*> copy_overrides,
+ RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis,
+ read_only_indices, &shared_copies));
+ // Create InstructionCopier for init operand of while instruction.
+ HloInstruction* init_hlo = while_hlo->mutable_operand(0);
+ InstructionCopier init_copier(init_hlo, {while_hlo});
+ init_copier.SetReadOnlyIndices(*read_only_indices);
+ init_copier.SetCopyOverrides(copy_overrides);
+ // Record 'init' buffer indices which point-to a Constant or Parameter.
+ TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant(
+ points_to_analysis));
+ // Record indices necessary to colocate while and init operand buffers.
+ TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers(
+ *liveness, while_hlo, /*read_only_indices_out=*/nullptr));
+ instructions_to_copy.push_back(init_copier);
+ }
- MaybeDumpModule("after copy insertion", *module);
+ for (InstructionCopier& to_copy : instructions_to_copy) {
+ if (to_copy.HasAllIndicesFalse()) {
+ continue;
+ }
+ changed = true;
- if (VLOG_IS_ON(1)) {
- int64 num_total_copies = 0;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- num_total_copies++;
- }
- }
+ // Copy instruction at recorded buffer indices.
+ HloComputation* computation = to_copy.instruction()->parent();
+ HloInstruction* copy = to_copy.Copy();
+ if (to_copy.instruction() == computation->root_instruction()) {
+ computation->set_root_instruction(copy);
}
- VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size();
- VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
}
- return true;
+ VLOG(3) << "After copy insertion for module " << module->name();
+ XLA_VLOG_LINES(3, module->ToString());
+
+ return changed;
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
index ea3c36b5c7..28bb62e40c 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -25,25 +25,12 @@ limitations under the License.
namespace xla {
-// Copy insertion is a legalization HLO pass which inserts copies (kCopy
-// instructions) to eliminate several kinds of problems in the HLO module.
-//
-// (1) Entry parameter or a constant live out of the entry computation. Entry
-// computation arguments and constants have different lifetimes than the
-// computation result and cannot share the same allocation. Parameters and
-// constants live out of non-entry computations do not need copies.
-//
-// (2) Different values which are simultaneously live and which must be held
-// in the same buffer. This can occur in while bodies. Specifically, the
-// while loop state (the arguments to the while instruction) is updated
-// in-place and the update may clobber the value from the previous
-// iteration before the previous value is dead. Computations called from
-// kCall instructions do not need such copies because kCall has no update
-// in-place semantics.
-//
-// (3) The buffer set of the root instruction of the entry computation must be
-// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and
-// InstructionAliasSet::IsDistinct return true.
+// HLO pass which inserts a copy of the root instruction (creating a new root)
+// if the root is or points-to any constant or parameter instruction.
+// If the root instruction is a Tuple, only tuple elements which point to
+// constant or parameter instructions will be copied.
+// Copy insertion is necessary because constant and parameter arrays have
+// different lifetimes than computation results.
class CopyInsertion : public HloPassInterface {
public:
tensorflow::StringPiece name() const override { return "copy-insertion"; }
@@ -51,6 +38,15 @@ class CopyInsertion : public HloPassInterface {
// Run the pass on the given module. Returns whether the module was changed
// (copies were inserted).
StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
+ // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
+ // duplicate copies.
+ StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
+
+ // A map containing all copies inserted during the copy insertion pass. The
+ // key is the copied instruction and the value is the copy.
+ tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
index 8807c6480b..a2eacc5c7d 100644
--- a/tensorflow/compiler/xla/service/copy_insertion_test.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -17,19 +17,18 @@ limitations under the License.
#include <set>
-#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/platform/test_benchmark.h"
namespace op = xla::testing::opcode_matchers;
@@ -38,53 +37,35 @@ namespace {
using ::testing::UnorderedElementsAre;
-int64 CountCopies(const HloComputation& computation) {
- int64 count = 0;
- for (const auto& instruction : computation.instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- count++;
- }
- }
- return count;
-}
-
-int64 CountCopies(const HloModule& module) {
- int64 count = 0;
- for (const auto& computation : module.computations()) {
- count += CountCopies(*computation);
- }
- return count;
-}
-
-int64 CountControlEdges(const HloComputation& computation) {
- int64 count = 0;
- for (const auto& instruction : computation.instructions()) {
- count += instruction->control_successors().size();
- }
- return count;
-}
-
-int64 CountControlEdges(const HloModule& module) {
- int64 count = 0;
- for (const auto& computation : module.computations()) {
- count += CountControlEdges(*computation);
- }
- return count;
-}
-
class CopyInsertionTest : public HloTestBase {
protected:
void InsertCopies(HloModule* module) {
CopyInsertion copy_insertion;
- ASSERT_IS_OK(copy_insertion.Run(module).status());
+ EXPECT_IS_OK(copy_insertion.Run(module).status());
+
+ // Verify the points to set of the root of the computation after copy
+ // insertion contains no constants or parameters, and is distinct and
+ // non-ambiguous.
+ auto points_to_analysis =
+ TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
+ const auto& points_to = points_to_analysis->GetPointsToSet(
+ module->entry_computation()->root_instruction());
+ EXPECT_TRUE(points_to.IsDistinct());
+ EXPECT_TRUE(!points_to.IsAmbiguous());
+
+ auto maybe_live_out_buffers =
+ points_to_analysis
+ ->GetPointsToSet(module->entry_computation()->root_instruction())
+ .CreateFlattenedSet();
+
+ for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
+ EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
+ EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
+ }
}
-
- const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
};
TEST_F(CopyInsertionTest, SingleParameter) {
- // Computation is a single parameter passed into a tuple. The parameter should
- // be copied before entering the tuple.
auto builder = HloComputation::Builder(TestName());
HloInstruction* x = builder.AddInstruction(
HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
@@ -96,15 +77,14 @@ TEST_F(CopyInsertionTest, SingleParameter) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
EXPECT_THAT(module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(x)));
+ op::Tuple(op::Copy(old_root->operand(0))));
}
TEST_F(CopyInsertionTest, SingleConstant) {
- // Computation is a single constant passed into a tuple. The parameter should
- // be copied before entering the tuple.
auto builder = HloComputation::Builder(TestName());
HloInstruction* constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
@@ -116,42 +96,11 @@ TEST_F(CopyInsertionTest, SingleConstant) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 1);
EXPECT_THAT(module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(constant)));
-}
-
-TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) {
- // Verify that an kCopy instructions which exist in the pass before
- // copy-insertion remain in the graph after copy-insertion.
- auto module = CreateNewModule();
-
- auto builder = HloComputation::Builder(TestName());
- HloInstruction* constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
- HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kCopy, constant));
- HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
- constant->shape(), HloOpcode::kAdd, copy_1, copy_2));
- HloInstruction* add_copy = builder.AddInstruction(
- HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add));
-
- module->AddEntryComputation(builder.Build());
-
- EXPECT_EQ(CountCopies(*module), 3);
-
- InsertCopies(module.get());
-
- EXPECT_EQ(CountCopies(*module), 3);
-
- EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant()))));
+ op::Tuple(op::Copy(old_root->operand(0))));
}
TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
@@ -178,12 +127,12 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 2);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y)));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Tuple(op::Copy(old_root->operand(0)),
+ op::Copy(old_root->operand(1)), old_root->operand(2)));
}
TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
@@ -216,7 +165,6 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 2);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(op::GetTupleElement(old_root)),
@@ -239,7 +187,6 @@ TEST_F(CopyInsertionTest, BitcastParameter) {
HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 1);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(old_root));
@@ -261,7 +208,6 @@ TEST_F(CopyInsertionTest, BitcastConstant) {
HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 1);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(old_root));
@@ -281,11 +227,11 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast));
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 1);
EXPECT_THAT(module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(bitcast)));
+ op::Tuple(op::Copy(old_root->operand(0))));
}
TEST_F(CopyInsertionTest, NestedTupleParameter) {
@@ -311,8 +257,6 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) {
HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 3);
-
HloInstruction* new_root = module->entry_computation()->root_instruction();
EXPECT_NE(old_root, new_root);
@@ -349,13 +293,12 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
EXPECT_EQ(gte, module->entry_computation()->root_instruction());
+ HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 2);
- EXPECT_THAT(
- module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))),
- op::Copy(op::GetTupleElement(op::GetTupleElement(param)))));
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Tuple(op::Copy(op::GetTupleElement(old_root)),
+ op::Copy(op::GetTupleElement(old_root))));
}
TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
@@ -388,7 +331,6 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
HloInstruction* old_root = module->entry_computation()->root_instruction();
InsertCopies(module.get());
- EXPECT_EQ(CountCopies(*module), 1);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Copy(old_root));
@@ -404,10 +346,12 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// The parameter 'nested' specifies the loop state shape from which to
// read the induction variable.
std::unique_ptr<HloComputation> BuildConditionComputation(
- const Shape& loop_state_shape) {
+ bool nested = false) {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int32>(10)));
+ const Shape& loop_state_shape =
+ nested ? nested_loop_state_shape_ : loop_state_shape_;
auto loop_state = builder.AddInstruction(
HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
auto induction_variable =
@@ -638,7 +582,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto loop_state_init = builder.AddInstruction(
HloInstruction::CreateTuple({induction_var_init, inner_init}));
auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
- loop_state_init->shape(), condition, body, loop_state_init));
+ loop_state_shape_, condition, body, loop_state_init));
module_->AddEntryComputation(builder.Build());
return while_hlo;
}
@@ -714,28 +658,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
// Take a reference to 'data_init' to make it interfere with while result.
- auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data_init, one_vec));
- auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_,
- data_init, &builder);
-
- // Add an additional binary operation operating on the while and the
- // interfering add so that neither operation is dead.
- auto gte = xla_while->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1));
- auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary(
- data_shape_, HloOpcode::kSubtract, add, gte));
- auto gte0 = xla_while->parent()->AddInstruction(
- HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0));
- auto tuple = xla_while->parent()->AddInstruction(
- HloInstruction::CreateTuple({gte0, sub}));
-
- xla_while->parent()->set_root_instruction(tuple);
-
- return xla_while;
+ return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
+ &builder);
}
HloInstruction* BuildWhileInstructionWithCustomInit(
@@ -745,8 +672,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_);
auto induction_var_init = builder->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int32>(0)));
- auto condition = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape));
+ auto condition =
+ module_->AddEmbeddedComputation(BuildConditionComputation(nested));
auto body = module_->AddEmbeddedComputation(
BuildIndependentBodyComputation(nested));
auto loop_state_init = builder->AddInstruction(
@@ -779,21 +706,23 @@ class WhileCopyInsertionTest : public CopyInsertionTest {
// CopyInsertion pass should not generate any copies.
//
TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
- auto condition = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
+ auto condition = module_->AddEmbeddedComputation(BuildConditionComputation());
auto body =
module_->AddEmbeddedComputation(BuildIndependentBodyComputation());
auto while_hlo = BuildWhileInstruction(condition, body);
+ const HloInstruction* old_init = while_hlo->operand(0);
+ HloInstruction* old_root = body->root_instruction();
InsertCopies(module_.get());
+ HloInstruction* new_root = body->root_instruction();
+ const HloInstruction* new_init = while_hlo->operand(0);
- // Body should have no copies as the adds can be done inplace.
- EXPECT_EQ(CountCopies(*body), 0);
- EXPECT_EQ(CountControlEdges(*module_), 0);
+ // No copies should be inserted so root should not be updated.
+ EXPECT_EQ(old_root, new_root);
- // Both init indices need copies as they are constants.
- EXPECT_THAT(while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
+ // Both init indices need copies.
+ EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)),
+ op::Copy(old_init->operand(1))));
}
// Tests while body computation with dependent tuple elements:
@@ -808,33 +737,20 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
// Tuple(Copy(out0), out1)
//
TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
- auto condition = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
+ auto condition = module_->AddEmbeddedComputation(BuildConditionComputation());
auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation());
auto while_hlo = BuildWhileInstruction(condition, body);
+ const HloInstruction* old_init = while_hlo->operand(0);
+ HloInstruction* old_root = body->root_instruction();
InsertCopies(module_.get());
+ HloInstruction* new_root = body->root_instruction();
+ const HloInstruction* new_init = while_hlo->operand(0);
- EXPECT_EQ(CountCopies(*body), 1);
- EXPECT_EQ(CountControlEdges(*body), 0);
-
- EXPECT_THAT(
- body->root_instruction(),
- op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast())));
-
- auto add = body->root_instruction()->operand(0);
- auto bcast = body->root_instruction()->operand(1)->operand(1);
- ASSERT_EQ(add->opcode(), HloOpcode::kAdd);
- ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast);
-
- EXPECT_THAT(
- while_hlo->while_body()->root_instruction(),
- op::Tuple(op::Add(op::Copy(), op::Constant()),
- op::Add(op::GetTupleElement(), op::Broadcast(op::Copy()))));
-
- // Both init indices need copies as they are constants.
- EXPECT_THAT(while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
+ EXPECT_THAT(new_root,
+ op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1)));
+ EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)),
+ op::Copy(old_init->operand(1))));
}
// Tests while body computation with read-only tuple element 0:
@@ -852,26 +768,33 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
//
// CopyInsertion pass should not generate any copies for the while body.
TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
- auto condition = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
+ auto condition = module_->AddEmbeddedComputation(BuildConditionComputation());
auto body = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
- BuildWhileInstruction(condition, body);
+ auto while_hlo = BuildWhileInstruction(condition, body);
+ const HloInstruction* old_init = while_hlo->operand(0);
+ HloInstruction* old_root = body->root_instruction();
InsertCopies(module_.get());
+ HloInstruction* new_root = body->root_instruction();
+ const HloInstruction* new_init = while_hlo->operand(0);
+
+ // No copies should be inserted in the body, so root should not be updated.
+ EXPECT_EQ(old_root, new_root);
- // No copies or control edges should be inserted. The body is legal as is.
- EXPECT_EQ(CountCopies(*body), 0);
- EXPECT_EQ(CountControlEdges(*body), 0);
+ // Both indices need copies, even though Index 0 is read-only, since both are
+ // constants, which must be copied.
+ EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)),
+ op::Copy(old_init->operand(1))));
}
// Same as above, but with two while loops, sharing entry parameters.
TEST_F(WhileCopyInsertionTest,
DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) {
- auto condition1 = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
- auto condition2 = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
+ auto condition1 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto condition2 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
auto body1 = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
auto body2 = module_->AddEmbeddedComputation(
@@ -889,46 +812,30 @@ TEST_F(WhileCopyInsertionTest,
loop_state_shape_, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition2, body2, loop_init));
-
- // Add a couple elements from each of the while so both whiles are live.
- auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
- auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
- builder.AddInstruction(
- HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
-
- auto entry = module_->AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
InsertCopies(module_.get());
- // Neither body should have any copies or control edges in them.
- EXPECT_EQ(CountCopies(*body1), 0);
- EXPECT_EQ(CountCopies(*body2), 0);
- EXPECT_EQ(CountControlEdges(*body1), 0);
- EXPECT_EQ(CountControlEdges(*body2), 0);
+ // Both while loops alias iter_param, since index 0 is read-only in the body.
+ EXPECT_EQ(while_hlo1->operand(0)->operand(0),
+ while_hlo2->operand(0)->operand(0));
+ EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_param);
- // Only two copies should be necessary. Each of the whiles should have
- // a copy of tuple element 1 (init value is a parameter, and the element is
- // not non-read-only) so each of the while bodies gets its own buffer to write
- // element 1 into.
- EXPECT_EQ(CountCopies(*entry), 2);
-
- EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
- EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy);
-
- // The two copies of element 1 should be different.
+ // Each while loop gets its own copy of data_param, since index 1 is not
+ // read-only in the body.
EXPECT_NE(while_hlo1->operand(0)->operand(1),
while_hlo2->operand(0)->operand(1));
+ EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param));
+ EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param));
}
// Same as above, but with two while loops, sharing non-parameters.
TEST_F(WhileCopyInsertionTest,
DependentTupleElements_OneReadOnly_TwoLoops_NonParams) {
- auto condition1 = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
- auto condition2 = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape_));
+ auto condition1 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto condition2 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
auto body1 = module_->AddEmbeddedComputation(
BuildDependentBodyOneReadOnlyComputation());
auto body2 = module_->AddEmbeddedComputation(
@@ -951,28 +858,21 @@ TEST_F(WhileCopyInsertionTest,
loop_state_shape_, condition1, body1, loop_init));
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape_, condition2, body2, loop_init));
-
- // Add a couple elements from each of the while so both whiles are not dead.
- auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
- auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0));
- builder.AddInstruction(
- HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
- auto entry = module_->AddEntryComputation(builder.Build());
+ module_->AddEntryComputation(builder.Build());
InsertCopies(module_.get());
- // Ideally only one copy should be necessary. One of the whiles should
- // have a copy of tuple element 1 (the non-read-only element) so each of the
- // while bodies gets its own buffer to write element 1 into. However, the
- // analysis isn't perfect and adds an additional copy of element 0.
- EXPECT_EQ(CountCopies(*entry), 2);
+ // No copies of iter_value are necessary, since index 0 is read-only in both
+ // while bodies.
+ EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value);
+ EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value);
- EXPECT_THAT(while_hlo1->operand(0),
- op::Tuple(op::Exp(), op::Copy(op::Exp())));
- EXPECT_THAT(while_hlo2->operand(0),
- op::Tuple(op::Exp(), op::Copy(op::Exp())));
+ // Each while loop gets its own copy of data_value, since index 1 is not
+ // read-only in the body.
+ EXPECT_NE(while_hlo1->operand(0)->operand(1),
+ while_hlo2->operand(0)->operand(1));
+ EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value));
+ EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value));
}
// Tests while body computation with nested tuple elements:
@@ -1005,34 +905,18 @@ TEST_F(WhileCopyInsertionTest,
// Tuple // new root
//
TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
- auto condition = module_->AddEmbeddedComputation(
- BuildConditionComputation(nested_loop_state_shape_));
+ auto condition =
+ module_->AddEmbeddedComputation(BuildConditionComputation(true));
auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation());
BuildWhileInstruction(condition, body, true);
- // HloInstruction* old_root = body->root_instruction();
+ HloInstruction* old_root = body->root_instruction();
InsertCopies(module_.get());
- // The only copy necessary is for the kReverse as it cannot be done
- // in-place (instruction can share buffer with operand). The other elements of
- // the loop state are kAdd instructions which can be done in-place.
- EXPECT_EQ(CountCopies(*body), 1);
-
- // Each element of the init needs a copy as all are constants.
- EXPECT_EQ(CountCopies(*module_), 4);
-
- // Either the kReverse itself must be copied or the operand of the kReverse
- // must be copied.
- if (body->root_instruction()->operand(1)->operand(1)->opcode() ==
- HloOpcode::kCopy) {
- EXPECT_THAT(
- body->root_instruction(),
- op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse()))));
- } else {
- EXPECT_THAT(
- body->root_instruction(),
- op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy()))));
- }
+ EXPECT_THAT(body->root_instruction(),
+ op::Tuple(old_root->operand(0),
+ op::Tuple(old_root->operand(1)->operand(0),
+ op::Copy(old_root->operand(1)->operand(1)))));
}
// Tests while init instruction which points-to a constant.
@@ -1043,13 +927,11 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
//
TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
-
+ auto old_init = while_hlo->operand(0);
InsertCopies(module_.get());
- EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
- EXPECT_EQ(CountCopies(*module_), 2);
- EXPECT_THAT(while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant())));
+ EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)),
+ op::Copy(old_init->operand(1))));
}
// Tests while init instruction which points-to a parameter.
@@ -1060,13 +942,11 @@ TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
//
TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
-
+ auto old_init = while_hlo->operand(0);
InsertCopies(module_.get());
- EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
- EXPECT_EQ(CountCopies(*module_), 2);
- EXPECT_THAT(while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter())));
+ EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)),
+ op::Copy(old_init->operand(1))));
}
// Tests while init instruction which has an ambiguous points-to set.
@@ -1095,34 +975,15 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
//
TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
-
+ auto old_init = while_hlo->operand(0);
InsertCopies(module_.get());
- EXPECT_EQ(CountCopies(*module_), 4);
- // The entry computation requires three copies to resolve the ambiguity of two
- // init elements and the constant passed in as one of the init elements.
- EXPECT_EQ(CountCopies(*module_->entry_computation()), 3);
- EXPECT_THAT(while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()),
- op::Tuple(op::Copy(op::GetTupleElement()),
- op::Copy(op::GetTupleElement()))));
-
- // The body requires one copy because the buffer set is not distinct: the
- // result of one of the adds is written into two elements of the output of the
- // loop body. Either element might be copied.
- EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
- if (while_hlo->while_body()
- ->root_instruction()
- ->operand(1)
- ->operand(0)
- ->opcode() == HloOpcode::kCopy) {
- EXPECT_THAT(
- while_hlo->while_body()->root_instruction(),
- op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
- } else {
- EXPECT_THAT(
- while_hlo->while_body()->root_instruction(),
- op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
- }
+
+ EXPECT_THAT(
+ while_hlo->operand(0),
+ op::Tuple(
+ op::Copy(old_init->operand(0)),
+ op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))),
+ op::Copy(op::GetTupleElement(old_init->operand(1))))));
}
// Tests while init instruction which has a non-distinct points-to set.
@@ -1150,43 +1011,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
//
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
-
+ auto old_init = while_hlo->operand(0);
InsertCopies(module_.get());
- // The entry computation requires two copies to resolve the non-disinctness of
- // two init elements and the constant passed in as one of the init
- // elements. Either element can be copied for the distinctness issue.
- EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
- if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() ==
- HloOpcode::kCopy) {
- EXPECT_THAT(
- while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()),
- op::Tuple(op::Copy(op::Broadcast()), op::Broadcast())));
- } else {
- EXPECT_THAT(
- while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()),
- op::Tuple(op::Broadcast(), op::Copy(op::Broadcast()))));
- }
-
- // The body requires one copy because the buffer set is not distinct: the
- // result of one of the adds is written into two elements of the output of the
- // loop body. Either element might be copied.
- EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1);
- if (while_hlo->while_body()
- ->root_instruction()
- ->operand(1)
- ->operand(0)
- ->opcode() == HloOpcode::kCopy) {
- EXPECT_THAT(
- while_hlo->while_body()->root_instruction(),
- op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add())));
- } else {
- EXPECT_THAT(
- while_hlo->while_body()->root_instruction(),
- op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add()))));
- }
+ EXPECT_THAT(while_hlo->operand(0),
+ op::Tuple(op::Copy(old_init->operand(0)),
+ op::Tuple(op::Copy(old_init->operand(1)->operand(0)),
+ op::Copy(old_init->operand(1)->operand(0)))));
}
// Tests while init instruction buffer which interferes with while result
@@ -1200,13 +1031,11 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
//
TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
-
+ auto old_init = while_hlo->operand(0);
InsertCopies(module_.get());
- EXPECT_EQ(CountCopies(*module_), 2);
- EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0);
- EXPECT_THAT(while_hlo->operand(0),
- op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast())));
+ EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)),
+ op::Copy(old_init->operand(1))));
}
// Tests while init instruction buffer which has a non-distinct points-to set:
@@ -1215,21 +1044,18 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
// Parameter(F32, {8})))
//
// where the second and third parameters are identical *and* the tuple shared
-// by another while instruction.
+// by another while instruction..
//
// Verifies that the resulting point-to set is distinct in the resulting Tuple
// (non-identical Copys). In other words, verifies that copy sharing does not
// insert identical copies to the resulting tuple.
TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
+ auto condition1 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
+ auto condition2 =
+ module_->AddEmbeddedComputation(BuildConditionComputation());
// Loop body that outputs tuple comprises two elements dependent on the init
// tuple.
- const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
- {induction_variable_shape_, data_shape_, data_shape_});
-
- auto condition1 = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape));
- auto condition2 = module_->AddEmbeddedComputation(
- BuildConditionComputation(loop_state_shape));
auto body1 =
module_->AddEmbeddedComputation(BuildDependentBodyComputation2());
auto body2 =
@@ -1246,6 +1072,8 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto loop_init = builder.AddInstruction(
HloInstruction::CreateTuple({iter_param, data_param, data_param}));
+ const Shape& loop_state_shape = ShapeUtil::MakeTupleShape(
+ {induction_variable_shape_, data_shape_, data_shape_});
// Two while loops shares the same loop init tuple.
auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile(
@@ -1253,479 +1081,43 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) {
auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile(
loop_state_shape, condition2, body2, loop_init));
- // Add add instruction so neither while is dead.
- auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0));
- auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0));
- builder.AddInstruction(
- HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2));
-
module_->AddEntryComputation(builder.Build());
- InsertCopies(module_.get());
-
- // None of the bodies should have copies or control flow edges.
- EXPECT_EQ(CountCopies(*body1), 0);
- EXPECT_EQ(CountCopies(*body2), 0);
+ auto points_to_analysis =
+ TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
- // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally
- // these should not need to be copied before either while. However, copy
- // insertion is not able to reason about the transparency of elements through
- // while bodies in all circumstances so extra copies are added (b/xxx).
- EXPECT_EQ(CountCopies(*module_->entry_computation()), 2);
+ // Asserts that the init tuples before copy insertion is non-distinct.
+ ASSERT_FALSE(
+ points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct());
+ ASSERT_FALSE(
+ points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct());
- EXPECT_THAT(while_hlo1->operand(0),
- op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
- EXPECT_THAT(while_hlo2->operand(0),
- op::Tuple(op::Copy(), op::Parameter(), op::Parameter()));
-}
+ auto old_init1 = while_hlo1->operand(0);
+ auto old_init2 = while_hlo2->operand(0);
-TEST_F(CopyInsertionTest, SwizzlingWhile) {
- // Test a while instruction with a body which permutes its tuple parameter
- // elements.
- auto module = CreateNewModule();
- const Shape loop_state_shape =
- ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
-
- // Body simply interchanges the two tuple elements in the loop state.
- auto body_builder = HloComputation::Builder("body");
- auto body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto body_element_0 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
- auto body_element_1 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
- body_builder.AddInstruction(
- HloInstruction::CreateTuple({body_element_1, body_element_0}));
- HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
-
- auto cond_builder = HloComputation::Builder("condition");
- cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- cond_builder.AddInstruction(HloInstruction::CreateUnary(
- cond_constant->shape(), HloOpcode::kNot, cond_constant));
- HloComputation* condition =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
- auto tuple = builder.AddInstruction(
- HloInstruction::CreateTuple({constant1, constant2}));
- auto xla_while = builder.AddInstruction(
- HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
- module->AddEntryComputation(builder.Build());
-
- InsertCopies(module.get());
-
- EXPECT_EQ(CountCopies(*module), 6);
-
- // The loop state elements should be copied at the parameter and at the root
- // with a control edge in between (see DeepCopyAndAddControlEdges). This is
- // technically one more copy than is strictly necessary, but in order to have
- // only three copies the copies of different loop state elements must be
- // ordered with a control edge.
- EXPECT_EQ(CountCopies(*body), 4);
- EXPECT_EQ(CountControlEdges(*body), 2);
-
- EXPECT_THAT(body->root_instruction(),
- op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy())));
-
- EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
- EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
-}
-
-TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
- // Test a while instruction with a body which permutes its tuple parameter
- // elements and applies one operation to one of the elements. The addition of
- // the operation (instruction) on the element makes the live range of the
- // respective input and output elements different than if the instruction were
- // not there (as in the SwizzlingWhile test above).
- auto module = CreateNewModule();
- const Shape loop_state_shape =
- ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
-
- // Body interchanges the two tuple elements in the loop state and negates one
- // of them.
- auto body_builder = HloComputation::Builder("body");
- auto body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto body_element_0 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
- auto body_element_1 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
- auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
- scalar_shape_, HloOpcode::kNegate, body_element_1));
- body_builder.AddInstruction(
- HloInstruction::CreateTuple({negate, body_element_0}));
- HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
-
- auto cond_builder = HloComputation::Builder("condition");
- cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- cond_builder.AddInstruction(HloInstruction::CreateUnary(
- cond_constant->shape(), HloOpcode::kNot, cond_constant));
- HloComputation* condition =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- auto constant1 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto constant2 = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
- auto tuple = builder.AddInstruction(
- HloInstruction::CreateTuple({constant1, constant2}));
- auto xla_while = builder.AddInstruction(
- HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
- module->AddEntryComputation(builder.Build());
-
- InsertCopies(module.get());
-
- EXPECT_EQ(CountCopies(*module), 6);
-
- // The loop state elements should be copied at the parameter and at the root
- // with a control edge in between (see DeepCopyAndAddControlEdges).
- EXPECT_EQ(CountCopies(*body), 4);
- EXPECT_EQ(CountControlEdges(*body), 2);
-
- EXPECT_THAT(
- body->root_instruction(),
- op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy())));
-
- EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
- EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
-}
-
-TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) {
- // Test a while instruction with a body which permutes it's tuple parameter
- // elements similar to SwizzlinWhile above. However, in this test the input to
- // the while body is a single constant (both loop state elements are the same
- // constant). This means no copies are necessary because both loop state
- // elements are the same so interchanging them is a no-op.
- auto module = CreateNewModule();
- const Shape loop_state_shape =
- ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
-
- // Body simply interchanges the two tuple elements in the loop state.
- auto body_builder = HloComputation::Builder("body");
- auto body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto body_element_0 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
- auto body_element_1 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
- body_builder.AddInstruction(
- HloInstruction::CreateTuple({body_element_1, body_element_0}));
- HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
-
- auto cond_builder = HloComputation::Builder("condition");
- cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- cond_builder.AddInstruction(HloInstruction::CreateUnary(
- cond_constant->shape(), HloOpcode::kNot, cond_constant));
- HloComputation* condition =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto tuple =
- builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
- builder.AddInstruction(
- HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple));
- module->AddEntryComputation(builder.Build());
-
- InsertCopies(module.get());
-
- EXPECT_EQ(CountCopies(*module), 2);
- EXPECT_EQ(CountCopies(*body), 0);
-
- EXPECT_EQ(CountCopies(*module->entry_computation()), 2);
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(), op::Copy()));
-}
-
-TEST_F(CopyInsertionTest, SequentialWhiles) {
- // Construct a computation with a series of sequential while instructions
- // containing four loop state elements:
- //
- // element 0 is passed to each while directly from an entry parameter.
- //
- // element 1 is passed transparently in series through all the while bodies.
- //
- // element 2 is negated in each while body. (in-place possible)
- //
- // element 3 is reversed in each while body. (in-place not possible)
- //
- const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
- const Shape loop_state_shape = ShapeUtil::MakeTupleShape(
- {element_shape, element_shape, element_shape, element_shape});
-
- auto module = CreateNewModule();
- auto builder = HloComputation::Builder(TestName());
- auto param_0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, element_shape, "param_0"));
- auto param_1 = builder.AddInstruction(
- HloInstruction::CreateParameter(1, element_shape, "param_1"));
- auto param_2 = builder.AddInstruction(
- HloInstruction::CreateParameter(2, element_shape, "param_2"));
- auto param_3 = builder.AddInstruction(
- HloInstruction::CreateParameter(3, element_shape, "param_3"));
-
- // The number of sequential kWhile instructions.
- const int kNumWhiles = 3;
-
- HloInstruction* prev_element_1 = param_1;
- HloInstruction* prev_element_2 = param_2;
- HloInstruction* prev_element_3 = param_3;
-
- // Vector containing all of the while instructions.
- std::vector<const HloInstruction*> whiles;
- for (int i = 0; i < kNumWhiles; ++i) {
- auto body_builder = HloComputation::Builder("body");
- auto body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto body_element_0 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, body_param, 0));
- auto body_element_1 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, body_param, 1));
- auto body_element_2 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, body_param, 2));
- auto body_element_3 = body_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, body_param, 3));
- auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
- element_shape, HloOpcode::kNegate, body_element_2));
- auto reverse = body_builder.AddInstruction(
- HloInstruction::CreateReverse(element_shape, body_element_3, {0}));
- body_builder.AddInstruction(HloInstruction::CreateTuple(
- {body_element_0, body_element_1, negate, reverse}));
- HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
-
- auto cond_builder = HloComputation::Builder("condition");
- cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "param"));
- auto cond_constant = cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- cond_builder.AddInstruction(HloInstruction::CreateUnary(
- cond_constant->shape(), HloOpcode::kNot, cond_constant));
- HloComputation* condition =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto while_init = builder.AddInstruction(HloInstruction::CreateTuple(
- {param_0, prev_element_1, prev_element_2, prev_element_3}));
-
- auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile(
- loop_state_shape, condition, body, while_init));
- whiles.push_back(xla_while);
- if (i != kNumWhiles - 1) {
- prev_element_1 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1));
- prev_element_2 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2));
- prev_element_3 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3));
- }
- }
-
- module->AddEntryComputation(builder.Build());
-
- InsertCopies(module.get());
-
- // Each while body has one copy. And each loop state element is copied once in
- // the entry computation.
- EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles);
-
- // Each while body should have exactly one copy for element three which is an
- // op (kReverse) which cannot be done in place.
- for (const HloInstruction* xla_while : whiles) {
- EXPECT_EQ(CountCopies(*xla_while->while_body()), 1);
- }
-
- EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(),
- op::Copy(), op::Copy()));
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(),
- op::GetTupleElement()));
-}
-
-TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) {
- // Test a while body and condition which are each simply a constant (root of
- // computation is a constant). Each constant should be copied. The copy in the
- // condition is not strictly necessary, but added due to b/32248867.
- auto module = CreateNewModule();
- auto builder = HloComputation::Builder(TestName());
- auto param_0 = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param_0"));
-
- auto body_builder = HloComputation::Builder("body");
- body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
- body_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(123.0)));
- HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
-
- auto cond_builder = HloComputation::Builder("condition");
- cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
- cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- HloComputation* condition =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto xla_while = builder.AddInstruction(
- HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0));
-
- module->AddEntryComputation(builder.Build());
-
- InsertCopies(module.get());
-
- EXPECT_EQ(CountCopies(*module), 3);
-
- EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter()));
- EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant()));
- EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant()));
-}
-
-std::unique_ptr<HloComputation> MakeTrivialCondition(const Shape& shape) {
- auto builder = HloComputation::Builder("trivial_condition");
- builder.AddInstruction(
- HloInstruction::CreateParameter(0, shape, "loop_state"));
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- builder.AddInstruction(HloInstruction::CreateUnary(
- constant->shape(), HloOpcode::kNot, constant));
- return builder.Build();
-}
-
-std::unique_ptr<HloComputation> MakeBenchmarkWhileBody() {
- auto builder = HloComputation::Builder("benchmark_loop_body");
- const Shape element_shape = ShapeUtil::MakeShape(F32, {42});
- const Shape loop_state_shape =
- ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape});
- HloInstruction* param = builder.AddInstruction(
- HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
- HloInstruction* element_0 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, param, 0));
- HloInstruction* element_1 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, param, 1));
- HloInstruction* element_2 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(element_shape, param, 2));
-
- HloInstruction* rev_1 = builder.AddInstruction(
- HloInstruction::CreateReverse(element_shape, element_1, {0}));
- HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary(
- element_shape, HloOpcode::kAdd, element_1, element_2));
-
- builder.AddInstruction(
- HloInstruction::CreateTuple({element_0, rev_1, add_1_2}));
- return builder.Build();
-}
-
-void BM_SequentialWhiles(int num_iters, int num_whiles) {
- // This benchmark constructs a chain of sequential while instructions.
- tensorflow::testing::StopTiming();
- for (int i = 0; i < num_iters; ++i) {
- HloModuleConfig config;
- config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
- config);
-
- auto builder = HloComputation::Builder("BM_SequentialWhiles");
- HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {42}), "x"));
- HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {42}), "y"));
- HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
- 2, ShapeUtil::MakeShape(F32, {42}), "z"));
- HloInstruction* init =
- builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
-
- HloInstruction* prev_loop_state = init;
- for (int w = 0; w < num_whiles; ++w) {
- HloComputation* condition =
- module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
- HloComputation* body =
- module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
- prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile(
- init->shape(), condition, body, prev_loop_state));
- }
- module.AddEntryComputation(builder.Build());
-
- CopyInsertion copy_insertion;
-
- tensorflow::testing::StartTiming();
- ASSERT_IS_OK(copy_insertion.Run(&module).status());
- tensorflow::testing::StopTiming();
-
- // The entry computation should have three copies, and each body has one.
- ASSERT_EQ(CountCopies(module), 3 + num_whiles);
- }
-}
-
-void BM_ParallelWhiles(int num_iters, int num_whiles) {
- // This benchmark constructs a fan-out of parallel while instructions.
- tensorflow::testing::StopTiming();
- for (int i = 0; i < num_iters; ++i) {
- HloModuleConfig config;
- config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags());
- HloModule module("BM_SequentialWhiles", VersionedComputationHandle(),
- config);
-
- auto builder = HloComputation::Builder("BM_ParallelWhiles");
- HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, ShapeUtil::MakeShape(F32, {42}), "x"));
- HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
- 1, ShapeUtil::MakeShape(F32, {42}), "y"));
- HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
- 2, ShapeUtil::MakeShape(F32, {42}), "z"));
- HloInstruction* init =
- builder.AddInstruction(HloInstruction::CreateTuple({x, y, z}));
-
- HloInstruction* sum = nullptr;
- for (int w = 0; w < num_whiles; ++w) {
- HloComputation* condition =
- module.AddEmbeddedComputation(MakeTrivialCondition(init->shape()));
- HloComputation* body =
- module.AddEmbeddedComputation(MakeBenchmarkWhileBody());
-
- HloInstruction* xla_while = builder.AddInstruction(
- HloInstruction::CreateWhile(init->shape(), condition, body, init));
-
- if (sum == nullptr) {
- sum = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
- } else {
- HloInstruction* element_0 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0));
- sum = builder.AddInstruction(HloInstruction::CreateBinary(
- x->shape(), HloOpcode::kAdd, sum, element_0));
- }
- }
- module.AddEntryComputation(builder.Build());
-
- CopyInsertion copy_insertion;
+ InsertCopies(module_.get());
- tensorflow::testing::StartTiming();
- ASSERT_IS_OK(copy_insertion.Run(&module).status());
- tensorflow::testing::StopTiming();
+ EXPECT_THAT(while_hlo1->operand(0),
+ op::Tuple(op::Copy(old_init1->operand(0)),
+ op::Copy(old_init1->operand(1)),
+ op::Copy(old_init1->operand(2))));
- // Each body receives of copy of two of the parameters (the corresponding
- // elements in the body are modifed), and there is one copy in each body.
- ASSERT_EQ(CountCopies(module), 3 * num_whiles);
- }
+ EXPECT_THAT(while_hlo2->operand(0),
+ op::Tuple(op::Copy(old_init2->operand(0)),
+ op::Copy(old_init2->operand(1)),
+ op::Copy(old_init2->operand(2))));
+
+ // Verifies the init tuples after copy insertion is distinct.
+ points_to_analysis =
+ TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
+ const auto& points_to1 =
+ points_to_analysis->GetPointsToSet(while_hlo1->operand(0));
+ EXPECT_TRUE(points_to1.IsDistinct());
+
+ const auto& points_to2 =
+ points_to_analysis->GetPointsToSet(while_hlo2->operand(0));
+ EXPECT_TRUE(points_to2.IsDistinct());
}
-BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
-BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096);
-
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index af2bd6d5d7..46e83282d5 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -243,81 +243,6 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
};
-
-// This copy insertion pass is a hack to address deficiencies in buffer
-// assignment. Buffer assignment uses TuplePointsToAnalysis which is
-// computation-scoped and thus has limited visibility across computation
-// boundaries. However, CopyInsertion uses module-scoped HloAliasAnalysis and
-// expects buffer assignment to have the same understanding of the graph. This
-// mismatch manifests in the parallel cpu backend, where the HLO outlining
-// results is a minefield of potential problems. This pass conservatively adds
-// copies to avoid any potential problems in buffer assignemnt.
-//
-// Technically these issues exist in all the backends. However, they only
-// manifest in the parallel cpu backend because of the outlining. Moving this
-// into the main copy insertion pass results in performance regressions n the
-// other backends.
-//
-// TODO(b/62548313): Remove this.
-class CpuParallelCopyInsertion : public HloPassInterface {
- public:
- tensorflow::StringPiece name() const override {
- return "cpu-parallel-copy-insertion";
- }
-
- StatusOr<bool> Run(HloModule* module) override {
- // Copy roots of all non-entry sequentially-called (eg, kCall, kWhile)
- // computations.
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
- TF_RETURN_IF_ERROR(
- call_graph->VisitNodes([module](const CallGraphNode& node) -> Status {
- if (node.context() == CallContext::kSequential &&
- !node.caller_callsites().empty()) {
- TF_ASSIGN_OR_RETURN(HloInstruction * root_copy,
- node.computation()->DeepCopyInstruction(
- node.computation()->root_instruction()));
- node.computation()->set_root_instruction(root_copy);
- }
- return Status::OK();
- }));
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
- HloDataflowAnalysis::Run(module));
-
- // Add copies to the operand of dynamic update slices which have read-only
- // values (constants and parameters). Buffer assignment which is based on
- // computation-scoped tuple points-to analysis does not properly track these
- // read-only values across kCall instructions. This can result in cases
- // where a outlined computation parameter operand of a dynamic update slice
- // aliases a constant or parameter in the entry computation and the dynamic
- // update slice is attempted in-place.
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) {
- HloInstruction* operand = instruction->mutable_operand(0);
- for (const HloValue* value :
- dataflow->GetValueSet(operand).values()) {
- if (value->defining_instruction()->opcode() ==
- HloOpcode::kConstant ||
- value->defining_instruction()->opcode() ==
- HloOpcode::kParameter) {
- HloInstruction* operand_copy =
- instruction->parent()->AddInstruction(
- HloInstruction::CreateUnary(operand->shape(),
- HloOpcode::kCopy, operand));
- TF_RETURN_IF_ERROR(
- operand->ReplaceUseWith(instruction, operand_copy));
- break;
- }
- }
- }
- }
- }
-
- return true;
- }
-};
-
} // namespace
Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
@@ -406,16 +331,15 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
// (and sometime after) copy insertion, to avoid dead code from interfering
// with the rewrites.
pipeline.AddPass<HloDCE>();
- pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CopyInsertion>();
if (options::CpuParallelBackendRequested(module->config())) {
// Re-run the outlining, in case any copies were inserted into the entry
// computation.
pipeline.AddPass<ParallelizationPreparation>(max_parallelism,
ShapeSizeBytesFunction());
- pipeline.AddPass<CpuParallelCopyInsertion>();
}
pipeline.AddPass<HloDCE>();
+ pipeline.AddPass<FlattenCallGraph>();
return pipeline.Run(module).status();
}
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index df7e128217..b9c4adce93 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -350,8 +350,8 @@ cc_library(
":ir_emission_utils",
"//tensorflow/compiler/xla/service:copy_insertion",
"//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
- "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/compiler/xla/service:tuple_points_to_analysis",
"//tensorflow/core:lib",
],
)
@@ -573,14 +573,11 @@ tf_cc_test(
deps = [
":instruction_fusion",
":while_transformer",
- "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:copy_insertion",
- "//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "//tensorflow/core:test",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc
index f7a3260641..3dc8555201 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc
@@ -22,53 +22,41 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace gpu {
-StatusOr<HloInstruction*> GpuCopyInsertion::FindOrInsertCopy(
- HloInstruction* hlo) {
- auto copy_it = inserted_copies_.find(hlo);
- if (copy_it == inserted_copies_.end()) {
- HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie();
- inserted_copies_.insert({hlo, copy});
- return copy;
- } else {
- return copy_it->second;
- }
-}
-
StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
- CopyInsertion generic_copy_insertion;
+ TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module));
- TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
- HloDataflowAnalysis::Run(module));
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
// Make sure all operands of a library call are in memory instead of constants
- // in IR.
+ // in IR. The top-level (index {}) of the points-to set of each operand
+ // indicates the source(s) of the array buffer. If any of these are constant,
+ // then add a copy to materialize the array.
HloComputation* computation = module->entry_computation();
for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
if (ImplementedAsLibraryCall(*hlo)) {
for (int64 i = 0; i < hlo->operand_count(); ++i) {
HloInstruction* operand = hlo->mutable_operand(i);
- TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
- bool copy_operand = false;
- for (const HloValue* value : dataflow->GetValueSet(operand).values()) {
- if (value->defining_instruction()->opcode() == HloOpcode::kConstant) {
- copy_operand = true;
- break;
- }
- }
- if (copy_operand) {
- TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand));
+ const PointsToSet& points_to =
+ points_to_analysis->GetPointsToSet(operand);
+ const auto& element = points_to.element(/*index=*/{});
+ if (std::any_of(element.begin(), element.end(),
+ [](const LogicalBuffer* buffer_source) {
+ return buffer_source->instruction()->opcode() ==
+ HloOpcode::kConstant;
+ })) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * copy,
+ CopyInsertion::FindOrInsertCopy(operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy));
changed = true;
}
@@ -76,31 +64,6 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
}
}
- // Init values of a while nodes cannot be constants. Insert copies for any
- // constants found at the operand of a while.
- tensorflow::gtl::FlatSet<HloInstruction*> copied_constants;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kWhile) {
- for (auto& pair :
- dataflow->GetInstructionValueSet(instruction->operand(0))) {
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction()->opcode() ==
- HloOpcode::kConstant &&
- !ContainsKey(copied_constants, value->defining_instruction())) {
- HloInstruction* constant = value->defining_instruction();
- TF_ASSIGN_OR_RETURN(HloInstruction * copy,
- FindOrInsertCopy(constant));
- TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy));
- copied_constants.insert(constant);
- }
- }
- }
- }
- }
- }
-
return changed;
}
diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/copy_insertion.h
index 2ca9a13fd8..11077dad2e 100644
--- a/tensorflow/compiler/xla/service/gpu/copy_insertion.h
+++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.h
@@ -16,8 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
namespace gpu {
@@ -25,20 +25,9 @@ namespace gpu {
// Besides the modifications made by the generic xla::CopyInsertion, this
// GPU-specific copy insertion also materializes operands of library calls by
// inserting kCopy instructions.
-class GpuCopyInsertion : public HloPassInterface {
+class GpuCopyInsertion : public CopyInsertion {
public:
- tensorflow::StringPiece name() const override { return "copy-insertion"; }
-
StatusOr<bool> Run(HloModule* module) override;
-
- protected:
- // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
- // duplicate copies.
- StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
-
- // A map containing all copies inserted to materialize operands of library
- // calls. The key is the copied instruction and the value is the copy.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> inserted_copies_;
};
} // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 80dccf5b65..2caa8f6051 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -220,8 +220,9 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
// (and sometime after) copy insertion, to avoid dead code from interfering
// with the rewrites.
pipeline.AddPass<HloDCE>();
- pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<GpuCopyInsertion>();
+ pipeline.AddPass<HloDCE>();
+ pipeline.AddPass<FlattenCallGraph>();
return pipeline.Run(hlo_module).status();
}
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
index f16daa0b54..44188473d3 100644
--- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -17,12 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
-#include "tensorflow/compiler/xla/service/hlo_verifier.h"
-#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
@@ -36,6 +33,8 @@ class WhileTransformerTest : public HloTestBase {
: module_(CreateNewModule()),
induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
data_shape_(ShapeUtil::MakeShape(F32, {8})),
+ loop_state_shape_(ShapeUtil::MakeTupleShape(
+ {induction_variable_shape_, data_shape_})),
condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
std::unique_ptr<HloComputation> BuildConditionComputation(
@@ -43,8 +42,8 @@ class WhileTransformerTest : public HloTestBase {
auto builder = HloComputation::Builder(TestName() + ".Condition");
auto limit_const = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int32>(limit)));
- auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, GetLoopStateShape(tuple_index), "loop_state"));
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
limit_const->shape(), loop_state, tuple_index));
@@ -59,8 +58,8 @@ class WhileTransformerTest : public HloTestBase {
const int64 increment) {
auto builder = HloComputation::Builder(TestName() + ".Body");
// Create param instruction to access loop state.
- auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
- 0, GetLoopStateShape(ind_var_tuple_index), "loop_state"));
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
// Update the induction variable GTE(ind_var_tuple_index).
auto induction_variable =
builder.AddInstruction(HloInstruction::CreateGetTupleElement(
@@ -74,7 +73,7 @@ class WhileTransformerTest : public HloTestBase {
data_shape_, loop_state, data_tuple_index));
// Use 'induction_variable' in computation with no path to output tuple.
auto update = builder.AddInstruction(
- HloInstruction::CreateBroadcast(data_shape_, induction_variable, {}));
+ HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape_, HloOpcode::kAdd, data, update));
// Create output Tuple.
@@ -99,9 +98,8 @@ class WhileTransformerTest : public HloTestBase {
HloInstruction::CreateTuple({induction_var_init, data_init}))
: builder.AddInstruction(
HloInstruction::CreateTuple({data_init, induction_var_init}));
- auto while_hlo = builder.AddInstruction(
- HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index),
- condition, body, loop_state_init));
+ auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_state_shape_, condition, body, loop_state_init));
module_->AddEntryComputation(builder.Build());
return while_hlo;
}
@@ -117,34 +115,18 @@ class WhileTransformerTest : public HloTestBase {
}
void RunCopyInsertionPass() {
- HloVerifier verifier([](const Shape& shape) {
- return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
- });
- TF_ASSERT_OK(verifier.Run(module_.get()).status());
CopyInsertion copy_insertion;
- TF_ASSERT_OK(copy_insertion.Run(module_.get()).status());
- }
-
- Shape GetLoopStateShape(const int64 ind_var_tuple_index) {
- if (ind_var_tuple_index == 0) {
- return ShapeUtil::MakeTupleShape(
- {induction_variable_shape_, data_shape_});
- } else {
- return ShapeUtil::MakeTupleShape(
- {data_shape_, induction_variable_shape_});
- }
+ EXPECT_IS_OK(copy_insertion.Run(module_.get()).status());
}
std::unique_ptr<HloModule> module_;
Shape induction_variable_shape_;
Shape data_shape_;
+ Shape loop_state_shape_;
Shape condition_result_shape_;
};
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
// Build computation with induction variable at tuple element 0.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
@@ -155,16 +137,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) {
RunCopyInsertionPass();
// Run WhileTransformer.
auto result = gpu::CanTransformWhileToFor(while_hlo);
- TF_ASSERT_OK(result.status());
+ ASSERT_TRUE(result.ok());
// Check results.
EXPECT_THAT(result.ConsumeValueOrDie(),
Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
}
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
// Build computation with induction variable at tuple element 1.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(1, 10));
@@ -175,16 +154,13 @@ TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) {
RunCopyInsertionPass();
// Run WhileTransformer.
auto result = gpu::CanTransformWhileToFor(while_hlo);
- TF_ASSERT_OK(result.status());
+ ASSERT_TRUE(result.ok());
// Check results.
EXPECT_THAT(result.ConsumeValueOrDie(),
Eq(std::tuple<int64, int64, int64>(0, 10, 1)));
}
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
+TEST_F(WhileTransformerTest, InvalidLoopLimit) {
// Build computation with invalid loop limit.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 5));
@@ -200,10 +176,7 @@ TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) {
HasSubstr("Loop start must be less than loop limit."));
}
-// TODO(b/68830972): The while transformer is far too fragile. It patterns
-// matches the exact expressions of opcodes. Re-enable when transformation is
-// more general
-TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) {
+TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
// Build computation with invalid loop increment.
auto condition =
module_->AddEmbeddedComputation(BuildConditionComputation(0, 10));
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index 0fb11792b8..6f80994751 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -144,10 +144,8 @@ class BufferValueMap {
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
- tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
- buffers_.at(old_buffer_number);
- old_value_set.erase(&value);
- if (old_value_set.empty()) {
+ buffers_.at(old_buffer_number).erase(&value);
+ if (buffers_.at(old_buffer_number).empty()) {
buffers_.erase(old_buffer_number);
}
@@ -177,7 +175,7 @@ class BufferValueMap {
// Value is init of a while (use is while).
std::vector<BufferNumber> aliased_buffers;
for (const HloUse& use : value.uses()) {
- VLOG(2) << "use of value " << value.ToShortString() << ": " << use;
+ VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
if (use.instruction->opcode() == HloOpcode::kWhile) {
// Determine the while value that this shares a buffer with.
const HloValue& while_value =
@@ -413,7 +411,7 @@ string HloAliasAnalysis::ToString() const {
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloModule* module) {
- VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
+ VLOG(1) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 1677c77f2e..8f595b45e9 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -412,18 +412,16 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation,
+ tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
HloInstruction* fusion_instruction) {
std::vector<std::unique_ptr<HloInstruction>> instructions;
tensorflow::gtl::FlatMap<string, HloInstruction*> instruction_map;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
- HloInstruction::CreateFromProto(
- module, instruction_proto, instruction_map,
- computation_map, add_fused_computation));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloInstruction> instruction,
+ HloInstruction::CreateFromProto(module, instruction_proto,
+ instruction_map, computation_map));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
@@ -533,7 +531,6 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
if (indices_to_copy != nullptr &&
!ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
- LOG(FATAL) << "DEATH!";
return FailedPrecondition(
"Can't deep copy instruction %s: given shape tree of indices to copy "
"has incompatible shape",
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 3208197f89..c9782cc981 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -152,18 +152,12 @@ class HloComputation {
// computation_map: a map from computation name to HloComputation*. This map
// must contain all computations which the newly constructed computation
// calls.
- // add_fused_computation: A function to call to add a fused
- // computation. Used (clearly) when the instruction is a fusion
- // instruction.
- // fusion_instruction: if non-null then the newly created computation will
- // be
+ // fusion_instruction: if non-null then the newly created computation will be
// constructed as a fused computation with this instruction as its fusion
// parent.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
HloModule* module, const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation,
+ tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map,
HloInstruction* fusion_instruction = nullptr);
// Gets the instructions in this computation.
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 2286cfe488..92261bce62 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -75,41 +75,11 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
std::forward_as_tuple(value_id, instruction, index, is_phi));
CHECK(emplaced.second);
- VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
-
return &emplaced.first->second;
}
-void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
- HloValue& value = values_.at(value_id);
- VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
-
- value_ids_to_delete_.push_back(value_id);
-}
-
-void HloDataflowAnalysis::DeleteMarkedValues() {
- // Verify that no marked-for-deletion values are in any of the value sets.
- tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
- value_ids_to_delete_.end());
- for (const auto& pair : value_sets_) {
- const HloInstruction* instruction = pair.first;
- const InstructionValueSet& instruction_value_set = pair.second;
- for (const auto& index_value_set : instruction_value_set) {
- const HloValueSet& value_set = index_value_set.second;
- for (const HloValue* value : value_set.values()) {
- DCHECK(!ContainsKey(id_set, value->id()))
- << "Value " << value->ToShortString()
- << " marked for deletion, but still exists in value set for "
- "instruction "
- << instruction->name();
- }
- }
- }
-
- for (HloValue::Id value_id : value_ids_to_delete_) {
- values_.erase(value_id);
- }
- value_ids_to_delete_.clear();
+void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
+ values_.erase(value_id);
}
string HloDataflowAnalysis::ToString() const {
@@ -151,7 +121,6 @@ bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
- VLOG(4) << "Phi(" << instruction->name() << ")";
for (const InstructionValueSet* input : inputs) {
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
@@ -214,7 +183,7 @@ bool HloDataflowAnalysis::Phi(
} else if (current_value != &new_value) {
if (current_value_defined_here) {
// Remove the existing phi.
- MarkValueForDeletion(current_value->id());
+ DeleteHloValue(current_value->id());
}
value_set.Clear();
value_set.AddValue(&new_value);
@@ -224,8 +193,7 @@ bool HloDataflowAnalysis::Phi(
// Multiple distinct values reach this point. A phi value is
// necessary.
CHECK_GT(input_value_ids.size(), 1);
- if (current_value == nullptr ||
- !(current_value->is_phi() && current_value_defined_here)) {
+ if (current_value == nullptr || !current_value->is_phi()) {
value_set.Clear();
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
changed = true;
@@ -468,13 +436,11 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
}
}
-void HloDataflowAnalysis::Propagate() {
+void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
std::queue<HloInstruction*> worklist;
-
- for (HloComputation* computation : module_->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- worklist.push(instruction);
- }
+ for (HloInstruction* instruction : instructions) {
+ worklist.push(instruction);
}
while (!worklist.empty()) {
@@ -631,10 +597,18 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
- dataflow_analysis->Propagate();
- // Delete all values marked for deletion.
- dataflow_analysis->DeleteMarkedValues();
+ // Construct list of all instructions to initialize the worklist to propagate
+ // the data flow. For efficiency sort the instruction in post order so
+ // producers appear before consumers.
+ std::vector<HloInstruction*> all_instructions;
+ for (const HloComputation* computation : module->MakeComputationPostOrder()) {
+ for (HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ all_instructions.push_back(instruction);
+ }
+ }
+ dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
// Add in positions to all values.
for (const HloComputation* computation : module->computations()) {
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 49b1343873..207e553bf7 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -126,16 +126,13 @@ class HloDataflowAnalysis {
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
bool is_phi = false);
- // Mark the HloValue with the given ID for deletion.
- void MarkValueForDeletion(HloValue::Id value_id);
-
- // Delete all HloValues marked for deletion. Should be called after
- // propagation is complete.
- void DeleteMarkedValues();
+ // Delete the HloValue with the given ID.
+ void DeleteHloValue(HloValue::Id value_id);
// Constructs and initializes the InstructionValueSets of all instructions to
// contain exactly the HloValues defined by each instruction. These values can
- // then propagated throughout the HLO graph by calling Propagate.
+ // then propagated throughout the HLO graph by calling
+ // UpdateInstructionsAndPropagate.
Status InitializeInstructionValueSets();
// Updates the value set of the given instruction based on the values flowing
@@ -153,8 +150,10 @@ class HloDataflowAnalysis {
bool UpdateTupleValueSet(HloInstruction* tuple);
bool UpdateWhileValueSet(HloInstruction* xla_while);
- // Propagate the dataflow through the module.
- void Propagate();
+ // Update the value sets of the given instructions and propagate the
+ // changes to fixed point.
+ void UpdateInstructionsAndPropagate(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
// Return the result of the SSA Phi function applied to the given inputs at
// the given instruction. If skip_top_level is true, then the top level of the
@@ -190,11 +189,6 @@ class HloDataflowAnalysis {
// A map from instruction to InstructionValueSet.
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
- // Values marked for deletion during construction. We don't delete them
- // immediately because references to them may still remain in ValueSets. After
- // construction, these values are deleted.
- std::vector<HloValue::Id> value_ids_to_delete_;
-
// A vector containing all HloValues sorted by HloValue::Id.
std::vector<const HloValue*> values_vector_;
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index 40e67c8780..a4921232f5 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -37,9 +37,6 @@ namespace xla {
StatusOr<bool> HloDCE::Run(HloModule* module) {
bool changed = false;
- VLOG(2) << "Before dce:";
- XLA_VLOG_LINES(2, module->ToString());
-
for (auto* computation : module->MakeNonfusionComputations()) {
std::unordered_set<HloInstruction*> live_instructions;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
@@ -61,8 +58,6 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
for (HloInstruction* dead_root : dead_roots) {
- VLOG(1) << "Removing dead root " << dead_root->ToString()
- << " and it's unused operands";
TF_RETURN_IF_ERROR(
computation->RemoveInstructionAndUnusedOperands(dead_root));
changed = true;
@@ -92,9 +87,6 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
}
}
- VLOG(2) << "After dce:";
- XLA_VLOG_LINES(2, module->ToString());
-
return changed;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 2c7e735a1c..e09899e48d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -51,9 +51,7 @@ using ::tensorflow::strings::StrCat;
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation) {
+ tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
@@ -79,19 +77,19 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(!proto.fusion_kind().empty());
TF_ASSIGN_OR_RETURN(instruction->fusion_kind_,
StringToFusionKind(proto.fusion_kind()));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> fused_computation,
- HloComputation::CreateFromProto(
- module, proto.fused_instructions_computation(),
- computation_map, add_fused_computation,
- /*fusion_instruction=*/instruction.get()));
- instruction->called_computations_.push_back(fused_computation.get());
- add_fused_computation(std::move(fused_computation));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloComputation> fused_computation,
+ HloComputation::CreateFromProto(
+ module, proto.fused_instructions_computation(), computation_map,
+ /*fusion_instruction=*/instruction.get()));
+ instruction->called_computations_.push_back(
+ module->AddEmbeddedComputation(std::move(fused_computation)));
} else {
for (const string& computation_name : proto.called_computation_names()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_name))
+ TF_RET_CHECK(ContainsKey(*computation_map, computation_name))
<< "No computation named " << computation_name;
instruction->called_computations_.push_back(
- computation_map.at(computation_name));
+ computation_map->at(computation_name));
}
}
@@ -2011,10 +2009,8 @@ string HloInstruction::ToCategory() const {
bool saw_rank_1 = false;
bool saw_higher_rank = false;
for (const auto* operand : operands()) {
- if (!ShapeUtil::IsTuple(operand->shape())) {
- saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
- saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
- }
+ saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1;
+ saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1;
}
if (saw_rank_1 && saw_higher_rank) {
return "rank-1-broadcast binary fusion";
@@ -2299,8 +2295,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
template Status HloInstruction::Visit(DfsHloVisitor* visitor);
template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
-using DFSStack = tensorflow::gtl::InlinedVector<
- std::pair<HloInstruction::Id, HloInstruction*>, 16>;
+using DFSStack =
+ tensorflow::gtl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
// Push "child" onto the dfs_stack if not already visited. Returns false if a
// cycle was detected, and true otherwise.
@@ -2308,7 +2304,7 @@ template <typename Visitor>
inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
HloInstruction* child) {
CHECK(child != nullptr);
- const HloInstruction::Id id = child->unique_id();
+ const int id = child->unique_id();
CHECK_GE(id, 0) << "instruction may not have a parent computation";
switch (visitor->GetVisitState(id)) {
case Visitor::kVisiting:
@@ -2325,8 +2321,8 @@ inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
}
using InternalCompareFunction =
- std::function<bool(std::pair<HloInstruction::Id, const HloInstruction*>,
- std::pair<HloInstruction::Id, const HloInstruction*>)>;
+ std::function<bool(std::pair<int, const HloInstruction*>,
+ std::pair<int, const HloInstruction*>)>;
template <typename Visitor>
static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
const InternalCompareFunction* operand_order,
@@ -2345,7 +2341,7 @@ static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
do {
DCHECK(!dfs_stack.empty());
- HloInstruction::Id current_id = dfs_stack.back().first;
+ int current_id = dfs_stack.back().first;
HloInstruction* current_node = dfs_stack.back().second;
CHECK_GE(current_id, 0) << current_id << ": " << current_node
<< ": instruction may not have parent computation";
@@ -2424,13 +2420,13 @@ Status HloInstruction::AcceptWithOperandOrder(
DfsHloVisitor* visitor, const CompareFunction& operand_order,
bool call_finish_visit) {
VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")";
- InternalCompareFunction func =
- [&operand_order](std::pair<HloInstruction::Id, const HloInstruction*> a,
- std::pair<HloInstruction::Id, const HloInstruction*> b) {
- // Call the client's comparison function on the actual HloInstruction*
- // objects (ignoring the internal ids we also have in our stack entries)
- return operand_order(a.second, b.second);
- };
+ InternalCompareFunction func = [&operand_order](
+ std::pair<int, const HloInstruction*> a,
+ std::pair<int, const HloInstruction*> b) {
+ // Call the client's comparison function on the actual HloInstruction*
+ // objects (ignoring the internal ids we also have in our stack entries)
+ return operand_order(a.second, b.second);
+ };
TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
/*ignore_control_predecessors=*/false));
if (call_finish_visit) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 411f926a87..4d8fe6bc10 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -83,16 +83,12 @@ class HloInstruction {
// must contain all operands of the newly constructed instruction.
// computation_map: a map from computation name to HloComputation*. This map
// must contain all computations which the newly constructed instruction
- // calls.
- // add_fused_computation: A function to call to add a fused
- // computation. Used (clearly) when the instruction is a fusion
- // instruction.
+ // calls. If the instruction is a fusion instruction, then the fusion
+ // computation is added to this map and the module.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
HloModule* module, const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
- const std::function<void(std::unique_ptr<HloComputation>)>&
- add_fused_computation);
+ tensorflow::gtl::FlatMap<string, HloComputation*>* computation_map);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
@@ -981,8 +977,7 @@ class HloInstruction {
void UniquifyName(NameUniquer* name_uniquer);
// Set the unique id for this instruction to "id"
- using Id = int;
- void SetUniqueId(Id id) {
+ void SetUniqueId(int id) {
CHECK_EQ(unique_id_, -1); // Should not be assigned already
CHECK_GE(id, 0);
unique_id_ = id;
@@ -990,7 +985,7 @@ class HloInstruction {
// Return the unique ID assigned to this node via SetUniqueId (or -1
// if no id has been assigned yet).
- Id unique_id() const { return unique_id_; }
+ int unique_id() const { return unique_id_; }
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
@@ -1093,7 +1088,7 @@ class HloInstruction {
// Returns how this instruction uses elements of its `i`th operand.
UseKind OperandElementUse(int64 i) const;
- Id unique_id_; // Unique to this HloInstruction within a HloModule
+ int unique_id_; // Unique to this HloInstruction within a HloModule
// Opcode for this instruction.
HloOpcode opcode_;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index d2cee6f8b1..659f3d8c26 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -296,16 +296,9 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
for (const HloComputationProto& computation_proto : proto.computations()) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<HloComputation> computation,
- HloComputation::CreateFromProto(
- module.get(), computation_proto, computation_map,
- /*add_fused_computation=*/
- [&module](std::unique_ptr<HloComputation> fused_computation) {
- module->AddComputationInternal(std::move(fused_computation),
- /*is_entry=*/false,
- /*uniquify_names=*/false);
- }));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
+ HloComputation::CreateFromProto(
+ module.get(), computation_proto, &computation_map));
CHECK_NE(computation.get(), nullptr);
TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
string computation_name = computation->name();
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 1f9a989961..e6cf0d37b8 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -184,7 +184,7 @@ void HloValue::AddPosition(HloInstruction* instruction,
live_out_of_module_ = true;
}
- if (instruction == defining_instruction()->parent()->root_instruction()) {
+ if (instruction == instruction->parent()->root_instruction()) {
live_out_of_computation_ = true;
}
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc
index 2ecf57ad3d..34899b7400 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc
@@ -55,34 +55,22 @@ static Status EmitDynamicUpdateSliceInPlaceImpl(
// Calculate output_index, where we'll write the value from update. For
// each dimension,
//
- // output_index[dim] = (start_index[dim] + update_index[dim])
+ // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size.
//
IrArray::Index output_index(rank);
for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* dim_size = llvm::ConstantInt::get(
+ update_index[i]->getType(), output_shape.dimensions(i));
llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast(
start_index[i], update_index[i]->getType());
- output_index[i] = ir_builder->CreateAdd(start_index0, update_index[i]);
- }
-
- // Check if 'index' intersects start/end indices. If it does not (indices
- // are out of bounds) then no update is performed.
- llvm::Value* in_bounds = llvm::ConstantInt::get(ir_builder->getInt1Ty(), 1);
- for (int64 i = 0; i < rank; ++i) {
- llvm::Value* dim_size = llvm::ConstantInt::get(
- output_index[i]->getType(), output_shape.dimensions(i));
- in_bounds = ir_builder->CreateAnd(
- in_bounds, ir_builder->CreateICmpSLT(output_index[i], dim_size),
- "in_bounds");
+ output_index[i] = ir_builder->CreateURem(
+ ir_builder->CreateAdd(start_index0, update_index[i]), dim_size);
}
// Do output[output_index] = update[update_index].
TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
update_array_generator(update_index));
- llvm::Value* input_data =
- output_array.EmitReadArrayElement(output_index, ir_builder);
- llvm::Value* to_write_data =
- ir_builder->CreateSelect(in_bounds, update_data, input_data);
- output_array.EmitWriteArrayElement(output_index, to_write_data, ir_builder);
+ output_array.EmitWriteArrayElement(output_index, update_data, ir_builder);
return Status::OK();
};
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 5a012c93d6..4920f17a7e 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -180,8 +180,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) {
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
-// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers.
-XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) {
+XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
ComputationBuilder b(client_, TestName());
ComputationDataHandle v1, v2;
diff --git a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
index f568f58154..92b2b1ee77 100644
--- a/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
+++ b/tensorflow/compiler/xla/tests/xla_internal_test_main.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
GTEST_API_ int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
@@ -31,7 +30,5 @@ GTEST_API_ int main(int argc, char** argv) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return 2;
}
- int result = RUN_ALL_TESTS();
- tensorflow::testing::RunBenchmarks();
- return result;
+ return RUN_ALL_TESTS();
}