aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-16 13:38:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-16 13:42:25 -0700
commita6ee64cd216b3ac440262e1f4ec7872fe7026df6 (patch)
tree52f0bcfc458a17f04093e8420f76a0bc7303a496
parent92c31bb620b0f8dd6590380dc6a5674f591ce1cb (diff)
Conditionally allow changing a non-fusion computation root_instruction shape.
PiperOrigin-RevId: 213191899
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h8
2 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 8c6903d766..601a008d9f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -279,11 +279,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
return Status::OK();
}
-void HloComputation::set_root_instruction(
- HloInstruction* new_root_instruction) {
+void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
+ bool accept_different_shape) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
- if (!IsFusionComputation()) {
+ if (!IsFusionComputation() && !accept_different_shape) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape() << " is incompatible with "
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 91c5234a6f..a880e9ab30 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -134,9 +134,11 @@ class HloComputation {
Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
// Set the root of the computation to the given instruction. The instruction
- // must have already been added to the computation and have the same shape as
- // the result of the computation for non fusion computations.
- void set_root_instruction(HloInstruction* new_root_instruction);
+ // must have already been added to the computation. In addition it must have
+ // the same shape as the result of the computation for non fusion
+ // computations, except if accept_different_shape is set to true.
+ void set_root_instruction(HloInstruction* new_root_instruction,
+ bool accept_different_shape = false);
// Return the root instruction of the computation. The root instruction is the
// instruction which produces the output of the computation.