From bea7cabfd0c32638d4c102c636270cbfecc6665b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 7 Apr 2017 00:46:42 -0800 Subject: Add a tool that converts HLO computations to tensorflow GraphDef which can be visualized on Tensorboard. This CL defines basic tensorflow::OpDef for each HLO instruction/node. More attributes (e.g. shapes, colors) will be added in the future. Change: 152477918 --- tensorflow/compiler/xla/service/hlo_opcode.cc | 13 ++ tensorflow/compiler/xla/service/hlo_opcode.h | 3 + tensorflow/compiler/xla/tools/BUILD | 44 +++++++ .../xla/tools/dumped_computation_to_tf_graphdef.cc | 139 +++++++++++++++++++++ .../compiler/xla/tools/hlo_tfgraph_builder.cc | 129 +++++++++++++++++++ .../compiler/xla/tools/hlo_tfgraph_builder.h | 59 +++++++++ .../compiler/xla/tools/hlo_tfgraph_builder_test.cc | 107 ++++++++++++++++ 7 files changed, 494 insertions(+) create mode 100644 tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc create mode 100644 tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc create mode 100644 tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h create mode 100644 tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 616b239a93..ceb0cdaa31 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -165,4 +165,17 @@ bool HloOpcodeIsComparison(HloOpcode opcode) { } } +bool HloOpcodeIsVariadic(HloOpcode opcode) { + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + return true; + default: + return false; + } +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 978ed5e79b..e2cdbfdfa7 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -104,6 +104,9 @@ inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { // Returns true iff the given opcode is a comparison operation. bool HloOpcodeIsComparison(HloOpcode opcode); +// Returns true iff the given opcode has variadic operands. +bool HloOpcodeIsVariadic(HloOpcode opcode); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_ diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 46eab7f02b..1c4ca44631 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -176,6 +176,50 @@ cc_binary( ], ) +cc_library( + name = "hlo_tfgraph_builder", + srcs = ["hlo_tfgraph_builder.cc"], + hdrs = ["hlo_tfgraph_builder.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "hlo_tfgraph_builder_test", + srcs = ["hlo_tfgraph_builder_test.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_binary( + name = "dumped_computation_to_tf_graphdef", + srcs = ["dumped_computation_to_tf_graphdef.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc new file mode 100644 index 0000000000..1aa769ee5a --- /dev/null +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: dumped_computation_to_tf_graph \ +// --output_dir=/tmp/graphs/ some_binary_snapshot_proto* +// +// Dumps a tensorflow GraphDef in text format for a snapshot computation. The +// dumped graph is an HLO computation with HLO instructions as nodes and can be +// visualized on Tensorboard. Upload the dumped files on Tensorboard. +// +// some_binary_snapshot_proto is obtained by serializing the SessionModule from +// ServiceInterface::SnapshotComputation to disk. + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/service.h" +#include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +using tensorflow::Env; +using tensorflow::io::JoinPath; +using tensorflow::strings::StrAppend; + +namespace xla { +namespace tools { +namespace { + +// Dumps all computations in the module to the given directory. +void DumpTfGraph(const HloModule& module, const string& directory_path) { + Env* env = Env::Default(); + TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); + string fname = module.name(); + std::replace(fname.begin(), fname.end(), '/', '_'); + // Since the file name will be used as the top-level scope name, clean it up + // to make it a valid scope name. + CleanNodeName(&fname); + StrAppend(&fname, ".pbtxt"); + string path = JoinPath(directory_path, fname); + HloTfGraphBuilder builder; + TF_CHECK_OK(builder.AddComputation(*module.entry_computation())); + std::cout << "Dumping " << module.name() << " to " << path << std::endl; + TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef())); +} + +} // namespace + +void RealMain(tensorflow::gtl::ArraySlice args, + const string& output_dir) { + LocalClient* client = ClientLibrary::LocalClientOrDie(); + // To avoid adding a new flag, use local service and lower the computations + // locally. + LocalService* local_service = + ClientLibrary::GetXlaService(client->platform()); + // Build HloModule for each Computation and dump to file. + for (char* arg : args) { + SessionModule session_module; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, + &session_module)); + auto computation_status = client->LoadSnapshot(session_module); + if (!computation_status.ok()) { + fprintf(stderr, "could not load snapshot for %s: %s\n", arg, + computation_status.status().ToString().c_str()); + continue; + } + Computation computation = computation_status.ConsumeValueOrDie(); + + StatusOr user_computation_status = + local_service->computation_tracker().Resolve(computation.handle()); + if (!user_computation_status.ok()) { + fprintf(stderr, + "failed to resolve computation to UserComputation %s: %s\n", arg, + user_computation_status.status().ToString().c_str()); + continue; + } + + auto* user_computation = user_computation_status.ValueOrDie(); + StatusOr> module_status = + local_service->computation_tracker().BuildHloModule( + user_computation->GetVersionedHandle()); + + if (!module_status.ok()) { + fprintf(stderr, "failed to build HloModule %s: %s\n", arg, + module_status.status().ToString().c_str()); + continue; + } + + DumpTfGraph(*module_status.ValueOrDie(), output_dir); + } +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + string output_dir = ""; + const std::vector flag_list = { + tensorflow::Flag("output_dir", &output_dir, + "Directory to write GraphDef data to."), + }; + + string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_ok || output_dir.empty()) { + LOG(QFATAL) << usage; + } + tensorflow::port::InitMain(argv[0], &argc, &argv); + + tensorflow::gtl::ArraySlice args(argv, argc); + args.pop_front(); // Pop off the binary name, argv[0] + xla::tools::RealMain(args, output_dir); + return 0; +} diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc new file mode 100644 index 0000000000..8e04ea3ae4 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +LIcensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/strings/strcat.h" + +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + +namespace xla { +namespace tools { + +static string GetOpDefName(const HloInstruction* instruction) { + string name = StrCat("hlo-", HloOpcodeString(instruction->opcode())); + tensorflow::str_util::TitlecaseString(&name, "-"); + name.erase(std::remove(name.begin(), name.end(), '-'), name.end()); + + if (instruction->opcode() == HloOpcode::kFusion) { + string fusion_name = ToString(instruction->fusion_kind()); + StrAppend(&name, tensorflow::StringPiece(fusion_name).substr(1)); + } + return name; +} + +void CleanNodeName(string* name) { + name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); + const string chars_to_replace = "<>[]"; + auto pred = [&](char c) { + return std::find(chars_to_replace.begin(), chars_to_replace.end(), c) != + chars_to_replace.end(); + }; + std::replace_if(name->begin(), name->end(), pred, '_'); +} + +Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { + LOG(INFO) << "Adding computation " << computation.name(); + for (auto embedded : computation.MakeEmbeddedComputationsList()) { + LOG(INFO) << "Adding embedded computation " << embedded->name(); + for (auto& instruction : embedded->instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + } + for (auto& instruction : computation.instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + } + return Status::OK(); +} + +const GraphDef& HloTfGraphBuilder::GetGraphDef() const { return graph_def_; } + +const string& HloTfGraphBuilder::GetNodeNameForInstruction( + const HloInstruction* instruction) { + if (ContainsKey(instruction_to_node_name_, instruction)) { + return instruction_to_node_name_[instruction]; + } + // If an instruction is fused, put it in the subgraph of the fusion; + // otherwise, put it in the computation subgraph. + string node_name = + instruction->IsFused() + ? GetNodeNameForInstruction(instruction->fusion_instruction()) + : instruction->parent()->name(); + string instruction_name = instruction->name(); + if (instruction->opcode() == HloOpcode::kParameter) { + StrAppend(&instruction_name, ".", instruction->parameter_number()); + } + StrAppend(&node_name, "/", instruction_name); + CleanNodeName(&node_name); + auto ret = + instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); + CHECK(ret.second); + return ret.first->second; +} + +// TODO(b/36987876): Add more attribute information e.g. shapes, dimensions etc. +void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, + NodeDef* node_def) const { + // Set the number of arguments for instructions that have variadic operands. + if (HloOpcodeIsVariadic(instruction->opcode())) { + tensorflow::AttrValue attr_value; + attr_value.set_i(instruction->operands().size()); + (*node_def->mutable_attr())["ArgNum"] = attr_value; + } +} + +Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { + if (!visited_instructions_.insert(instruction).second) { + // Skip instructions that have already been added. + return Status::OK(); + } + + NodeDef* node_def = graph_def_.add_node(); + node_def->set_name(GetNodeNameForInstruction(instruction)); + node_def->set_op(GetOpDefName(instruction)); + SetNodeAttrs(instruction, node_def); + if (instruction->opcode() == HloOpcode::kFusion) { + for (auto& fused_instruction : instruction->fused_instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get())); + } + } + // Add all edges including control edges. + for (unsigned i = 0; i < instruction->operands().size(); ++i) { + *node_def->add_input() = GetNodeNameForInstruction(instruction->operand(i)); + } + // Called computations are control dependencies. + for (const auto* called_computation : instruction->called_computations()) { + *node_def->add_input() = StrCat( + "^", GetNodeNameForInstruction(called_computation->root_instruction())); + } + return Status::OK(); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h new file mode 100644 index 0000000000..3052eae113 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" + +namespace xla { +namespace tools { + +// This constructs a tensorflow graph for HLO computations. +class HloTfGraphBuilder { + public: + // Adds a computation to the graph. + Status AddComputation(const HloComputation& computation); + + const tensorflow::GraphDef& GetGraphDef() const; + + private: + // Gets the node name of an instruction. The node name is hierarchical. For + // example, if an instruction is fused, it will be put in a subgraph of the + // fusion instruction. + const string& GetNodeNameForInstruction(const HloInstruction* instruction); + + void SetNodeAttrs(const HloInstruction* instruction, + tensorflow::NodeDef* node_def) const; + + Status AddInstruction(const HloInstruction* instruction); + + tensorflow::GraphDef graph_def_; + // This records instructions that have been visited. + std::unordered_set visited_instructions_; + // A cache that maps instruction to the node name. + std::unordered_map instruction_to_node_name_; +}; + +// Cleans the node name to make it a valid name in a tensorflow graph. +void CleanNodeName(string* name); + +} // namespace tools +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc new file mode 100644 index 0000000000..57144d0385 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace tools { +namespace { + +using ::tensorflow::GraphDef; + +class HloTfGraphBuilderTest : public HloTestBase { + protected: + HloTfGraphBuilderTest() {} + HloTfGraphBuilder generator_; + + // Create a computation which takes a scalar and returns its negation. + std::unique_ptr CreateNegateComputation() { + auto builder = HloComputation::Builder("Negate"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); + return builder.Build(); + } + + // Creates a computation which calls map with the given computation. + std::unique_ptr CreateMapComputation( + HloComputation* map_computation) { + auto builder = HloComputation::Builder("Map"); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map_computation)); + return builder.Build(); + } + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(HloTfGraphBuilderTest, SimpleNegateComputation) { + auto negate_computation = CreateNegateComputation(); + TF_CHECK_OK(generator_.AddComputation(*negate_computation)); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 2); + EXPECT_EQ(graph_def.node(0).name(), "Negate/param0.0"); + EXPECT_EQ(graph_def.node(0).op(), "HloParameter"); + EXPECT_EQ(graph_def.node(1).name(), "Negate/negate"); + EXPECT_EQ(graph_def.node(1).op(), "HloNegate"); + EXPECT_EQ(graph_def.node(1).input_size(), 1); + EXPECT_EQ(graph_def.node(1).input(0), "Negate/param0.0"); +} + +TEST_F(HloTfGraphBuilderTest, GreaterThanOrEqualTo) { + auto builder = HloComputation::Builder("GE"); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32_, "param1")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kGe, param_1, param_2)); + TF_CHECK_OK(generator_.AddComputation(*builder.Build())); + GraphDef graph_def = generator_.GetGraphDef(); + EXPECT_EQ(graph_def.node_size(), 3); + EXPECT_EQ(graph_def.node(0).name(), "GE/param0.0"); + EXPECT_EQ(graph_def.node(1).name(), "GE/param1.1"); + EXPECT_EQ(graph_def.node(2).input_size(), 2); + EXPECT_EQ(graph_def.node(2).name(), "GE/greater-than-or-equal-to"); + EXPECT_EQ(graph_def.node(2).op(), "HloGreaterThanOrEqualTo"); +} + +TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { + // Create computations with a diamond-shaped callgraph. + auto negate_computation = CreateNegateComputation(); + auto map1_computation = CreateMapComputation(negate_computation.get()); + auto map2_computation = CreateMapComputation(negate_computation.get()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32_, "param0")); + auto map1 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); + auto map2 = builder.AddInstruction( + HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); + auto computation = builder.Build(); + TF_CHECK_OK(generator_.AddComputation(*computation)); + EXPECT_GT(generator_.GetGraphDef().node_size(), 0); +} + +} // namespace +} // namespace tools +} // namespace xla -- cgit v1.2.3