diff options
author | 2017-11-01 19:11:49 -0700 | |
---|---|---|
committer | 2017-11-01 19:15:17 -0700 | |
commit | 53a4fcbdbad571e659203733f6a07ba82651d40b (patch) | |
tree | 58f22f4b697ff81126be3f49baf65b841f35762b | |
parent | 67fe8d146a0aa642a29a52a1389000b99b19cc03 (diff) |
Fixed HloComputation/HloInstruction clone to allow deep clone, and avoid the cloned instruction and computations to still have live link to their parent original modules and computations.
PiperOrigin-RevId: 174271432
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 37 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 18 |
4 files changed, 42 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index b5b07aeb72..ed776b9933 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -724,7 +724,8 @@ Status HloComputation::Accept( return this->Accept(&visitor); } -std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix) { +std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix, + HloModule* module) { VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; auto postorder = MakeInstructionPostOrder(); std::unordered_map<HloInstruction*, HloInstruction*> clone_map; @@ -737,12 +738,8 @@ std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix) { CHECK(new_operand != nullptr); new_operands.push_back(new_operand); } - - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands); - new_instr->set_metadata(instr->metadata()); - if (instr->has_sharding()) { - new_instr->set_sharding(instr->sharding()); - } + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, module); InsertOrDie(&clone_map, instr, new_instr.get()); instructions.push_back(std::move(new_instr)); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index b44a9e417a..fbbbc45c26 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -289,7 +289,11 @@ class HloComputation { Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; // Returns a deep copy of this computation including all instructions. - std::unique_ptr<HloComputation> Clone(const string& suffix = "clone"); + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). + std::unique_ptr<HloComputation> Clone(const string& suffix = "clone", + HloModule* module = nullptr); // Returns true if the given instruction can be removed from the // computation. Instructions such as parameters and send/receive instructions diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c24eb13ad1..d8ab9dde52 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -961,7 +961,8 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> new_operands) const { + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloModule* module) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { @@ -1131,7 +1132,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( clone = CreateConstant(literal_->CloneToUnique()); break; case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands); + clone = CloneFusionWithNewOperands(shape, new_operands, module); break; case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, parameter_name_); @@ -1168,15 +1169,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } clone->set_metadata(metadata_); + if (has_sharding()) { + clone->set_sharding(sharding()); + } + clone->set_parent(parent_); return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr<HloInstruction> HloInstruction::Clone( - const string& suffix) const { +std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix, + HloModule* module) const { std::unique_ptr<HloInstruction> clone = - CloneWithNewOperands(shape_, operands_); + CloneWithNewOperands(shape_, operands_, module); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1210,16 +1215,12 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone( } } } - clone->set_parent(parent_); - if (has_sharding()) { - clone->set_sharding(sharding()); - } return clone; } std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> operands) const { + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloModule* module) const { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(parent() != nullptr); @@ -1236,7 +1237,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( // fused instructions. for (HloInstruction* old_fused_parameter : fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back(old_fused_parameter->Clone()); + new_fused_instructions.push_back( + old_fused_parameter->Clone("clone", module)); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); } @@ -1255,7 +1257,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( } new_fused_instructions.push_back( old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands)); + old_fused_instruction->shape(), new_operands, module)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); new_fused_instruction->set_parent(parent_); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); @@ -1271,12 +1273,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands( ++new_fused_instruction_iter) { computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); } + if (module == nullptr) { + module = GetModule(); + } auto fused_root_ = fused_expression_root(); new_instruction->called_computations_.push_back( - CHECK_NOTNULL(GetModule()) - ->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - new_instruction->set_parent(parent_); + CHECK_NOTNULL(module)->AddEmbeddedComputation( + computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); return new_instruction; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3fba0b59fb..e251dfb399 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -27,6 +27,7 @@ limitations under the License. #include <memory> #include <string> #include <tuple> +#include <unordered_map> #include <unordered_set> #include <vector> @@ -870,12 +871,19 @@ class HloInstruction { // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of // the instruction to form the name of the cloned instruction. - std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone") const; + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). + std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone", + HloModule* module = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). std::unique_ptr<HloInstruction> CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> operands) const; + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloModule* module = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector<HloComputation*>& called_computations() const { @@ -1061,8 +1069,8 @@ class HloInstruction { // Clones a fusion instruction with a new shape and operands. std::unique_ptr<HloInstruction> CloneFusionWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> operands) const; + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloModule* module = nullptr) const; // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. |