diff options
author | 2018-03-06 16:29:33 -0800 | |
---|---|---|
committer | 2018-03-06 16:38:49 -0800 | |
commit | 75e15a2b25f731d7ddf4ffc455a4bf8d1c0fd7ca (patch) | |
tree | b6511c78de81d1c0856c113ce790f06ffe5082ec /tensorflow/compiler/xla | |
parent | 721a60801055190dae18fe3e3933950c75fa9d1c (diff) |
[XLA] Store the program shape in the HloModuleProto and HloComputationProto.
PiperOrigin-RevId: 188100425
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module.cc | 68 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util.cc | 138 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util_test.cc | 114 |
6 files changed, 39 insertions, 291 deletions
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index a43785b4a9..66fd317051 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -145,6 +145,9 @@ message HloComputationProto { // The name of the root of the computation. string root_name = 3; + + // The program shape (with layout) of this computation. + xla.ProgramShape program_shape = 4; } // Serialization of HloModule. @@ -155,6 +158,9 @@ message HloModuleProto { // The array of computations is always in a valid dependency order, where // callees appear before their callers. repeated HloComputationProto computations = 3; + + // The program shape (with layout) of the entry computation. + xla.ProgramShape program_shape = 4; } // Serialization of HloOrdering. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 21e6b2ca73..f99c7cf5e4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -399,6 +399,7 @@ HloComputationProto HloComputation::ToProto() const { proto.add_instructions()->Swap(&instruction_proto); } proto.set_root_name(root_instruction()->name()); + *proto.mutable_program_shape() = ComputeProgramShape(); return proto; } @@ -532,7 +533,6 @@ ProgramShape HloComputation::ComputeProgramShape() const { } *program_shape.mutable_result() = root_instruction_->shape(); - LayoutUtil::ClearLayout(&program_shape); return program_shape; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 39d864efcb..dd9d346999 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -248,7 +248,7 @@ class HloComputation { ShapeTree<HloInstruction*>* copies_added = nullptr); // Computes and returns the ProgramShape of this computation (shape of - // parameters and result without layout). + // parameters and result with layout). ProgramShape ComputeProgramShape() const; // Return whether `*this` and `other` are functionally equivalent. diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index cb2fe9f874..cdea3d5978 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -213,74 +213,23 @@ HloModuleProto HloModule::ToProto() const { continue; } HloComputationProto computation_proto = computation->ToProto(); + if (computation->name() == entry_computation_->name()) { + *proto.mutable_program_shape() = computation_proto.program_shape(); + } proto.add_computations()->Swap(&computation_proto); } return proto; } -namespace { - -// Construct a ProgramShape matching the shape of the parameters and root of the -// given module's entry computation. -StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) { - const HloComputationProto* entry_computation = nullptr; - for (const HloComputationProto& computation : module.computations()) { - if (computation.name() == module.entry_computation_name()) { - entry_computation = &computation; - break; - } - } - TF_RET_CHECK(entry_computation != nullptr) - << "No computation with entry computation name" - << module.entry_computation_name(); - - tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters; - const HloInstructionProto* root = nullptr; - for (const HloInstructionProto& instruction : - entry_computation->instructions()) { - if (instruction.name() == entry_computation->root_name()) { - TF_RET_CHECK(root == nullptr) << "Entry computation has more than " - "one instruction with (root) name " - << instruction.name(); - root = &instruction; - } - if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { - TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) - << "Entry computation has more than one parameter instruction " - "with parameter number " - << instruction.parameter_number(); - parameters[instruction.parameter_number()] = {instruction.name(), - &instruction.shape()}; - } - } - TF_RET_CHECK(root != nullptr) - << "Entry computation is missing root instruction named " - << entry_computation->root_name(); - - ProgramShape program_shape; - *program_shape.mutable_result() = root->shape(); - for (int64 i = 0; i < parameters.size(); ++i) { - TF_RET_CHECK(ContainsKey(parameters, i)) - << "Entry computation missing parameter number " << i; - const string& name = parameters.at(i).first; - const Shape& shape = *parameters.at(i).second; - *program_shape.add_parameters() = shape; - program_shape.add_parameter_names(name); - } - - return std::move(program_shape); -} - -} // namespace - /* static */ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( const HloModuleProto& proto, const HloModuleConfig& module_config, const VersionedComputationHandle& entry_computation_handle) { // The ProgramShape in the passed in module config must match the shapes of // the entry parameters and root. - TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, - ProgramShapeFromProto(proto)); + TF_RET_CHECK(proto.has_program_shape()) + << "No program shape found in the proto"; + const auto& expected_program_shape = proto.program_shape(); TF_RET_CHECK(expected_program_shape.parameters_size() == module_config.entry_computation_layout().parameter_count()); for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { @@ -354,8 +303,9 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto( /* static */ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto( const HloModuleProto& module) { - TF_ASSIGN_OR_RETURN(ProgramShape program_shape, - ProgramShapeFromProto(module)); + TF_RET_CHECK(module.has_program_shape()) + << "No program shape found in the proto"; + const auto& program_shape = module.program_shape(); HloModuleConfig module_config(program_shape); diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index f75c452082..3460679558 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -21,106 +21,6 @@ limitations under the License. namespace xla { -namespace { - -// Returns the entry computation of the HLO module in the given HloProto. -StatusOr<const HloComputationProto*> GetEntryComputation( - const HloProto& hlo_proto) { - if (!hlo_proto.has_hlo_module()) { - return NotFound("HloProto missing HloModuleProto."); - } - - if (hlo_proto.hlo_module().entry_computation_name().empty()) { - return NotFound("HloProto has empty entry computation name."); - } - - const string& entry_computation_name = - hlo_proto.hlo_module().entry_computation_name(); - const HloComputationProto* entry_computation = nullptr; - for (const HloComputationProto& computation : - hlo_proto.hlo_module().computations()) { - if (computation.name() == entry_computation_name) { - if (entry_computation == nullptr) { - entry_computation = &computation; - } else { - return InvalidArgument( - "HloProto has multiple computations with entry computation named " - "%s.", - entry_computation_name.c_str()); - } - } - } - if (entry_computation == nullptr) { - return InvalidArgument("HloProto has no entry computation named %s.", - entry_computation_name.c_str()); - } - return entry_computation; -} - -// Returns the root instruction of the given computation proto. -StatusOr<const HloInstructionProto*> GetRootInstruction( - const HloComputationProto& computation) { - if (computation.root_name().empty()) { - return InvalidArgument("Missing root instruction name."); - } - - const HloInstructionProto* root = nullptr; - for (const HloInstructionProto& instruction : computation.instructions()) { - if (instruction.name() == computation.root_name()) { - if (root == nullptr) { - root = &instruction; - } else { - return InvalidArgument( - "Computation has multiple instructions named %s.", - computation.root_name().c_str()); - } - } - } - if (root == nullptr) { - return InvalidArgument("Computation has no instruction named %s.", - computation.root_name().c_str()); - } - return root; -} - -// Returns the parameters of the given computation. Parameter numbers are -// checked for validity and contiguousness. -StatusOr<std::vector<const HloInstructionProto*>> GetParameters( - const HloComputationProto& computation) { - std::vector<const HloInstructionProto*> parameters; - for (const HloInstructionProto& instruction : computation.instructions()) { - if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { - parameters.push_back(&instruction); - } - } - - // Verify the uniqueness and validity of the parameter numbers. - tensorflow::gtl::FlatSet<int64> parameter_numbers; - for (const HloInstructionProto* parameter : parameters) { - if (parameter->parameter_number() < 0 || - parameter->parameter_number() >= parameters.size()) { - return InvalidArgument( - "Parameter instruction %s has invalid parameter number %lld.", - parameter->name().c_str(), parameter->parameter_number()); - } - if (parameter_numbers.count(parameter->parameter_number()) != 0) { - return InvalidArgument( - "Multiple parameter instructions have parameter number %lld.", - parameter->parameter_number()); - } - parameter_numbers.insert(parameter->parameter_number()); - } - - std::sort(parameters.begin(), parameters.end(), - [](const HloInstructionProto* a, const HloInstructionProto* b) { - return a->parameter_number() < b->parameter_number(); - }); - - return parameters; -} - -} // namespace - HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { HloOrderingProto proto_ordering = @@ -141,33 +41,33 @@ HloProto MakeHloProto(const HloModule& module) { StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes( const HloProto& hlo_proto) { - TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation, - GetEntryComputation(hlo_proto)); - TF_ASSIGN_OR_RETURN(std::vector<const HloInstructionProto*> parameters, - GetParameters(*entry_computation)); + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + std::vector<const Shape*> parameter_shapes; - for (const HloInstructionProto* parameter : parameters) { - if (!parameter->has_shape()) { - return InvalidArgument("Parameter instruction %s is missing shape.", - parameter->name().c_str()); - } - parameter_shapes.push_back(¶meter->shape()); + const auto& program_shape = hlo_proto.hlo_module().program_shape(); + for (const Shape& shape : program_shape.parameters()) { + parameter_shapes.push_back(&shape); } return parameter_shapes; } StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) { - TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation, - GetEntryComputation(hlo_proto)); - - TF_ASSIGN_OR_RETURN(const HloInstructionProto* root, - GetRootInstruction(*entry_computation)); - if (!root->has_shape()) { - return InvalidArgument("Instruction %s is missing shape.", - root->name().c_str()); + if (!hlo_proto.has_hlo_module()) { + return NotFound("HloProto missing HloModuleProto."); + } + if (!hlo_proto.hlo_module().has_program_shape()) { + return NotFound("HloProto missing program shape."); + } + if (!hlo_proto.hlo_module().program_shape().has_result()) { + return NotFound("HloProto missing result in its program shape"); } - return &root->shape(); + return &hlo_proto.hlo_module().program_shape().result(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc index 0c0abf10fa..b9cca13870 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc @@ -29,69 +29,6 @@ namespace { class HloProtoUtilTest : public ::testing::Test {}; -TEST_F(HloProtoUtilTest, ParamsAndOutputShape) { - HloProto hlo_proto; - HloModuleProto* module = hlo_proto.mutable_hlo_module(); - module->set_entry_computation_name("entry"); - HloComputationProto* computation = module->add_computations(); - computation->set_name("entry"); - computation->set_root_name("root"); - - HloInstructionProto* param0 = computation->add_instructions(); - param0->set_opcode(HloOpcodeString(HloOpcode::kParameter)); - param0->set_parameter_number(0); - *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42}); - - HloInstructionProto* param2 = computation->add_instructions(); - param2->set_opcode(HloOpcodeString(HloOpcode::kParameter)); - param2->set_parameter_number(2); - *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3}); - - HloInstructionProto* param1 = computation->add_instructions(); - param1->set_opcode(HloOpcodeString(HloOpcode::kParameter)); - param1->set_parameter_number(1); - *param1->mutable_shape() = ShapeUtil::MakeShape(F64, {}); - - HloInstructionProto* root = computation->add_instructions(); - root->set_opcode(HloOpcodeString(HloOpcode::kAdd)); - root->set_name("root"); - *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2}); - - VLOG(1) << hlo_proto.DebugString(); - - TF_ASSERT_OK_AND_ASSIGN(std::vector<const Shape*> parameter_shapes, - EntryComputationParameterShapes(hlo_proto)); - ASSERT_EQ(parameter_shapes.size(), 3); - EXPECT_TRUE( - ShapeUtil::Equal(*parameter_shapes[0], ShapeUtil::MakeShape(F32, {42}))); - EXPECT_TRUE( - ShapeUtil::Equal(*parameter_shapes[1], ShapeUtil::MakeShape(F64, {}))); - EXPECT_TRUE(ShapeUtil::Equal(*parameter_shapes[2], - ShapeUtil::MakeShape(S32, {1, 2, 3}))); - - TF_ASSERT_OK_AND_ASSIGN(const Shape* output_shape, - EntryComputationOutputShape(hlo_proto)); - EXPECT_TRUE(ShapeUtil::Equal(*output_shape, ShapeUtil::MakeShape(U8, {2}))); -} - -TEST_F(HloProtoUtilTest, ParamsAndOutputShapeNoParameters) { - HloProto hlo_proto; - HloModuleProto* module = hlo_proto.mutable_hlo_module(); - module->set_entry_computation_name("entry"); - HloComputationProto* computation = module->add_computations(); - computation->set_name("entry"); - computation->set_root_name("root"); - - HloInstructionProto* root = computation->add_instructions(); - root->set_opcode(HloOpcodeString(HloOpcode::kAdd)); - root->set_name("root"); - *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2}); - - TF_ASSERT_OK_AND_ASSIGN(std::vector<const Shape*> parameter_shapes, - EntryComputationParameterShapes(hlo_proto)); - ASSERT_EQ(parameter_shapes.size(), 0); -} - TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) { HloProto hlo_proto; @@ -101,60 +38,15 @@ TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) { ::testing::HasSubstr("missing HloModuleProto")); } -TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingEntryComputation) { +TEST_F(HloProtoUtilTest, MissingProgramShape) { HloProto hlo_proto; HloModuleProto* module = hlo_proto.mutable_hlo_module(); - module->set_entry_computation_name("entry"); - HloComputationProto* computation = module->add_computations(); - computation->set_name("not_entry"); - - auto status = EntryComputationParameterShapes(hlo_proto).status(); - ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), - ::testing::HasSubstr("has no entry computation named")); -} - -TEST_F(HloProtoUtilTest, OutputShapeMissingEntryRoot) { - HloProto hlo_proto; - HloModuleProto* module = hlo_proto.mutable_hlo_module(); - module->set_entry_computation_name("entry"); - HloComputationProto* computation = module->add_computations(); - computation->set_name("entry"); - computation->set_root_name("root"); - - auto status = EntryComputationOutputShape(hlo_proto).status(); - ASSERT_FALSE(status.ok()); - ASSERT_THAT(status.error_message(), - ::testing::HasSubstr("has no instruction named")); -} - -TEST_F(HloProtoUtilTest, ParamsShapesMissingParameterNumbers) { - HloProto hlo_proto; - HloModuleProto* module = hlo_proto.mutable_hlo_module(); - module->set_entry_computation_name("entry"); - HloComputationProto* computation = module->add_computations(); - computation->set_name("entry"); - computation->set_root_name("root"); - - HloInstructionProto* param0 = computation->add_instructions(); - param0->set_opcode(HloOpcodeString(HloOpcode::kParameter)); - param0->set_parameter_number(0); - *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42}); - - HloInstructionProto* param2 = computation->add_instructions(); - param2->set_opcode(HloOpcodeString(HloOpcode::kParameter)); - param2->set_parameter_number(2); - *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3}); - - HloInstructionProto* root = computation->add_instructions(); - root->set_opcode(HloOpcodeString(HloOpcode::kAdd)); - root->set_name("root"); - *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2}); + module->set_name("entry"); auto status = EntryComputationParameterShapes(hlo_proto).status(); ASSERT_FALSE(status.ok()); ASSERT_THAT(status.error_message(), - ::testing::HasSubstr("invalid parameter number")); + ::testing::HasSubstr("missing program shape")); } } // namespace |