aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-12 02:48:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-12 04:13:14 -0700
commit62850f51dd5e978ac243695efab753490a52ca15 (patch)
tree32fa13ebf92d5d3916e62edda31011212bf05dce
parentb11dff9a7c9571527c50962752456ce9632ebdf3 (diff)
Support dumping HLO graphs as TF GraphDefs in hlo_graph_dumper
- Added a new --xla_hlo_dump_as_graphdef TF_XLA_FLAGS - Moved hlo_tfgraph_builder from xla/tools/ to xla/service/ - Refactored GraphRendererInterface a bit to support both dot graph and tf graph. Change: 152921467
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc4
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h3
-rw-r--r--tensorflow/compiler/xla/service/BUILD28
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc44
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.h31
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc (renamed from tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc)6
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder.h (renamed from tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h)11
-rw-r--r--tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc (renamed from tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc)16
-rw-r--r--tensorflow/compiler/xla/tools/BUILD30
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc99
10 files changed, 126 insertions, 146 deletions
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc
index 8822f6f610..ba43a59195 100644
--- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc
@@ -36,10 +36,14 @@ static std::once_flag flags_init;
static void AllocateFlags() {
flags = new HloGraphDumperFlags;
flags->xla_hlo_dump_graph_path = "/tmp/";
+ flags->xla_hlo_dump_as_graphdef = false;
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag("xla_hlo_dump_graph_path",
&flags->xla_hlo_dump_graph_path,
"Path to write dumped HLO graphs to"),
+ tensorflow::Flag("xla_hlo_dump_as_graphdef",
+ &flags->xla_hlo_dump_as_graphdef,
+ "Dumps HLO graphs as tensorflow GraphDefs"),
});
ParseFlagsFromEnv(*flag_list);
}
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h
index b6dfced87c..d0b4d092ff 100644
--- a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h
@@ -34,6 +34,9 @@ void AppendHloGraphDumperFlags(std::vector<tensorflow::Flag>* flag_list);
// The values of flags associated with XLA's hlo_graph_dumper module.
typedef struct {
string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to
+ // If set, dumps HLO graphs as tensorflow GraphDef; otherwise, dumps HLO
+ // graphs as DOT graph.
+ bool xla_hlo_dump_as_graphdef;
} HloGraphDumperFlags;
// Return a pointer to the HloGraphDumperFlags struct;
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 92e49314d9..c019eff72d 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1393,6 +1393,33 @@ cc_test(
)
cc_library(
+ name = "hlo_tfgraph_builder",
+ srcs = ["hlo_tfgraph_builder.cc"],
+ hdrs = ["hlo_tfgraph_builder.h"],
+ visibility = ["//tensorflow/compiler/xla/tools:__pkg__"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "hlo_tfgraph_builder_test",
+ srcs = ["hlo_tfgraph_builder_test.cc"],
+ deps = [
+ ":hlo_tfgraph_builder",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
name = "hlo_graph_dumper",
srcs = [
"hlo_graph_dumper.cc",
@@ -1401,6 +1428,7 @@ cc_library(
deps = [
":hlo",
":hlo_execution_profile",
+ ":hlo_tfgraph_builder",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 0af4c99d0a..fc8fcfce9e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -414,14 +415,24 @@ namespace {
class FileGraphRenderer : public GraphRendererInterface {
public:
- string RenderGraph(const string& graph) override {
+ string RenderGraph(const string& graph, GraphKind graph_kind) override {
static std::atomic<int> output_num(0);
legacy_flags::HloGraphDumperFlags* flags =
legacy_flags::GetHloGraphDumperFlags();
- string path = StrCat(flags->xla_hlo_dump_graph_path, "hlo_graph_",
- output_num++, ".XXXXXX.dot");
+ string file_extension;
+ switch (graph_kind) {
+ case DOT_GRAPH:
+ file_extension = ".dot";
+ break;
+ case TF_GRAPHDEF:
+ file_extension = ".pbtxt";
+ break;
+ }
+ string path =
+ JoinPath(flags->xla_hlo_dump_graph_path,
+ StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension));
auto status = Status::OK();
- int fd = mkstemps(&path[0], 4);
+ int fd = mkstemps(&path[0], file_extension.length());
if (fd < 0) {
status =
Status(tensorflow::error::Code::UNKNOWN,
@@ -446,10 +457,26 @@ XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
string DumpGraph(const HloComputation& computation, const string& label,
bool show_addresses, bool show_layouts,
const HloExecutionProfile* hlo_execution_profile) {
- string graph = ComputationToDotGraph(computation, label, show_addresses,
- show_layouts, hlo_execution_profile);
-
- string graph_url = GetGraphRenderer()->RenderGraph(graph);
+ string graph;
+ string graph_url;
+ legacy_flags::HloGraphDumperFlags* flags =
+ legacy_flags::GetHloGraphDumperFlags();
+ if (flags->xla_hlo_dump_as_graphdef) {
+ HloTfGraphBuilder builder;
+ TF_CHECK_OK(builder.AddComputation(computation));
+ CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(),
+ &graph));
+ // TODO(b/37198616): Use the default registered renderers when all
+ // renderers support rendering GraphDefs. Always dump GraphDefs to files
+ // for now.
+ graph_url = FileGraphRenderer().RenderGraph(
+ graph, GraphRendererInterface::TF_GRAPHDEF);
+ } else {
+ graph = ComputationToDotGraph(computation, label, show_addresses,
+ show_layouts, hlo_execution_profile);
+ graph_url = GetGraphRenderer()->RenderGraph(
+ graph, GraphRendererInterface::DOT_GRAPH);
+ }
LOG(INFO) << "computation " << computation.name() << " [" << label
<< "]: " << graph_url;
return graph_url;
@@ -467,5 +494,4 @@ void DumpText(const HloModule& module, const string& label,
}
} // namespace hlo_graph_dumper
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
index 5f841da1f3..8ed50c3847 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -25,8 +25,25 @@ limitations under the License.
namespace xla {
namespace hlo_graph_dumper {
-// Dumps a graph of the computation to the GraphViz server and returns
-// a description of the rendered graph (e.g., a URL).
+// Abstract interface for classes that render HLO graphs (e.g. DOT graph,
+// tensorflow GraphDef).
+class GraphRendererInterface {
+ public:
+ enum GraphKind {
+ DOT_GRAPH,
+ TF_GRAPHDEF,
+ };
+
+ virtual ~GraphRendererInterface() = default;
+
+ // Renders a DOT graph, returning a description of the rendered output
+ // (e.g., a URL)
+ virtual string RenderGraph(const string& graph, GraphKind graph_kind) = 0;
+};
+
+// Dumps a graph of the computation and returns a description of the rendered
+// graph (e.g., a URL) based on the renderer. The "best" renderer in the
+// registry is used.
string DumpGraph(const HloComputation& computation, const string& label,
bool show_addresses, bool show_layouts,
const HloExecutionProfile* hlo_execution_profile = nullptr);
@@ -40,16 +57,6 @@ string DumpGraph(const HloComputation& computation, const string& label,
void DumpText(const HloModule& module, const string& label,
const string& directory_path, bool do_prefix = true);
-// Abstract interface for classes that render DOT graphs.
-class GraphRendererInterface {
- public:
- virtual ~GraphRendererInterface() = default;
-
- // Renders a DOT graph, returning a description of the rendered output
- // (e.g., a URL)
- virtual string RenderGraph(const string& graph) = 0;
-};
-
// Graph renderers may be added using a registration mechanism, e.g.:
// XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100)
// The renderer with the highest numeric priority value is used.
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index fe835a20c4..7f2f5bedee 100644
--- a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -31,7 +31,7 @@ using ::tensorflow::strings::StrCat;
using ::tensorflow::str_util::Join;
namespace xla {
-namespace tools {
+namespace hlo_graph_dumper {
namespace {
string GetOpDefName(const HloInstruction* instruction) {
@@ -200,5 +200,5 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
return Status::OK();
}
-} // namespace tools
+} // namespace hlo_graph_dumper
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
index 3052eae113..b2c578af91 100644
--- a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h
@@ -13,16 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
-#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/graph/graph.h"
namespace xla {
-namespace tools {
+namespace hlo_graph_dumper {
// This constructs a tensorflow graph for HLO computations.
class HloTfGraphBuilder {
@@ -53,7 +52,7 @@ class HloTfGraphBuilder {
// Cleans the node name to make it a valid name in a tensorflow graph.
void CleanNodeName(string* name);
-} // namespace tools
+} // namespace hlo_graph_dumper
} // namespace xla
-#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TOOLS_HLO_TFGRAPH_BUILDER_H_
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_
diff --git a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 626bcc6d85..3190f2d703 100644
--- a/tensorflow/compiler/xla/tools/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
+#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
namespace xla {
-namespace tools {
+namespace hlo_graph_dumper {
namespace {
using ::tensorflow::GraphDef;
@@ -40,7 +42,7 @@ class HloTfGraphBuilderTest : public HloTestBase {
// Creates a computation which calls map with the given computation.
std::unique_ptr<HloComputation> CreateMapComputation(
- HloComputation* map_computation) {
+ HloComputation *map_computation) {
auto builder = HloComputation::Builder("Map");
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
@@ -48,18 +50,18 @@ class HloTfGraphBuilderTest : public HloTestBase {
HloInstruction::CreateMap(r0f32_, {param}, map_computation));
return builder.Build();
}
- Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+ Shape r0f32_ = ShapeUtil::MakeShape(PrimitiveType::F32, {});
};
TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) {
auto builder = HloComputation::Builder("Concatenate");
- Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ Shape shape = ShapeUtil::MakeShape(PrimitiveType::F32, {2, 2});
auto param_1 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
auto param_2 = builder.AddInstruction(
HloInstruction::CreateParameter(1, shape, "param1"));
builder.AddInstruction(HloInstruction::CreateConcatenate(
- ShapeUtil::MakeShape(F32, {2, 4}), {param_1, param_2}, 1));
+ ShapeUtil::MakeShape(PrimitiveType::F32, {2, 4}), {param_1, param_2}, 1));
TF_CHECK_OK(generator_.AddComputation(*builder.Build()));
GraphDef graph_def = generator_.GetGraphDef();
EXPECT_EQ(graph_def.node_size(), 3);
@@ -150,5 +152,5 @@ TEST_F(HloTfGraphBuilderTest, EmbeddedComputationsDiamond) {
}
} // namespace
-} // namespace tools
+} // namespace hlo_graph_dumper
} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
index ab598b8edd..1d9baf5de1 100644
--- a/tensorflow/compiler/xla/tools/BUILD
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -176,49 +176,21 @@ cc_binary(
],
)
-cc_library(
- name = "hlo_tfgraph_builder",
- srcs = ["hlo_tfgraph_builder.cc"],
- hdrs = ["hlo_tfgraph_builder.h"],
- deps = [
- "//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
-cc_test(
- name = "hlo_tfgraph_builder_test",
- srcs = ["hlo_tfgraph_builder_test.cc"],
- deps = [
- ":hlo_tfgraph_builder",
- "//tensorflow/compiler/xla/client:computation_builder",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/core:test_main",
- ],
-)
-
cc_binary(
name = "dumped_computation_to_tf_graphdef",
srcs = ["dumped_computation_to_tf_graphdef.cc"],
deps = [
- ":hlo_tfgraph_builder",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:session_proto",
"//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
],
)
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
index 1aa769ee5a..850267d319 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_tf_graphdef.cc
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Usage: dumped_computation_to_tf_graph \
-// --output_dir=/tmp/graphs/ some_binary_snapshot_proto*
+// Usage: dumped_computation_to_tf_graph some_binary_snapshot_proto*
//
// Dumps a tensorflow GraphDef in text format for a snapshot computation. The
// dumped graph is an HLO computation with HLO instructions as nodes and can be
@@ -31,87 +30,31 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/service/session.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/tools/hlo_tfgraph_builder.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
using tensorflow::Env;
-using tensorflow::io::JoinPath;
-using tensorflow::strings::StrAppend;
namespace xla {
namespace tools {
-namespace {
-// Dumps all computations in the module to the given directory.
-void DumpTfGraph(const HloModule& module, const string& directory_path) {
- Env* env = Env::Default();
- TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
- string fname = module.name();
- std::replace(fname.begin(), fname.end(), '/', '_');
- // Since the file name will be used as the top-level scope name, clean it up
- // to make it a valid scope name.
- CleanNodeName(&fname);
- StrAppend(&fname, ".pbtxt");
- string path = JoinPath(directory_path, fname);
- HloTfGraphBuilder builder;
- TF_CHECK_OK(builder.AddComputation(*module.entry_computation()));
- std::cout << "Dumping " << module.name() << " to " << path << std::endl;
- TF_CHECK_OK(WriteTextProto(env, path, builder.GetGraphDef()));
-}
-
-} // namespace
-
-void RealMain(tensorflow::gtl::ArraySlice<char*> args,
- const string& output_dir) {
- LocalClient* client = ClientLibrary::LocalClientOrDie();
- // To avoid adding a new flag, use local service and lower the computations
- // locally.
- LocalService* local_service =
- ClientLibrary::GetXlaService(client->platform());
- // Build HloModule for each Computation and dump to file.
+void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+ Client* client = ClientLibrary::LocalClientOrDie();
for (char* arg : args) {
- SessionModule session_module;
- TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
- &session_module));
- auto computation_status = client->LoadSnapshot(session_module);
- if (!computation_status.ok()) {
- fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
- computation_status.status().ToString().c_str());
- continue;
- }
- Computation computation = computation_status.ConsumeValueOrDie();
-
- StatusOr<UserComputation*> user_computation_status =
- local_service->computation_tracker().Resolve(computation.handle());
- if (!user_computation_status.ok()) {
- fprintf(stderr,
- "failed to resolve computation to UserComputation %s: %s\n", arg,
- user_computation_status.status().ToString().c_str());
- continue;
- }
-
- auto* user_computation = user_computation_status.ValueOrDie();
- StatusOr<std::unique_ptr<HloModule>> module_status =
- local_service->computation_tracker().BuildHloModule(
- user_computation->GetVersionedHandle());
-
- if (!module_status.ok()) {
- fprintf(stderr, "failed to build HloModule %s: %s\n", arg,
- module_status.status().ToString().c_str());
- continue;
- }
-
- DumpTfGraph(*module_status.ValueOrDie(), output_dir);
+ SessionModule module;
+ TF_CHECK_OK(
+ tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
+ Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+ ComputationStats stats =
+ client->GetComputationStats(computation).ConsumeValueOrDie();
+ fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str());
}
}
@@ -119,21 +62,17 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args,
} // namespace xla
int main(int argc, char** argv) {
- string output_dir = "";
- const std::vector<tensorflow::Flag> flag_list = {
- tensorflow::Flag("output_dir", &output_dir,
- "Directory to write GraphDef data to."),
- };
-
- string usage = tensorflow::Flags::Usage(argv[0], flag_list);
- bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
- if (!parse_ok || output_dir.empty()) {
- LOG(QFATAL) << usage;
- }
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags();
+ flags->xla_generate_hlo_graph = ".*";
+
+ xla::legacy_flags::HloGraphDumperFlags* dumper_flags =
+ xla::legacy_flags::GetHloGraphDumperFlags();
+ dumper_flags->xla_hlo_dump_as_graphdef = true;
+
tensorflow::gtl::ArraySlice<char*> args(argv, argc);
args.pop_front(); // Pop off the binary name, argv[0]
- xla::tools::RealMain(args, output_dir);
+ xla::tools::RealMain(args);
return 0;
}