aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc72
1 files changed, 49 insertions, 23 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 34b18b0e21..441288da1a 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -284,9 +284,8 @@ void HloComputation::set_root_instruction(
if (!IsFusionComputation()) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
- << new_root_instruction->shape().ShortDebugString()
- << " is incompatible with "
- << root_instruction_->shape().ShortDebugString();
+ << new_root_instruction->shape() << " is incompatible with "
+ << root_instruction_->shape();
}
bool root_found = false;
for (auto& instruction : instructions_) {
@@ -529,8 +528,10 @@ HloInstruction* HloComputation::CreateFusionInstruction(
}
StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
- HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
- ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
+ HloInstruction* instruction, ShapeIndex* index,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf) {
if (ShapeUtil::IsTuple(instruction->shape())) {
std::vector<HloInstruction*> elements;
for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
@@ -541,9 +542,8 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
instruction, i));
index->push_back(i);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * element,
- DeepCopyHelper(gte, indices_to_copy, copies_added, index));
+ TF_ASSIGN_OR_RETURN(HloInstruction * element,
+ DeepCopyHelper(gte, index, copy_leaf));
elements.push_back(element);
index->pop_back();
}
@@ -557,19 +557,7 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
// Array shape.
TF_RET_CHECK(ShapeUtil::IsArray(instruction->shape()));
- if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
- // Use kCopy to copy array elements
- HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- if (copies_added != nullptr) {
- *copies_added->mutable_element(*index) = copy;
- }
- return copy;
- } else {
- // Elements which are not to be copied are passed through
- // transparently.
- return instruction;
- }
+ return copy_leaf(instruction, *index, this);
}
StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
@@ -591,7 +579,36 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
}
ShapeIndex index;
- return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index);
+ auto copy_leaf = [indices_to_copy, copies_added](
+ HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation) {
+ if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
+ HloInstruction* copy = computation->AddInstruction(
+ HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
+ if (copies_added != nullptr) {
+ *copies_added->mutable_element(leaf_index) = copy;
+ }
+ return copy;
+ }
+ // Elements which are not to be copied are passed through
+ // transparently.
+ return leaf;
+ };
+ return DeepCopyHelper(instruction, &index, copy_leaf);
+}
+
+StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
+ HloInstruction* instruction,
+ const std::function<
+ HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
+ HloComputation* computation)>& copy_leaf) {
+ if (instruction->parent() != this) {
+ return FailedPrecondition(
+ "Can't deep copy instruction %s: instruction is not in computation %s",
+ instruction->name().c_str(), name().c_str());
+ }
+ ShapeIndex index;
+ return DeepCopyHelper(instruction, &index, copy_leaf);
}
ProgramShape HloComputation::ComputeProgramShape() const {
@@ -664,7 +681,7 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
- result->SetReachabilityToUnion(inputs, hlo);
+ result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
}
@@ -881,4 +898,13 @@ void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
name_ = name_uniquer->GetUniqueName(name_);
}
+HloInstruction* HloComputation::GetInstructionWithName(
+ tensorflow::StringPiece name) {
+ auto instructions_in_computation = instructions();
+ auto it = c_find_if(instructions_in_computation, [&](HloInstruction* instr) {
+ return instr->name() == name;
+ });
+ return it == instructions_in_computation.end() ? nullptr : *it;
+}
+
} // namespace xla