aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 16:29:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 16:38:49 -0800
commit75e15a2b25f731d7ddf4ffc455a4bf8d1c0fd7ca (patch)
treeb6511c78de81d1c0856c113ce790f06ffe5082ec /tensorflow/compiler/xla/service/hlo_computation.cc
parent721a60801055190dae18fe3e3933950c75fa9d1c (diff)
[XLA] Store the program shape in the HloModuleProto and HloComputationProto.
PiperOrigin-RevId: 188100425
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 21e6b2ca73..f99c7cf5e4 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -399,6 +399,7 @@ HloComputationProto HloComputation::ToProto() const {
proto.add_instructions()->Swap(&instruction_proto);
}
proto.set_root_name(root_instruction()->name());
+ *proto.mutable_program_shape() = ComputeProgramShape();
return proto;
}
@@ -532,7 +533,6 @@ ProgramShape HloComputation::ComputeProgramShape() const {
}
*program_shape.mutable_result() = root_instruction_->shape();
- LayoutUtil::ClearLayout(&program_shape);
return program_shape;
}