diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-16 13:38:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-16 13:42:25 -0700 |
commit | a6ee64cd216b3ac440262e1f4ec7872fe7026df6 (patch) | |
tree | 52f0bcfc458a17f04093e8420f76a0bc7303a496 /tensorflow/compiler | |
parent | 92c31bb620b0f8dd6590380dc6a5674f591ce1cb (diff) |
Conditionally allow changing a non-fusion computation root_instruction shape.
PiperOrigin-RevId: 213191899
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 8 |
2 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 8c6903d766..601a008d9f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -279,11 +279,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { return Status::OK(); } -void HloComputation::set_root_instruction( - HloInstruction* new_root_instruction) { +void HloComputation::set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape) { // The shape of the root (ignoring layout) is an invariant of the computation // for non-fusion cases. - if (!IsFusionComputation()) { + if (!IsFusionComputation() && !accept_different_shape) { CHECK(ShapeUtil::Compatible(new_root_instruction->shape(), root_instruction_->shape())) << new_root_instruction->shape() << " is incompatible with " diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 91c5234a6f..a880e9ab30 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -134,9 +134,11 @@ class HloComputation { Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); // Set the root of the computation to the given instruction. The instruction - // must have already been added to the computation and have the same shape as - // the result of the computation for non fusion computations. - void set_root_instruction(HloInstruction* new_root_instruction); + // must have already been added to the computation. In addition it must have + // the same shape as the result of the computation for non fusion + // computations, except if accept_different_shape is set to true. + void set_root_instruction(HloInstruction* new_root_instruction, + bool accept_different_shape = false); // Return the root instruction of the computation. The root instruction is the // instruction which produces the output of the computation. |