aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 16:29:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 16:38:49 -0800
commit75e15a2b25f731d7ddf4ffc455a4bf8d1c0fd7ca (patch)
treeb6511c78de81d1c0856c113ce790f06ffe5082ec /tensorflow/compiler/xla
parent721a60801055190dae18fe3e3933950c75fa9d1c (diff)
[XLA] Store the program shape in the HloModuleProto and HloComputationProto.
PiperOrigin-RevId: 188100425
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc68
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc138
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util_test.cc114
6 files changed, 39 insertions, 291 deletions
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index a43785b4a9..66fd317051 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -145,6 +145,9 @@ message HloComputationProto {
// The name of the root of the computation.
string root_name = 3;
+
+ // The program shape (with layout) of this computation.
+ xla.ProgramShape program_shape = 4;
}
// Serialization of HloModule.
@@ -155,6 +158,9 @@ message HloModuleProto {
// The array of computations is always in a valid dependency order, where
// callees appear before their callers.
repeated HloComputationProto computations = 3;
+
+ // The program shape (with layout) of the entry computation.
+ xla.ProgramShape program_shape = 4;
}
// Serialization of HloOrdering.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 21e6b2ca73..f99c7cf5e4 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -399,6 +399,7 @@ HloComputationProto HloComputation::ToProto() const {
proto.add_instructions()->Swap(&instruction_proto);
}
proto.set_root_name(root_instruction()->name());
+ *proto.mutable_program_shape() = ComputeProgramShape();
return proto;
}
@@ -532,7 +533,6 @@ ProgramShape HloComputation::ComputeProgramShape() const {
}
*program_shape.mutable_result() = root_instruction_->shape();
- LayoutUtil::ClearLayout(&program_shape);
return program_shape;
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 39d864efcb..dd9d346999 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -248,7 +248,7 @@ class HloComputation {
ShapeTree<HloInstruction*>* copies_added = nullptr);
// Computes and returns the ProgramShape of this computation (shape of
- // parameters and result without layout).
+ // parameters and result with layout).
ProgramShape ComputeProgramShape() const;
// Return whether `*this` and `other` are functionally equivalent.
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index cb2fe9f874..cdea3d5978 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -213,74 +213,23 @@ HloModuleProto HloModule::ToProto() const {
continue;
}
HloComputationProto computation_proto = computation->ToProto();
+ if (computation->name() == entry_computation_->name()) {
+ *proto.mutable_program_shape() = computation_proto.program_shape();
+ }
proto.add_computations()->Swap(&computation_proto);
}
return proto;
}
-namespace {
-
-// Construct a ProgramShape matching the shape of the parameters and root of the
-// given module's entry computation.
-StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) {
- const HloComputationProto* entry_computation = nullptr;
- for (const HloComputationProto& computation : module.computations()) {
- if (computation.name() == module.entry_computation_name()) {
- entry_computation = &computation;
- break;
- }
- }
- TF_RET_CHECK(entry_computation != nullptr)
- << "No computation with entry computation name"
- << module.entry_computation_name();
-
- tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters;
- const HloInstructionProto* root = nullptr;
- for (const HloInstructionProto& instruction :
- entry_computation->instructions()) {
- if (instruction.name() == entry_computation->root_name()) {
- TF_RET_CHECK(root == nullptr) << "Entry computation has more than "
- "one instruction with (root) name "
- << instruction.name();
- root = &instruction;
- }
- if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
- TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number()))
- << "Entry computation has more than one parameter instruction "
- "with parameter number "
- << instruction.parameter_number();
- parameters[instruction.parameter_number()] = {instruction.name(),
- &instruction.shape()};
- }
- }
- TF_RET_CHECK(root != nullptr)
- << "Entry computation is missing root instruction named "
- << entry_computation->root_name();
-
- ProgramShape program_shape;
- *program_shape.mutable_result() = root->shape();
- for (int64 i = 0; i < parameters.size(); ++i) {
- TF_RET_CHECK(ContainsKey(parameters, i))
- << "Entry computation missing parameter number " << i;
- const string& name = parameters.at(i).first;
- const Shape& shape = *parameters.at(i).second;
- *program_shape.add_parameters() = shape;
- program_shape.add_parameter_names(name);
- }
-
- return std::move(program_shape);
-}
-
-} // namespace
-
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config,
const VersionedComputationHandle& entry_computation_handle) {
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
- TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape,
- ProgramShapeFromProto(proto));
+ TF_RET_CHECK(proto.has_program_shape())
+ << "No program shape found in the proto";
+ const auto& expected_program_shape = proto.program_shape();
TF_RET_CHECK(expected_program_shape.parameters_size() ==
module_config.entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
@@ -354,8 +303,9 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
/* static */
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
const HloModuleProto& module) {
- TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
- ProgramShapeFromProto(module));
+ TF_RET_CHECK(module.has_program_shape())
+ << "No program shape found in the proto";
+ const auto& program_shape = module.program_shape();
HloModuleConfig module_config(program_shape);
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index f75c452082..3460679558 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -21,106 +21,6 @@ limitations under the License.
namespace xla {
-namespace {
-
-// Returns the entry computation of the HLO module in the given HloProto.
-StatusOr<const HloComputationProto*> GetEntryComputation(
- const HloProto& hlo_proto) {
- if (!hlo_proto.has_hlo_module()) {
- return NotFound("HloProto missing HloModuleProto.");
- }
-
- if (hlo_proto.hlo_module().entry_computation_name().empty()) {
- return NotFound("HloProto has empty entry computation name.");
- }
-
- const string& entry_computation_name =
- hlo_proto.hlo_module().entry_computation_name();
- const HloComputationProto* entry_computation = nullptr;
- for (const HloComputationProto& computation :
- hlo_proto.hlo_module().computations()) {
- if (computation.name() == entry_computation_name) {
- if (entry_computation == nullptr) {
- entry_computation = &computation;
- } else {
- return InvalidArgument(
- "HloProto has multiple computations with entry computation named "
- "%s.",
- entry_computation_name.c_str());
- }
- }
- }
- if (entry_computation == nullptr) {
- return InvalidArgument("HloProto has no entry computation named %s.",
- entry_computation_name.c_str());
- }
- return entry_computation;
-}
-
-// Returns the root instruction of the given computation proto.
-StatusOr<const HloInstructionProto*> GetRootInstruction(
- const HloComputationProto& computation) {
- if (computation.root_name().empty()) {
- return InvalidArgument("Missing root instruction name.");
- }
-
- const HloInstructionProto* root = nullptr;
- for (const HloInstructionProto& instruction : computation.instructions()) {
- if (instruction.name() == computation.root_name()) {
- if (root == nullptr) {
- root = &instruction;
- } else {
- return InvalidArgument(
- "Computation has multiple instructions named %s.",
- computation.root_name().c_str());
- }
- }
- }
- if (root == nullptr) {
- return InvalidArgument("Computation has no instruction named %s.",
- computation.root_name().c_str());
- }
- return root;
-}
-
-// Returns the parameters of the given computation. Parameter numbers are
-// checked for validity and contiguousness.
-StatusOr<std::vector<const HloInstructionProto*>> GetParameters(
- const HloComputationProto& computation) {
- std::vector<const HloInstructionProto*> parameters;
- for (const HloInstructionProto& instruction : computation.instructions()) {
- if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
- parameters.push_back(&instruction);
- }
- }
-
- // Verify the uniqueness and validity of the parameter numbers.
- tensorflow::gtl::FlatSet<int64> parameter_numbers;
- for (const HloInstructionProto* parameter : parameters) {
- if (parameter->parameter_number() < 0 ||
- parameter->parameter_number() >= parameters.size()) {
- return InvalidArgument(
- "Parameter instruction %s has invalid parameter number %lld.",
- parameter->name().c_str(), parameter->parameter_number());
- }
- if (parameter_numbers.count(parameter->parameter_number()) != 0) {
- return InvalidArgument(
- "Multiple parameter instructions have parameter number %lld.",
- parameter->parameter_number());
- }
- parameter_numbers.insert(parameter->parameter_number());
- }
-
- std::sort(parameters.begin(), parameters.end(),
- [](const HloInstructionProto* a, const HloInstructionProto* b) {
- return a->parameter_number() < b->parameter_number();
- });
-
- return parameters;
-}
-
-} // namespace
-
HloProto MakeHloProto(const HloModule& module,
const BufferAssignment& assignment) {
HloOrderingProto proto_ordering =
@@ -141,33 +41,33 @@ HloProto MakeHloProto(const HloModule& module) {
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
const HloProto& hlo_proto) {
- TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation,
- GetEntryComputation(hlo_proto));
- TF_ASSIGN_OR_RETURN(std::vector<const HloInstructionProto*> parameters,
- GetParameters(*entry_computation));
+ if (!hlo_proto.has_hlo_module()) {
+ return NotFound("HloProto missing HloModuleProto.");
+ }
+ if (!hlo_proto.hlo_module().has_program_shape()) {
+ return NotFound("HloProto missing program shape.");
+ }
+
std::vector<const Shape*> parameter_shapes;
- for (const HloInstructionProto* parameter : parameters) {
- if (!parameter->has_shape()) {
- return InvalidArgument("Parameter instruction %s is missing shape.",
- parameter->name().c_str());
- }
- parameter_shapes.push_back(&parameter->shape());
+ const auto& program_shape = hlo_proto.hlo_module().program_shape();
+ for (const Shape& shape : program_shape.parameters()) {
+ parameter_shapes.push_back(&shape);
}
return parameter_shapes;
}
StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) {
- TF_ASSIGN_OR_RETURN(const HloComputationProto* entry_computation,
- GetEntryComputation(hlo_proto));
-
- TF_ASSIGN_OR_RETURN(const HloInstructionProto* root,
- GetRootInstruction(*entry_computation));
- if (!root->has_shape()) {
- return InvalidArgument("Instruction %s is missing shape.",
- root->name().c_str());
+ if (!hlo_proto.has_hlo_module()) {
+ return NotFound("HloProto missing HloModuleProto.");
+ }
+ if (!hlo_proto.hlo_module().has_program_shape()) {
+ return NotFound("HloProto missing program shape.");
+ }
+ if (!hlo_proto.hlo_module().program_shape().has_result()) {
+ return NotFound("HloProto missing result in its program shape");
}
- return &root->shape();
+ return &hlo_proto.hlo_module().program_shape().result();
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
index 0c0abf10fa..b9cca13870 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util_test.cc
@@ -29,69 +29,6 @@ namespace {
class HloProtoUtilTest : public ::testing::Test {};
-TEST_F(HloProtoUtilTest, ParamsAndOutputShape) {
- HloProto hlo_proto;
- HloModuleProto* module = hlo_proto.mutable_hlo_module();
- module->set_entry_computation_name("entry");
- HloComputationProto* computation = module->add_computations();
- computation->set_name("entry");
- computation->set_root_name("root");
-
- HloInstructionProto* param0 = computation->add_instructions();
- param0->set_opcode(HloOpcodeString(HloOpcode::kParameter));
- param0->set_parameter_number(0);
- *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42});
-
- HloInstructionProto* param2 = computation->add_instructions();
- param2->set_opcode(HloOpcodeString(HloOpcode::kParameter));
- param2->set_parameter_number(2);
- *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3});
-
- HloInstructionProto* param1 = computation->add_instructions();
- param1->set_opcode(HloOpcodeString(HloOpcode::kParameter));
- param1->set_parameter_number(1);
- *param1->mutable_shape() = ShapeUtil::MakeShape(F64, {});
-
- HloInstructionProto* root = computation->add_instructions();
- root->set_opcode(HloOpcodeString(HloOpcode::kAdd));
- root->set_name("root");
- *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2});
-
- VLOG(1) << hlo_proto.DebugString();
-
- TF_ASSERT_OK_AND_ASSIGN(std::vector<const Shape*> parameter_shapes,
- EntryComputationParameterShapes(hlo_proto));
- ASSERT_EQ(parameter_shapes.size(), 3);
- EXPECT_TRUE(
- ShapeUtil::Equal(*parameter_shapes[0], ShapeUtil::MakeShape(F32, {42})));
- EXPECT_TRUE(
- ShapeUtil::Equal(*parameter_shapes[1], ShapeUtil::MakeShape(F64, {})));
- EXPECT_TRUE(ShapeUtil::Equal(*parameter_shapes[2],
- ShapeUtil::MakeShape(S32, {1, 2, 3})));
-
- TF_ASSERT_OK_AND_ASSIGN(const Shape* output_shape,
- EntryComputationOutputShape(hlo_proto));
- EXPECT_TRUE(ShapeUtil::Equal(*output_shape, ShapeUtil::MakeShape(U8, {2})));
-}
-
-TEST_F(HloProtoUtilTest, ParamsAndOutputShapeNoParameters) {
- HloProto hlo_proto;
- HloModuleProto* module = hlo_proto.mutable_hlo_module();
- module->set_entry_computation_name("entry");
- HloComputationProto* computation = module->add_computations();
- computation->set_name("entry");
- computation->set_root_name("root");
-
- HloInstructionProto* root = computation->add_instructions();
- root->set_opcode(HloOpcodeString(HloOpcode::kAdd));
- root->set_name("root");
- *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2});
-
- TF_ASSERT_OK_AND_ASSIGN(std::vector<const Shape*> parameter_shapes,
- EntryComputationParameterShapes(hlo_proto));
- ASSERT_EQ(parameter_shapes.size(), 0);
-}
-
TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) {
HloProto hlo_proto;
@@ -101,60 +38,15 @@ TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingModule) {
::testing::HasSubstr("missing HloModuleProto"));
}
-TEST_F(HloProtoUtilTest, ParamsAndOutputShapeMissingEntryComputation) {
+TEST_F(HloProtoUtilTest, MissingProgramShape) {
HloProto hlo_proto;
HloModuleProto* module = hlo_proto.mutable_hlo_module();
- module->set_entry_computation_name("entry");
- HloComputationProto* computation = module->add_computations();
- computation->set_name("not_entry");
-
- auto status = EntryComputationParameterShapes(hlo_proto).status();
- ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
- ::testing::HasSubstr("has no entry computation named"));
-}
-
-TEST_F(HloProtoUtilTest, OutputShapeMissingEntryRoot) {
- HloProto hlo_proto;
- HloModuleProto* module = hlo_proto.mutable_hlo_module();
- module->set_entry_computation_name("entry");
- HloComputationProto* computation = module->add_computations();
- computation->set_name("entry");
- computation->set_root_name("root");
-
- auto status = EntryComputationOutputShape(hlo_proto).status();
- ASSERT_FALSE(status.ok());
- ASSERT_THAT(status.error_message(),
- ::testing::HasSubstr("has no instruction named"));
-}
-
-TEST_F(HloProtoUtilTest, ParamsShapesMissingParameterNumbers) {
- HloProto hlo_proto;
- HloModuleProto* module = hlo_proto.mutable_hlo_module();
- module->set_entry_computation_name("entry");
- HloComputationProto* computation = module->add_computations();
- computation->set_name("entry");
- computation->set_root_name("root");
-
- HloInstructionProto* param0 = computation->add_instructions();
- param0->set_opcode(HloOpcodeString(HloOpcode::kParameter));
- param0->set_parameter_number(0);
- *param0->mutable_shape() = ShapeUtil::MakeShape(F32, {42});
-
- HloInstructionProto* param2 = computation->add_instructions();
- param2->set_opcode(HloOpcodeString(HloOpcode::kParameter));
- param2->set_parameter_number(2);
- *param2->mutable_shape() = ShapeUtil::MakeShape(S32, {1, 2, 3});
-
- HloInstructionProto* root = computation->add_instructions();
- root->set_opcode(HloOpcodeString(HloOpcode::kAdd));
- root->set_name("root");
- *root->mutable_shape() = ShapeUtil::MakeShape(U8, {2});
+ module->set_name("entry");
auto status = EntryComputationParameterShapes(hlo_proto).status();
ASSERT_FALSE(status.ok());
ASSERT_THAT(status.error_message(),
- ::testing::HasSubstr("invalid parameter number"));
+ ::testing::HasSubstr("missing program shape"));
}
} // namespace