diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-06 16:29:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-06 16:38:49 -0800 |
commit | 75e15a2b25f731d7ddf4ffc455a4bf8d1c0fd7ca (patch) | |
tree | b6511c78de81d1c0856c113ce790f06ffe5082ec /tensorflow/compiler/xla/service/hlo_proto_util.cc | |
parent | 721a60801055190dae18fe3e3933950c75fa9d1c (diff) |
[XLA] Store the program shape in the HloModuleProto and HloComputationProto.
PiperOrigin-RevId: 188100425
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_proto_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util.cc | 138 |
1 files changed, 19 insertions, 119 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index f75c452082..3460679558 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -21,106 +21,6 @@ limitations under the License. namespace xla { -namespace { - -// Returns the entry computation of the HLO module in the given HloProto. -StatusOr<const HloComputationProto*> GetEntryComputation( - const HloProto& hlo_proto) { - if (!hlo_proto.has_hlo_module()) { - return NotFound("HloProto missing HloModuleProto."); - } - - if (hlo_proto.hlo_module().entry_computation_name().empty()) { - return NotFound("HloProto has empty entry computation name."); - } - - const string& entry_computation_name = - hlo_proto.hlo_module().entry_computation_name(); - const HloComputationProto* entry_computation = nullptr; - for (const HloComputationProto& computation : - hlo_proto.hlo_module().computations()) { - if (computation.name() == entry_computation_name) { - if (entry_computation == nullptr) { - entry_computation = &computation; - } else { - return InvalidArgument( - "HloProto has multiple computations with entry computation named " - "%s.", - entry_computation_name.c_str()); - } - } - } - if (entry_computation == nullptr) { - return InvalidArgument("HloProto has no entry computation named %s.", - entry_computation_name.c_str()); - } - return entry_computation; -} - -// Returns the root instruction of the given computation proto. -StatusOr<const HloInstructionProto*> GetRootInstruction( - const HloComputationProto& computation) { - if (computation.root_name().empty()) { - return InvalidArgument("Missing root instruction name."); - } - - const HloInstructionProto* root = nullptr; - for (const HloInstructionProto& instruction : computation.instructions()) { - if (instruction.name() == computation.root_name()) { - if (root == nullptr) { - root = &instruction; - } else { - return InvalidArgument( - "Computation has multiple instructions named %s.", - computation.root_name().c_str()); - } - } - } - if (root == nullptr) { - return InvalidArgument("Computation has no instruction named %s.", - computation.root_name().c_str()); - } - return root; -} - -// Returns the parameters of the given computation. Parameter numbers are -// checked for validity and contiguousness. -StatusOr<std::vector<const HloInstructionProto*>> GetParameters( - const HloComputationProto& computation) { - std::vector<const HloInstructionProto*> parameters; - for (const HloInstructionProto& instruction : computation.instructions()) { - if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { - parameters.push_back(&instruction); - } - } - - // Verify the uniqueness and validity of the parameter numbers. - tensorflow::gtl::FlatSet<int64> parameter_numbers; - for (const HloInstructionProto* parameter : parameters) { - if (parameter->parameter_number() < 0 || - parameter->parameter_number() >= parameters.size()) { - return InvalidArgument( - "Parameter instruction %s has invalid parameter number %lld.", - parameter->name().c_str(), parameter->parameter_number()); - } - if (parameter_numbers.count(parameter->parameter_number()) != 0) { - return InvalidArgument( - "Multiple parameter instructions have parameter number %lld.", - parameter->parameter_number()); - } - parameter_numbers.insert(parameter->parameter_number()); - } - - std::sort(parameters.begin(), parameters.end(), - [](const HloInstructionProto* a, const HloInstructionProto* b) { - return a->parameter_number() < b->parameter_number(); - }); - - return parameters; -} - -} // namespace - HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { HloOrderingProto proto_ordering = @@ -141,33 +41,33 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes( const HloProto& hlo_proto) { - TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation, - GetEntryComputation(hlo_proto)); - TF_ASSIGN_OR_RETURN(std::vector<const HloInstructionProto*> parameters, - GetParameters(*entry_computation)); + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + std::vector<const Shape*> parameter_shapes; - for (const HloInstructionProto* parameter : parameters) { - if (!parameter->has_shape()) { - return InvalidArgument("Parameter instruction %s is missing shape.", - parameter->name().c_str()); - } - parameter_shapes.push_back(¶meter->shape()); + const auto& program_shape = hlo_proto.hlo_module().program_shape(); + for (const Shape& shape : program_shape.parameters()) { + parameter_shapes.push_back(&shape); } return parameter_shapes; } StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) { - TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation, - GetEntryComputation(hlo_proto)); - - TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, - GetRootInstruction(*entry_computation)); - if (!root->has_shape()) { - return InvalidArgument("Instruction %s is missing shape.", - root->name().c_str()); + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + if (!hlo_proto.hlo_module().program_shape().has_result()) { + return NotFound("HloProto missing result in its program shape"); } - return &root->shape(); + return &hlo_proto.hlo_module().program_shape().result(); } } // namespace xla |