aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/aot/compile.cc10
-rw-r--r--tensorflow/compiler/aot/flags.cc5
-rw-r--r--tensorflow/compiler/aot/flags.h2
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl28
4 files changed, 38 insertions, 7 deletions
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index eac8da0ab1..2b8cc6024c 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -97,11 +97,15 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client,
&computation,
&compile_result->has_context_arg));
- if (!flags.debug_dir.empty()) {
+ if (!flags.out_session_module.empty()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
computation.Snapshot());
- string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb");
- TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module));
+ // Serialize the SessionModule deterministically so that all the outputs of
+ // a tf_library genrule are deterministic.
+ string proto;
+ TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
+ TF_RETURN_IF_ERROR(
+ WriteStringToFile(Env::Default(), flags.out_session_module, proto));
}
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc
index 5aff10346f..7c2f27e550 100644
--- a/tensorflow/compiler/aot/flags.cc
+++ b/tensorflow/compiler/aot/flags.cc
@@ -33,9 +33,6 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"fetch nodes will be dumped to stdout in a comma-separated list. "
"Typically used to format arguments for other tools, e.g. "
"freeze_graph."},
- {"debug_dir", &flags->debug_dir,
- "Specifies a directory to dump debugging information, including "
- "rewritten graphs and the XLA HLO module."},
// Flags controlling the XLA ahead-of-time compilation, that correspond to
// the fields of xla::cpu::CpuAotCompilationOptions.
//
@@ -64,6 +61,8 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"namespaces are given, within the global namespace."},
{"out_object", &flags->out_object, "Output object file name."},
{"out_header", &flags->out_header, "Output header file name."},
+ {"out_session_module", &flags->out_session_module,
+ "Output session module proto."},
{"gen_name_to_index", &flags->gen_name_to_index,
"Generate name-to-index data for Lookup{Arg,Result}Index methods."},
{"gen_program_shape", &flags->gen_program_shape,
diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h
index 3246dbf95c..3519659e3a 100644
--- a/tensorflow/compiler/aot/flags.h
+++ b/tensorflow/compiler/aot/flags.h
@@ -29,7 +29,6 @@ struct MainFlags {
string graph;
string config;
bool dump_fetch_nodes = false;
- string debug_dir;
string target_triple;
string target_cpu;
string target_features;
@@ -37,6 +36,7 @@ struct MainFlags {
string cpp_class;
string out_object;
string out_header;
+ string out_session_module;
// C++ codegen options
bool gen_name_to_index = false;
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 4888760acd..2adb1dc65e 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -165,6 +165,34 @@ def tf_library(name, graph, config,
tags=tags,
)
+ # Rule that runs tfcompile to produce the SessionModule proto, useful for
+ # debugging. TODO(b/64813587): Once the SessionModule proto is
+ # deterministic, move this into the main rule above.
+ session_module_pb = name + "_session_module.pb"
+ native.genrule(
+ name=(name + "_session_module"),
+ srcs=[
+ tfcompile_graph,
+ config,
+ ],
+ outs=[
+ session_module_pb,
+ ],
+ cmd=("$(location " + tfcompile_tool + ")" +
+ " --graph=$(location " + tfcompile_graph + ")" +
+ " --config=$(location " + config + ")" +
+ " --entry_point=" + ep +
+ " --cpp_class=" + cpp_class +
+ " --target_triple=" + target_llvm_triple() +
+ " --out_session_module=$(@D)/" + session_module_pb +
+ " " + (tfcompile_flags or "")),
+ tools=[tfcompile_tool],
+ visibility=visibility,
+ testonly=testonly,
+ local=1,
+ tags=tags,
+ )
+
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
need_xla_data_proto = (tfcompile_flags and