diff options
author | 2017-08-25 14:03:07 -0700 | |
---|---|---|
committer | 2017-08-25 14:09:49 -0700 | |
commit | 88e510c479fb96678a8cae07d3f757273a1c8952 (patch) | |
tree | a410c5f636fdc397600e8935aae95f1b79dcacf1 /tensorflow/compiler/xla/service/hlo_computation.cc | |
parent | 008910f1122d115a6d7430bfcc63cf4296c7467d (diff) |
Add option to HloComputation::DeepCopyInstruction for selectively copying only
certain indices. Also, add mechanism for returning the kCopy instructions
added to create the deep copy.
PiperOrigin-RevId: 166521917
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 74 |
1 files changed, 45 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 44b54e432a..b8133cda30 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -459,49 +459,65 @@ HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution( return fusion_instruction; } -StatusOr<HloInstruction*> HloComputation::DeepCopyTuple( - HloInstruction* instruction) { - TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape())); - std::vector<HloInstruction*> element_copies; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); - ++i) { - HloInstruction* gte = AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i)); - // Recurse to copy tuple elements. For array elements, insert a kCopy - // because GetTupleElement forwards a pointer to the tuple element buffer. - HloInstruction* element_copy; - if (ShapeUtil::IsTuple(gte->shape())) { - TF_ASSIGN_OR_RETURN(element_copy, DeepCopyTuple(gte)); +StatusOr<HloInstruction*> HloComputation::DeepCopyHelper( + HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, + ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) { + if (ShapeUtil::IsArray(instruction->shape())) { + if (indices_to_copy == nullptr || indices_to_copy->element(*index)) { + // Use kCopy to copy array elements + HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + if (copies_added != nullptr) { + *copies_added->mutable_element(*index) = copy; + } + return copy; } else { - element_copy = AddInstruction( - HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte)); + // Array elements which are not to be copied are passed through + // transparently. + return instruction; + } + } else if (ShapeUtil::IsTuple(instruction->shape())) { + std::vector<HloInstruction*> elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + i++) { + HloInstruction* gte = + AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(instruction->shape(), i), + instruction, i)); + + index->push_back(i); + TF_ASSIGN_OR_RETURN( + HloInstruction * element, + DeepCopyHelper(gte, indices_to_copy, copies_added, index)); + elements.push_back(element); + index->pop_back(); } - element_copies.push_back(element_copy); + return AddInstruction(HloInstruction::CreateTuple(elements)); + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); } - - // Gather element copies into a tuple with a new Tuple instruction. - return AddInstruction(HloInstruction::CreateTuple(element_copies)); } StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction( - HloInstruction* instruction) { + HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy, + ShapeTree<HloInstruction*>* copies_added) { if (instruction->parent() != this) { return FailedPrecondition( "Can't deep copy instruction %s: instruction is not in computation %s", instruction->name().c_str(), name().c_str()); } - // For tuple instructions, perform a deep copy. For array instructions, copy - // with a kCopy instruction. - if (ShapeUtil::IsTuple(instruction->shape())) { - return DeepCopyTuple(instruction); - } else if (ShapeUtil::IsArray(instruction->shape())) { - return AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - } else { + if (indices_to_copy != nullptr && + !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); + "Can't deep copy instruction %s: given shape tree of indices to copy " + "has incompatible shape", + instruction->name().c_str()); } + + ShapeIndex index; + return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index); } ProgramShape HloComputation::ComputeProgramShape() const { |