aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-01 19:11:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-01 19:15:17 -0700
commit53a4fcbdbad571e659203733f6a07ba82651d40b (patch)
tree58f22f4b697ff81126be3f49baf65b841f35762b
parent67fe8d146a0aa642a29a52a1389000b99b19cc03 (diff)
Fixed HloComputation/HloInstruction clone to allow deep clone, and avoid the cloned instruction and computations to still have live link to their parent original modules and computations.
PiperOrigin-RevId: 174271432
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc11
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc37
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
4 files changed, 42 insertions, 30 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index b5b07aeb72..ed776b9933 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -724,7 +724,8 @@ Status HloComputation::Accept(
return this->Accept(&visitor);
}
-std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix) {
+std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix,
+ HloModule* module) {
VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
auto postorder = MakeInstructionPostOrder();
std::unordered_map<HloInstruction*, HloInstruction*> clone_map;
@@ -737,12 +738,8 @@ std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix) {
CHECK(new_operand != nullptr);
new_operands.push_back(new_operand);
}
-
- new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands);
- new_instr->set_metadata(instr->metadata());
- if (instr->has_sharding()) {
- new_instr->set_sharding(instr->sharding());
- }
+ new_instr =
+ instr->CloneWithNewOperands(instr->shape(), new_operands, module);
InsertOrDie(&clone_map, instr, new_instr.get());
instructions.push_back(std::move(new_instr));
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index b44a9e417a..fbbbc45c26 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -289,7 +289,11 @@ class HloComputation {
Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const;
// Returns a deep copy of this computation including all instructions.
- std::unique_ptr<HloComputation> Clone(const string& suffix = "clone");
+ // If the module pointer is not nullptr, it will be the module where
+ // the cloned computations will be added to (in order to support deep
+ // cloning).
+ std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
+ HloModule* module = nullptr);
// Returns true if the given instruction can be removed from the
// computation. Instructions such as parameters and send/receive instructions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index c24eb13ad1..d8ab9dde52 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -961,7 +961,8 @@ bool HloInstruction::HasSideEffect() const {
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands) const {
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloModule* module) const {
VLOG(3) << "CloneWithNewOperands:\n " << ToString();
VLOG(3) << " new operands:";
for (const HloInstruction* new_operand : new_operands) {
@@ -1131,7 +1132,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateConstant(literal_->CloneToUnique());
break;
case HloOpcode::kFusion:
- clone = CloneFusionWithNewOperands(shape, new_operands);
+ clone = CloneFusionWithNewOperands(shape, new_operands, module);
break;
case HloOpcode::kParameter:
clone = CreateParameter(parameter_number_, shape, parameter_name_);
@@ -1168,15 +1169,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
clone->set_metadata(metadata_);
+ if (has_sharding()) {
+ clone->set_sharding(sharding());
+ }
+ clone->set_parent(parent_);
return clone;
}
HloInstruction::~HloInstruction() {}
-std::unique_ptr<HloInstruction> HloInstruction::Clone(
- const string& suffix) const {
+std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
+ HloModule* module) const {
std::unique_ptr<HloInstruction> clone =
- CloneWithNewOperands(shape_, operands_);
+ CloneWithNewOperands(shape_, operands_, module);
if (suffix.empty()) {
clone->name_ = name();
} else {
@@ -1210,16 +1215,12 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(
}
}
}
- clone->set_parent(parent_);
- if (has_sharding()) {
- clone->set_sharding(sharding());
- }
return clone;
}
std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) const {
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloModule* module) const {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(parent() != nullptr);
@@ -1236,7 +1237,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
// fused instructions.
for (HloInstruction* old_fused_parameter :
fused_instructions_computation()->parameter_instructions()) {
- new_fused_instructions.push_back(old_fused_parameter->Clone());
+ new_fused_instructions.push_back(
+ old_fused_parameter->Clone("clone", module));
HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
}
@@ -1255,7 +1257,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
}
new_fused_instructions.push_back(
old_fused_instruction->CloneWithNewOperands(
- old_fused_instruction->shape(), new_operands));
+ old_fused_instruction->shape(), new_operands, module));
HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
new_fused_instruction->set_parent(parent_);
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
@@ -1271,12 +1273,13 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
++new_fused_instruction_iter) {
computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
}
+ if (module == nullptr) {
+ module = GetModule();
+ }
auto fused_root_ = fused_expression_root();
new_instruction->called_computations_.push_back(
- CHECK_NOTNULL(GetModule())
- ->AddEmbeddedComputation(
- computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
- new_instruction->set_parent(parent_);
+ CHECK_NOTNULL(module)->AddEmbeddedComputation(
+ computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
return new_instruction;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 3fba0b59fb..e251dfb399 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -27,6 +27,7 @@ limitations under the License.
#include <memory>
#include <string>
#include <tuple>
+#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -870,12 +871,19 @@ class HloInstruction {
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
// the instruction to form the name of the cloned instruction.
- std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone") const;
+ // If the module pointer is not nullptr, it will be the module where
+ // the cloned computations will be added to (in order to support deep
+ // cloning).
+ std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
+ HloModule* module = nullptr) const;
// Clones the HLO instruction as above but with new shape and operands.
+ // If the module pointer is not nullptr, it will be the module where
+ // the cloned computations will be added to (in order to support deep
+ // cloning).
std::unique_ptr<HloInstruction> CloneWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) const;
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloModule* module = nullptr) const;
// Returns the computations this instruction directly calls (if any).
const std::vector<HloComputation*>& called_computations() const {
@@ -1061,8 +1069,8 @@ class HloInstruction {
// Clones a fusion instruction with a new shape and operands.
std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands) const;
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloModule* module = nullptr) const;
// Returns true if this instruction can legally have the dimensions field
// set. Used for checking precondition of dimensions field accessors.