aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-25 14:03:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-25 14:09:49 -0700
commit88e510c479fb96678a8cae07d3f757273a1c8952 (patch)
treea410c5f636fdc397600e8935aae95f1b79dcacf1 /tensorflow/compiler/xla/service/hlo_computation.cc
parent008910f1122d115a6d7430bfcc63cf4296c7467d (diff)
Add option to HloComputation::DeepCopyInstruction for selectively copying only
certain indices. Also, add mechanism for returning the kCopy instructions added to create the deep copy. PiperOrigin-RevId: 166521917
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc74
1 files changed, 45 insertions, 29 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 44b54e432a..b8133cda30 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -459,49 +459,65 @@ HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution(
return fusion_instruction;
}
-StatusOr<HloInstruction*> HloComputation::DeepCopyTuple(
- HloInstruction* instruction) {
- TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()));
- std::vector<HloInstruction*> element_copies;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
- ++i) {
- HloInstruction* gte = AddInstruction(HloInstruction::CreateGetTupleElement(
- ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i));
- // Recurse to copy tuple elements. For array elements, insert a kCopy
- // because GetTupleElement forwards a pointer to the tuple element buffer.
- HloInstruction* element_copy;
- if (ShapeUtil::IsTuple(gte->shape())) {
- TF_ASSIGN_OR_RETURN(element_copy, DeepCopyTuple(gte));
+StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
+ HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
+ ShapeTree<HloInstruction*>* copies_added, ShapeIndex* index) {
+ if (ShapeUtil::IsArray(instruction->shape())) {
+ if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
+ // Use kCopy to copy array elements
+ HloInstruction* copy = AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ if (copies_added != nullptr) {
+ *copies_added->mutable_element(*index) = copy;
+ }
+ return copy;
} else {
- element_copy = AddInstruction(
- HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte));
+ // Array elements which are not to be copied are passed through
+ // transparently.
+ return instruction;
+ }
+ } else if (ShapeUtil::IsTuple(instruction->shape())) {
+ std::vector<HloInstruction*> elements;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
+ i++) {
+ HloInstruction* gte =
+ AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetTupleElementShape(instruction->shape(), i),
+ instruction, i));
+
+ index->push_back(i);
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * element,
+ DeepCopyHelper(gte, indices_to_copy, copies_added, index));
+ elements.push_back(element);
+ index->pop_back();
}
- element_copies.push_back(element_copy);
+ return AddInstruction(HloInstruction::CreateTuple(elements));
+ } else {
+ return FailedPrecondition(
+ "Can only copy array and tuple shaped instructions");
}
-
- // Gather element copies into a tuple with a new Tuple instruction.
- return AddInstruction(HloInstruction::CreateTuple(element_copies));
}
StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
- HloInstruction* instruction) {
+ HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
+ ShapeTree<HloInstruction*>* copies_added) {
if (instruction->parent() != this) {
return FailedPrecondition(
"Can't deep copy instruction %s: instruction is not in computation %s",
instruction->name().c_str(), name().c_str());
}
- // For tuple instructions, perform a deep copy. For array instructions, copy
- // with a kCopy instruction.
- if (ShapeUtil::IsTuple(instruction->shape())) {
- return DeepCopyTuple(instruction);
- } else if (ShapeUtil::IsArray(instruction->shape())) {
- return AddInstruction(HloInstruction::CreateUnary(
- instruction->shape(), HloOpcode::kCopy, instruction));
- } else {
+ if (indices_to_copy != nullptr &&
+ !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
return FailedPrecondition(
- "Can only copy array and tuple shaped instructions");
+ "Can't deep copy instruction %s: given shape tree of indices to copy "
+ "has incompatible shape",
+ instruction->name().c_str());
}
+
+ ShapeIndex index;
+ return DeepCopyHelper(instruction, indices_to_copy, copies_added, &index);
}
ProgramShape HloComputation::ComputeProgramShape() const {