diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-04-12 02:48:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-12 04:13:14 -0700 |
commit | 62850f51dd5e978ac243695efab753490a52ca15 (patch) | |
tree | 32fa13ebf92d5d3916e62edda31011212bf05dce | |
parent | b11dff9a7c9571527c50962752456ce9632ebdf3 (diff) |
Support dumping HLO graphs as TF GraphDefs in hlo_graph_dumper
- Added a new --xla_hlo_dump_as_graphdef TF_XLA_FLAGS
- Moved hlo_tfgraph_builder from xla/tools/ to xla/service/
- Refactored GraphRendererInterface a bit to support both dot graph and tf graph.
Change: 152921467
-rw-r--r-- | tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 28 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.cc | 44 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_graph_dumper.h | 31 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc (renamed from tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc) | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_tfgraph_builder.h (renamed from tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h) | 11 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc (renamed from tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc) | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/BUILD | 30 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc | 99 |
10 files changed, 126 insertions, 146 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc index 8822f6f610..ba43a59195 100644 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc @@ -36,10 +36,14 @@ static std::once_flag flags_init; static void AllocateFlags() { flags = new HloGraphDumperFlags; flags->xla_hlo_dump_graph_path = "/tmp/"; + flags->xla_hlo_dump_as_graphdef = false; flag_list = new std::vector<tensorflow::Flag>({ tensorflow::Flag("xla_hlo_dump_graph_path", &flags->xla_hlo_dump_graph_path, "Path to write dumped HLO graphs to"), + tensorflow::Flag("xla_hlo_dump_as_graphdef", + &flags->xla_hlo_dump_as_graphdef, + "Dumps HLO graphs as tensorflow GraphDefs"), }); ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h index b6dfced87c..d0b4d092ff 100644 --- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h @@ -34,6 +34,9 @@ void AppendHloGraphDumperFlags(std::vector<tensorflow::Flag>* flag_list); // The values of flags associated with XLA's hlo_graph_dumper module. typedef struct { string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to + // If set, dumps HLO graphs as tensorflow GraphDef; otherwise, dumps HLO + // graphs as DOT graph. + bool xla_hlo_dump_as_graphdef; } HloGraphDumperFlags; // Return a pointer to the HloGraphDumperFlags struct; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 92e49314d9..c019eff72d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1393,6 +1393,33 @@ cc_test( ) cc_library( + name = "hlo_tfgraph_builder", + srcs = ["hlo_tfgraph_builder.cc"], + hdrs = ["hlo_tfgraph_builder.h"], + visibility = ["//tensorflow/compiler/xla/tools:__pkg__"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "hlo_tfgraph_builder_test", + srcs = ["hlo_tfgraph_builder_test.cc"], + deps = [ + ":hlo_tfgraph_builder", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test_main", + ], +) + +cc_library( name = "hlo_graph_dumper", srcs = [ "hlo_graph_dumper.cc", @@ -1401,6 +1428,7 @@ cc_library( deps = [ ":hlo", ":hlo_execution_profile", + ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 0af4c99d0a..fc8fcfce9e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -414,14 +415,24 @@ namespace { class FileGraphRenderer : public GraphRendererInterface { public: - string RenderGraph(const string& graph) override { + string RenderGraph(const string& graph, GraphKind graph_kind) override { static std::atomic<int> output_num(0); legacy_flags::HloGraphDumperFlags* flags = legacy_flags::GetHloGraphDumperFlags(); - string path = StrCat(flags->xla_hlo_dump_graph_path, "hlo_graph_", - output_num++, ".XXXXXX.dot"); + string file_extension; + switch (graph_kind) { + case DOT_GRAPH: + file_extension = ".dot"; + break; + case TF_GRAPHDEF: + file_extension = ".pbtxt"; + break; + } + string path = + JoinPath(flags->xla_hlo_dump_graph_path, + StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); auto status = Status::OK(); - int fd = mkstemps(&path[0], 4); + int fd = mkstemps(&path[0], file_extension.length()); if (fd < 0) { status = Status(tensorflow::error::Code::UNKNOWN, @@ -446,10 +457,26 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); string DumpGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile) { - string graph = ComputationToDotGraph(computation, label, show_addresses, - show_layouts, hlo_execution_profile); - - string graph_url = GetGraphRenderer()->RenderGraph(graph); + string graph; + string graph_url; + legacy_flags::HloGraphDumperFlags* flags = + legacy_flags::GetHloGraphDumperFlags(); + if (flags->xla_hlo_dump_as_graphdef) { + HloTfGraphBuilder builder; + TF_CHECK_OK(builder.AddComputation(computation)); + CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), + &graph)); + // TODO(b/37198616): Use the default registered renderers when all + // renderers support rendering GraphDefs. Always dump GraphDefs to files + // for now. + graph_url = FileGraphRenderer().RenderGraph( + graph, GraphRendererInterface::TF_GRAPHDEF); + } else { + graph = ComputationToDotGraph(computation, label, show_addresses, + show_layouts, hlo_execution_profile); + graph_url = GetGraphRenderer()->RenderGraph( + graph, GraphRendererInterface::DOT_GRAPH); + } LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; @@ -467,5 +494,4 @@ void DumpText(const HloModule& module, const string& label, } } // namespace hlo_graph_dumper - } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index 5f841da1f3..8ed50c3847 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -25,8 +25,25 @@ limitations under the License. namespace xla { namespace hlo_graph_dumper { -// Dumps a graph of the computation to the GraphViz server and returns -// a description of the rendered graph (e.g., a URL). +// Abstract interface for classes that render HLO graphs (e.g. DOT graph, +// tensorflow GraphDef). +class GraphRendererInterface { + public: + enum GraphKind { + DOT_GRAPH, + TF_GRAPHDEF, + }; + + virtual ~GraphRendererInterface() = default; + + // Renders a DOT graph, returning a description of the rendered output + // (e.g., a URL) + virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0; +}; + +// Dumps a graph of the computation and returns a description of the rendered +// graph (e.g., a URL) based on the renderer. The "best" renderer in the +// registry is used. string DumpGraph(const HloComputation& computation, const string& label, bool show_addresses, bool show_layouts, const HloExecutionProfile* hlo_execution_profile = nullptr); @@ -40,16 +57,6 @@ string DumpGraph(const HloComputation& computation, const string& label, void DumpText(const HloModule& module, const string& label, const string& directory_path, bool do_prefix = true); -// Abstract interface for classes that render DOT graphs. -class GraphRendererInterface { - public: - virtual ~GraphRendererInterface() = default; - - // Renders a DOT graph, returning a description of the rendered output - // (e.g., a URL) - virtual string RenderGraph(const string& graph) = 0; -}; - // Graph renderers may be added using a registration mechanism, e.g.: // XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100) // The renderer with the highest numeric priority value is used. diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index fe835a20c4..7f2f5bedee 100644 --- a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -31,7 +31,7 @@ using ::tensorflow::strings::StrCat; using ::tensorflow::str_util::Join; namespace xla { -namespace tools { +namespace hlo_graph_dumper { namespace { string GetOpDefName(const HloInstruction* instruction) { @@ -200,5 +200,5 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { return Status::OK(); } -} // namespace tools +} // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h index 3052eae113..b2c578af91 100644 --- a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -13,16 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/graph/graph.h" namespace xla { -namespace tools { +namespace hlo_graph_dumper { // This constructs a tensorflow graph for HLO computations. class HloTfGraphBuilder { @@ -53,7 +52,7 @@ class HloTfGraphBuilder { // Cleans the node name to make it a valid name in a tensorflow graph. void CleanNodeName(string* name); -} // namespace tools +} // namespace hlo_graph_dumper } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_ +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index 626bcc6d85..3190f2d703 100644 --- a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -13,12 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" +#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" namespace xla { -namespace tools { +namespace hlo_graph_dumper { namespace { using ::tensorflow::GraphDef; @@ -40,7 +42,7 @@ class HloTfGraphBuilderTest : public HloTestBase { // Creates a computation which calls map with the given computation. std::unique_ptr<HloComputation> CreateMapComputation( - HloComputation* map_computation) { + HloComputation *map_computation) { auto builder = HloComputation::Builder("Map"); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32_, "param0")); @@ -48,18 +50,18 @@ class HloTfGraphBuilderTest : public HloTestBase { HloInstruction::CreateMap(r0f32_, {param}, map_computation)); return builder.Build(); } - Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {}); }; TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { auto builder = HloComputation::Builder("Concatenate"); - Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2}); auto param_1 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); auto param_2 = builder.AddInstruction( HloInstruction::CreateParameter(1, shape, "param1")); builder.AddInstruction(HloInstruction::CreateConcatenate( - ShapeUtil::MakeShape(F32, {2, 4}), {param_1, param_2}, 1)); + ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1)); TF_CHECK_OK(generator_.AddComputation(*builder.Build())); GraphDef graph_def = generator_.GetGraphDef(); EXPECT_EQ(graph_def.node_size(), 3); @@ -150,5 +152,5 @@ TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) { } } // namespace -} // namespace tools +} // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index ab598b8edd..1d9baf5de1 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -176,49 +176,21 @@ cc_binary( ], ) -cc_library( - name = "hlo_tfgraph_builder", - srcs = ["hlo_tfgraph_builder.cc"], - hdrs = ["hlo_tfgraph_builder.h"], - deps = [ - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], -) - -cc_test( - name = "hlo_tfgraph_builder_test", - srcs = ["hlo_tfgraph_builder_test.cc"], - deps = [ - ":hlo_tfgraph_builder", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/core:test_main", - ], -) - cc_binary( name = "dumped_computation_to_tf_graphdef", srcs = ["dumped_computation_to_tf_graphdef.cc"], deps = [ - ":hlo_tfgraph_builder", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc index 1aa769ee5a..850267d319 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Usage: dumped_computation_to_tf_graph \ -// --output_dir=/tmp/graphs/ some_binary_snapshot_proto* +// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto* // // Dumps a tensorflow GraphDef in text format for a snapshot computation. The // dumped graph is an HLO computation with HLO instructions as nodes and can be @@ -31,87 +30,31 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" #include "tensorflow/compiler/xla/service/service.h" #include "tensorflow/compiler/xla/service/session.pb.h" #include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" using tensorflow::Env; -using tensorflow::io::JoinPath; -using tensorflow::strings::StrAppend; namespace xla { namespace tools { -namespace { -// Dumps all computations in the module to the given directory. -void DumpTfGraph(const HloModule& module, const string& directory_path) { - Env* env = Env::Default(); - TF_CHECK_OK(env->RecursivelyCreateDir(directory_path)); - string fname = module.name(); - std::replace(fname.begin(), fname.end(), '/', '_'); - // Since the file name will be used as the top-level scope name, clean it up - // to make it a valid scope name. - CleanNodeName(&fname); - StrAppend(&fname, ".pbtxt"); - string path = JoinPath(directory_path, fname); - HloTfGraphBuilder builder; - TF_CHECK_OK(builder.AddComputation(*module.entry_computation())); - std::cout << "Dumping " << module.name() << " to " << path << std::endl; - TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef())); -} - -} // namespace - -void RealMain(tensorflow::gtl::ArraySlice<char*> args, - const string& output_dir) { - LocalClient* client = ClientLibrary::LocalClientOrDie(); - // To avoid adding a new flag, use local service and lower the computations - // locally. - LocalService* local_service = - ClientLibrary::GetXlaService(client->platform()); - // Build HloModule for each Computation and dump to file. +void RealMain(tensorflow::gtl::ArraySlice<char*> args) { + Client* client = ClientLibrary::LocalClientOrDie(); for (char* arg : args) { - SessionModule session_module; - TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, - &session_module)); - auto computation_status = client->LoadSnapshot(session_module); - if (!computation_status.ok()) { - fprintf(stderr, "could not load snapshot for %s: %s\n", arg, - computation_status.status().ToString().c_str()); - continue; - } - Computation computation = computation_status.ConsumeValueOrDie(); - - StatusOr<UserComputation*> user_computation_status = - local_service->computation_tracker().Resolve(computation.handle()); - if (!user_computation_status.ok()) { - fprintf(stderr, - "failed to resolve computation to UserComputation %s: %s\n", arg, - user_computation_status.status().ToString().c_str()); - continue; - } - - auto* user_computation = user_computation_status.ValueOrDie(); - StatusOr<std::unique_ptr<HloModule>> module_status = - local_service->computation_tracker().BuildHloModule( - user_computation->GetVersionedHandle()); - - if (!module_status.ok()) { - fprintf(stderr, "failed to build HloModule %s: %s\n", arg, - module_status.status().ToString().c_str()); - continue; - } - - DumpTfGraph(*module_status.ValueOrDie(), output_dir); + SessionModule module; + TF_CHECK_OK( + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module)); + Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); + ComputationStats stats = + client->GetComputationStats(computation).ConsumeValueOrDie(); + fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str()); } } @@ -119,21 +62,17 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args, } // namespace xla int main(int argc, char** argv) { - string output_dir = ""; - const std::vector<tensorflow::Flag> flag_list = { - tensorflow::Flag("output_dir", &output_dir, - "Directory to write GraphDef data to."), - }; - - string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - if (!parse_ok || output_dir.empty()) { - LOG(QFATAL) << usage; - } tensorflow::port::InitMain(argv[0], &argc, &argv); + xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags(); + flags->xla_generate_hlo_graph = ".*"; + + xla::legacy_flags::HloGraphDumperFlags* dumper_flags = + xla::legacy_flags::GetHloGraphDumperFlags(); + dumper_flags->xla_hlo_dump_as_graphdef = true; + tensorflow::gtl::ArraySlice<char*> args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - xla::tools::RealMain(args, output_dir); + xla::tools::RealMain(args); return 0; } |