aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_proto_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 16:29:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 16:38:49 -0800
commit75e15a2b25f731d7ddf4ffc455a4bf8d1c0fd7ca (patch)
treeb6511c78de81d1c0856c113ce790f06ffe5082ec /tensorflow/compiler/xla/service/hlo_proto_util.cc
parent721a60801055190dae18fe3e3933950c75fa9d1c (diff)
[XLA] Store the program shape in the HloModuleProto and HloComputationProto.
PiperOrigin-RevId: 188100425
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_proto_util.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc138
1 files changed, 19 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index f75c452082..3460679558 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -21,106 +21,6 @@ limitations under the License.
namespace xla {
-namespace {
-
-// Returns the entry computation of the HLO module in the given HloProto.
-StatusOr<const HloComputationProto*> GetEntryComputation(
- const HloProto& hlo_proto) {
- if (!hlo_proto.has_hlo_module()) {
- return NotFound("HloProto missing HloModuleProto.");
- }
-
- if (hlo_proto.hlo_module().entry_computation_name().empty()) {
- return NotFound("HloProto has empty entry computation name.");
- }
-
- const string& entry_computation_name =
- hlo_proto.hlo_module().entry_computation_name();
- const HloComputationProto* entry_computation = nullptr;
- for (const HloComputationProto& computation :
- hlo_proto.hlo_module().computations()) {
- if (computation.name() == entry_computation_name) {
- if (entry_computation == nullptr) {
- entry_computation = &computation;
- } else {
- return InvalidArgument(
- "HloProto has multiple computations with entry computation named "
- "%s.",
- entry_computation_name.c_str());
- }
- }
- }
- if (entry_computation == nullptr) {
- return InvalidArgument("HloProto has no entry computation named %s.",
- entry_computation_name.c_str());
- }
- return entry_computation;
-}
-
-// Returns the root instruction of the given computation proto.
-StatusOr<const HloInstructionProto*> GetRootInstruction(
- const HloComputationProto& computation) {
- if (computation.root_name().empty()) {
- return InvalidArgument("Missing root instruction name.");
- }
-
- const HloInstructionProto* root = nullptr;
- for (const HloInstructionProto& instruction : computation.instructions()) {
- if (instruction.name() == computation.root_name()) {
- if (root == nullptr) {
- root = &instruction;
- } else {
- return InvalidArgument(
- "Computation has multiple instructions named %s.",
- computation.root_name().c_str());
- }
- }
- }
- if (root == nullptr) {
- return InvalidArgument("Computation has no instruction named %s.",
- computation.root_name().c_str());
- }
- return root;
-}
-
-// Returns the parameters of the given computation. Parameter numbers are
-// checked for validity and contiguousness.
-StatusOr<std::vector<const HloInstructionProto*>> GetParameters(
- const HloComputationProto& computation) {
- std::vector<const HloInstructionProto*> parameters;
- for (const HloInstructionProto& instruction : computation.instructions()) {
- if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
- parameters.push_back(&instruction);
- }
- }
-
- // Verify the uniqueness and validity of the parameter numbers.
- tensorflow::gtl::FlatSet<int64> parameter_numbers;
- for (const HloInstructionProto* parameter : parameters) {
- if (parameter->parameter_number() < 0 ||
- parameter->parameter_number() >= parameters.size()) {
- return InvalidArgument(
- "Parameter instruction %s has invalid parameter number %lld.",
- parameter->name().c_str(), parameter->parameter_number());
- }
- if (parameter_numbers.count(parameter->parameter_number()) != 0) {
- return InvalidArgument(
- "Multiple parameter instructions have parameter number %lld.",
- parameter->parameter_number());
- }
- parameter_numbers.insert(parameter->parameter_number());
- }
-
- std::sort(parameters.begin(), parameters.end(),
- [](const HloInstructionProto* a, const HloInstructionProto* b) {
- return a->parameter_number() < b->parameter_number();
- });
-
- return parameters;
-}
-
-} // namespace
-
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
HloOrderingProto proto_ordering =
@@ -141,33 +41,33 @@ HloProto MakeHloProto(const HloModule& module) {
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
const HloProto& hlo_proto) {
- TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation,
- GetEntryComputation(hlo_proto));
- TF_ASSIGN_OR_RETURN(std::vector<const HloInstructionProto*> parameters,
- GetParameters(*entry_computation));
+ if (!hlo_proto.has_hlo_module()) {
+ return NotFound("HloProto missing HloModuleProto.");
+ }
+ if (!hlo_proto.hlo_module().has_program_shape()) {
+ return NotFound("HloProto missing program shape.");
+ }
+
std::vector<const Shape*> parameter_shapes;
- for (const HloInstructionProto* parameter : parameters) {
- if (!parameter->has_shape()) {
- return InvalidArgument("Parameter instruction %s is missing shape.",
- parameter->name().c_str());
- }
- parameter_shapes.push_back(&parameter->shape());
+ const auto& program_shape = hlo_proto.hlo_module().program_shape();
+ for (const Shape& shape : program_shape.parameters()) {
+ parameter_shapes.push_back(&shape);
}
return parameter_shapes;
}
StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) {
- TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation,
- GetEntryComputation(hlo_proto));
-
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
- GetRootInstruction(*entry_computation));
- if (!root->has_shape()) {
- return InvalidArgument("Instruction %s is missing shape.",
- root->name().c_str());
+ if (!hlo_proto.has_hlo_module()) {
+ return NotFound("HloProto missing HloModuleProto.");
+ }
+ if (!hlo_proto.hlo_module().has_program_shape()) {
+ return NotFound("HloProto missing program shape.");
+ }
+ if (!hlo_proto.hlo_module().program_shape().has_result()) {
+ return NotFound("HloProto missing result in its program shape");
}
- return &root->shape();
+ return &hlo_proto.hlo_module().program_shape().result();
}
} // namespace xla