aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 8c6903d766..601a008d9f 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -279,11 +279,11 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
return Status::OK();
}
-void HloComputation::set_root_instruction(
- HloInstruction* new_root_instruction) {
+void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
+ bool accept_different_shape) {
// The shape of the root (ignoring layout) is an invariant of the computation
// for non-fusion cases.
- if (!IsFusionComputation()) {
+ if (!IsFusionComputation() && !accept_different_shape) {
CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
root_instruction_->shape()))
<< new_root_instruction->shape() << " is incompatible with "