aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-04-27 20:06:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-27 20:09:38 -0700
commit74747435c2442084e8de53bc73311152f270ae88 (patch)
treed7d9a57b195039d277fc8af9814290a787d04cc4 /tensorflow/compiler/aot
parentce8e19a756f71fa66f60a28515c64c106ca7f6a1 (diff)
HLO profiling for tfcompile.
This CL extends the --xla_hlo_profile knob to tfcompile. tf_library rules can now set enable_xla_hlo_profiling to True to: - Have the generated code update per-HLO profile counters as it executes. - Have tfcompile generate and serialize an instance HloProfilePrinterData with a compiled model that can be used to pretty-print the collected profile counters. PiperOrigin-RevId: 194627272
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/codegen.cc71
-rw-r--r--tensorflow/compiler/aot/codegen.h10
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc2
-rw-r--r--tensorflow/compiler/aot/codegen_test_h.golden11
-rw-r--r--tensorflow/compiler/aot/compile.cc1
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.cc74
-rw-r--r--tensorflow/compiler/aot/embedded_protocol_buffers.h83
-rw-r--r--tensorflow/compiler/aot/tests/BUILD13
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc60
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl13
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc2
11 files changed, 252 insertions, 88 deletions
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 2cae85e896..0025842aea 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -333,6 +333,20 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
: "";
+ const string include_hlo_profile_printer_data_proto =
+ opts.gen_hlo_profile_printer_data
+ ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
+ : "";
+
+ // When HLO profiling is disabled we only forward declare the
+ // HloProfilePrinter protobuf. So we can only conditionally emit this code
+ // calling HloProfilePrinter::profile_counters_size.
+ const string assign_profile_counters_size =
+ opts.gen_hlo_profile_printer_data
+ ? "data->profile_counters_size = "
+ "data->hlo_profile_printer_data->profile_counters_size();"
+ : "";
+
// Use a poor-man's text templating mechanism; first populate the full header
// with placeholder tokens, and then rewrite the tokens with real values.
*header =
@@ -348,6 +362,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
{{INCLUDE_XLA_DATA_PROTO}}
+{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
@@ -418,6 +433,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
+ data->hlo_profile_printer_data = StaticHloProfilePrinterData();
+ {{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
return *kStaticData;
@@ -487,6 +504,13 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
return kShape;
}
+
+ // Metadata that can be used to pretty-print profile counters.
+ static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
+ static const xla::HloProfilePrinterData* kHloProfilePrinterData =
+ {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
+ return kHloProfilePrinterData;
+ }
};
{{NS_END}}
@@ -501,35 +525,41 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
+ {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
+ {"{{DECLS_FROM_OBJ_FILE}}",
+ str_util::Join(metadata_result.header_variable_decls, "\n")},
{"{{ENTRY}}", compile_result.entry_point},
+ {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
+ metadata_result.hlo_profile_printer_data_access_shim},
{"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
+ {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
+ include_hlo_profile_printer_data_proto},
{"{{METHODS_ARG}}\n", methods_arg},
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
+ metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
- {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
- {"{{DECLS_FROM_OBJ_FILE}}",
- str_util::Join(metadata_result.header_variable_decls, "\n")},
- {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
- metadata_result.program_shape_access_shim}};
+ {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}
-static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) {
+static string CreateUniqueIdentifier(const CodegenOpts& opts,
+ StringPiece suffix) {
string result = "__tfcompile";
for (const string& n : opts.namespaces) {
strings::StrAppend(&result, "_", n);
}
- strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape");
+ strings::StrAppend(&result, "_", opts.class_name, "_", suffix);
return result;
}
@@ -550,18 +580,31 @@ Status GenerateMetadata(const CodegenOpts& opts,
// When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
// a shim that evaluates to nullptr, which is what we want.
+ ProtobufToEmbed program_shape_protobuf{
+ CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
+ program_shape.get()};
+
+ ProtobufToEmbed hlo_profile_printer_data_protobuf{
+ CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
+ "xla::HloProfilePrinterData",
+ compile_result.aot->hlo_profile_printer_data()};
+
TF_ASSIGN_OR_RETURN(
- EmbeddedProtocolBuffer embedded_program_shape,
- CreateEmbeddedProtocolBuffer(opts.target_triple,
- CreateUniqueIdentifierForProgramShape(opts),
- "xla::ProgramShape", program_shape.get()));
+ EmbeddedProtocolBuffers embedded_protobufs,
+ CreateEmbeddedProtocolBuffers(
+ opts.target_triple,
+ {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
metadata_result->program_shape_access_shim =
- std::move(embedded_program_shape.cpp_shim_expression);
+ std::move(embedded_protobufs.cpp_shims[0].expression);
+ metadata_result->hlo_profile_printer_data_access_shim =
+ std::move(embedded_protobufs.cpp_shims[1].expression);
+ metadata_result->header_variable_decls.emplace_back(
+ std::move(embedded_protobufs.cpp_shims[0].variable_decl));
metadata_result->header_variable_decls.emplace_back(
- std::move(embedded_program_shape.cpp_variable_decl));
+ std::move(embedded_protobufs.cpp_shims[1].variable_decl));
metadata_result->object_file_data =
- std::move(embedded_program_shape.object_file_data);
+ std::move(embedded_protobufs.object_file_data);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 3430b1f96c..83f2d3ee11 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -44,6 +44,10 @@ struct CodegenOpts {
// If true, generate program shape data for the ProgramShape method.
bool gen_program_shape = false;
+
+ // If true, emit a serialized HloProfilePrinterData protobuf that can be used
+ // to pretty print HLO profile counters.
+ bool gen_hlo_profile_printer_data = false;
};
// Describes a generated metadata object file.
@@ -57,6 +61,12 @@ struct MetadataResult {
// GenerateMetadata.
string program_shape_access_shim;
+ // hlo_profile_printer_data_access_shim is a C++ expression that constructs
+ // the xla::HloProfilePrinterData instance for the CompileResult passed to
+ // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a
+ // C++ expression that evaluates to nullptr at runtime.
+ string hlo_profile_printer_data_access_shim;
+
// The contents of the object (".o") file.
string object_file_data;
};
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 2642536c4f..29bc9c13b8 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -172,7 +172,7 @@ TEST(CodegenTest, Golden) {
fetch->set_name("myfetch");
CompileResult compile_result;
compile_result.aot.reset(
- new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
+ new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index ac3b587331..6e050cf564 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -10,6 +10,7 @@
#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
#include "tensorflow/compiler/xla/xla_data.pb.h"
+
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
@@ -23,6 +24,7 @@ extern "C" void entry_point(
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
+
namespace foo {
namespace bar {
@@ -82,6 +84,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
data->arg_names = StaticArgNames();
data->result_names = StaticResultNames();
data->program_shape = StaticProgramShape();
+ data->hlo_profile_printer_data = StaticHloProfilePrinterData();
+
return data;
}();
return *kStaticData;
@@ -243,6 +247,13 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}();
return kShape;
}
+
+ // Metadata that can be used to pretty-print profile counters.
+ static const xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
+ static const xla::HloProfilePrinterData* kHloProfilePrinterData =
+ nullptr;
+ return kHloProfilePrinterData;
+ }
};
} // end namespace bar
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index e17a7c4bf6..31044ff85d 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -110,6 +110,7 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
flags.target_triple, flags.target_cpu, flags.target_features,
flags.entry_point,
xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
+
return CompileXla(client, computation, aot_opts, compile_result);
}
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
index 0048eec93b..63d22de1ca 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.cc
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc
@@ -36,9 +36,8 @@ namespace tfcompile {
using xla::llvm_ir::AsStringRef;
-static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
- llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine,
- const ::tensorflow::protobuf::MessageLite& proto,
+static void AddEmbeddedProtocolBufferToLlvmModule(
+ llvm::Module* module, const ::tensorflow::protobuf::MessageLite& proto,
StringPiece unique_identifier, string* protobuf_array_symbol_name,
int64* protobuf_array_size) {
string protobuf_array_contents = proto.SerializeAsString();
@@ -46,19 +45,14 @@ static std::unique_ptr<llvm::Module> CreateModuleWithEmbeddedProtocolBuffer(
strings::StrCat(unique_identifier, "_protobuf_array_contents");
*protobuf_array_size = protobuf_array_contents.size();
- std::unique_ptr<llvm::Module> module =
- MakeUnique<llvm::Module>("embedded_data_module", *llvm_context);
-
llvm::Constant* protobuf_array_initializer =
- llvm::ConstantDataArray::getString(*llvm_context,
+ llvm::ConstantDataArray::getString(module->getContext(),
AsStringRef(protobuf_array_contents),
/*AddNull=*/false);
new llvm::GlobalVariable(
*module, protobuf_array_initializer->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage,
protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name));
-
- return module;
}
static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name,
@@ -115,42 +109,44 @@ GetTargetMachineFromTriple(StringPiece target_triple) {
/*Features=*/"", llvm::TargetOptions(), llvm::None));
}
-StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
- StringPiece target_triple, StringPiece symbol_prefix,
- StringPiece qualified_cpp_protobuf_name,
- const ::tensorflow::protobuf::MessageLite* proto) {
+StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
+ StringPiece target_triple,
+ gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<llvm::TargetMachine> target_machine,
GetTargetMachineFromTriple(target_triple));
llvm::LLVMContext llvm_context;
- string object_file, cpp_shim, cpp_variable_decl;
-
- if (proto) {
- string protobuf_array_symbol_name;
- int64 protobuf_array_size;
-
- std::unique_ptr<llvm::Module> module_with_serialized_proto =
- CreateModuleWithEmbeddedProtocolBuffer(
- &llvm_context, target_machine.get(), *proto, symbol_prefix,
- &protobuf_array_symbol_name, &protobuf_array_size);
- TF_ASSIGN_OR_RETURN(object_file,
- CodegenModule(target_machine.get(),
- std::move(module_with_serialized_proto)));
- cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name,
- protobuf_array_symbol_name,
- protobuf_array_size);
-
- cpp_variable_decl = strings::StrCat("extern \"C\" char ",
- protobuf_array_symbol_name, "[];");
- } else {
- TF_ASSIGN_OR_RETURN(
- object_file,
- CodegenModule(target_machine.get(),
- MakeUnique<llvm::Module>("empty_module", llvm_context)));
- cpp_shim = "nullptr";
+ std::unique_ptr<llvm::Module> module_with_serialized_proto =
+ MakeUnique<llvm::Module>("embedded_data_module", llvm_context);
+
+ EmbeddedProtocolBuffers result;
+
+ for (const ProtobufToEmbed& protobuf_to_embed : protobufs_to_embed) {
+ string cpp_shim, cpp_variable_decl;
+ if (protobuf_to_embed.message) {
+ string protobuf_array_symbol_name;
+ int64 protobuf_array_size;
+
+ AddEmbeddedProtocolBufferToLlvmModule(
+ module_with_serialized_proto.get(), *protobuf_to_embed.message,
+ protobuf_to_embed.symbol_prefix, &protobuf_array_symbol_name,
+ &protobuf_array_size);
+ cpp_shim = CreateCPPShimExpression(
+ protobuf_to_embed.qualified_cpp_protobuf_name,
+ protobuf_array_symbol_name, protobuf_array_size);
+
+ cpp_variable_decl = strings::StrCat("extern \"C\" char ",
+ protobuf_array_symbol_name, "[];");
+ } else {
+ cpp_shim = "nullptr";
+ }
+ result.cpp_shims.push_back({cpp_shim, cpp_variable_decl});
}
- return {{cpp_shim, cpp_variable_decl, object_file}};
+ TF_ASSIGN_OR_RETURN(result.object_file_data,
+ CodegenModule(target_machine.get(),
+ std::move(module_with_serialized_proto)));
+ return result;
}
} // namespace tfcompile
diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h
index 8436e0ff67..ebfe4806c2 100644
--- a/tensorflow/compiler/aot/embedded_protocol_buffers.h
+++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h
@@ -21,51 +21,70 @@ limitations under the License.
#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_
#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
namespace tfcompile {
using xla::StatusOr;
-// Represents a protocol buffer embedded into an object file and describes a way
-// to access it at runtime.
-struct EmbeddedProtocolBuffer {
- // cpp_shim_expression is a C++ expression that creates an instance of said
- // protocol buffer when executed.
- string cpp_shim_expression;
-
- // cpp_variable_decl is an "extern C" array declaration that is used in
- // cpp_shim_expression. It must be visible wherever cpp_shim_expression is
- // emitted.
- string cpp_variable_decl;
-
- // The contents of the object (".o") file the protocol buffer is embbed in.
- // This needs to be linked in to any program that wants to execute
- // cpp_variable_decl .
+// Represents a set of protocol buffers embedded into an object file and
+// describes how to access them at runtime.
+struct EmbeddedProtocolBuffers {
+ // Each instance CPPShim describes how to generate C++ code to instantiate a
+ // protobuf instance from the corresponding static data emitted into the
+ // object file.
+ struct CPPShim {
+ // `expression` is a C++ expression that creates an instance of said
+ // protocol buffer when executed.
+ string expression;
+
+ // `variable_decl` is an "extern C" array declaration that is used in
+ // `expression`. It must be visible wherever `expression` is emitted.
+ string variable_decl;
+ };
+
+ // Each cpp_shim corresponds to one embedded protocol buffer.
+ std::vector<CPPShim> cpp_shims;
+
+ // The contents of the object (".o") file the protocol buffers are embbed in.
+ // This needs to be linked in to any program that wants to execute any of the
+ // expressions in `cpp_shims`.
string object_file_data;
};
-// Creates an object file that contains `proto`.
-//
-// `proto` is allowed to be nullptr, in which case the generated C++ shim
-// expression is just `nullptr`, and the generated object file does not define
-// any symbols.
+// Describes a protocol buffer to embed into an object file.
+struct ProtobufToEmbed {
+ // `symbol_prefix` is prefix that is guaranteed to be unique across the binary
+ // or DSO the generated object file will be linked into.
+ string symbol_prefix;
+
+ // `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
+ // namespace qualified) protocol buffer name. This is only used in
+ // CPPShim::expression so relatively qualified names are fine as long as
+ // they're valid wherever CPPShim::expression is emitted.
+ string qualified_cpp_protobuf_name;
+
+ // `message` is the protocol buffer to be embedded. It is allowed to be
+ // nullptr, in which case the generated C++ shim expression is just `nullptr`,
+ // and the generated object file does not define any symbols.
+ const ::tensorflow::protobuf::MessageLite* message;
+};
+
+// Embeds a a sequence of protocol buffers into an object file.
//
// `target_triple` is the target triple for the target architecture for the
// generated object file.
//
-// `symbol_prefix` is prefix that is guaranteed to be unique across the binary
-// or DSO the generated object file will be linked into.
-//
-// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++
-// namespace qualified) protocol buffer name. This needs is only used in
-// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified
-// names are fine as long as they're valid wherever cpp_shim_expression
-// is emitted.
-StatusOr<EmbeddedProtocolBuffer> CreateEmbeddedProtocolBuffer(
- StringPiece target_triple, StringPiece symbol_prefix,
- StringPiece qualified_cpp_protobuf_name,
- const ::tensorflow::protobuf::MessageLite* proto);
+// `protobufs_to_embed` describes the protocol buffers to embed into the
+// resulting object file. The C++ shim for protobufs_to_embed[i] is
+// cpp_shims[i] in the returned EmbeddedProtocolBuffers instance. The contents
+// of all the protocol buffers are embedded into a single .o file whose content
+// is stored in the object_file_data field in the returned
+// EmbeddedProtocolBuffers instance.
+StatusOr<EmbeddedProtocolBuffers> CreateEmbeddedProtocolBuffers(
+ StringPiece target_triple,
+ gtl::ArraySlice<ProtobufToEmbed> protobufs_to_embed);
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index bb73cb19c5..222e26810a 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -164,6 +164,15 @@ tf_library(
)
tf_library(
+ name = "test_graph_tfmatmulandadd_with_profiling",
+ testonly = 1,
+ config = "test_graph_tfmatmulandadd.config.pbtxt",
+ cpp_class = "MatMulAndAddCompWithProfiling",
+ enable_xla_hlo_profiling = True,
+ graph = "test_graph_tfmatmulandadd.pb",
+)
+
+tf_library(
name = "test_graph_tfsplits",
testonly = 1,
config = "test_graph_tfsplits.config.pbtxt",
@@ -189,9 +198,13 @@ tf_cc_test(
":test_graph_tfgather",
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
+ ":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_profile_printer",
+ "//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 67dbd643bf..aa9d968265 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -25,15 +25,22 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace tfcompile {
namespace {
+using ::testing::HasSubstr;
+using ::testing::UnorderedElementsAre;
+
TEST(TFCompileTest, Add) {
AddComp add;
EXPECT_EQ(add.arg0_data(), add.args()[0]);
@@ -484,6 +491,59 @@ TEST(TFCompileTest, ProgramShape) {
EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
}
+TEST(TFCompileTest, HloProfiling) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ MatMulAndAddCompWithProfiling fn;
+ ASSERT_TRUE(fn.hlo_profiling_enabled());
+
+ fn.set_thread_pool(&device);
+
+ // x = [[1, 2], [3, 4]]
+ fn.arg0(0, 0) = 1;
+ fn.arg0(0, 1) = 2;
+ fn.arg0(1, 0) = 3;
+ fn.arg0(1, 1) = 4;
+
+ // y = [[10, 20], [30, 40]]
+ fn.arg1(0, 0) = 10;
+ fn.arg1(0, 1) = 20;
+ fn.arg1(1, 0) = 30;
+ fn.arg1(1, 1) = 40;
+
+ EXPECT_TRUE(fn.Run());
+
+ string hlo_profile_as_string =
+ xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
+ /*clock_rate_ghz=*/1.0);
+ VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
+
+ std::vector<string> hlo_profile_lines =
+ tensorflow::str_util::Split(hlo_profile_as_string, '\n');
+
+ auto header = HasSubstr("Execution profile for");
+ auto total_cycles_profile_line = HasSubstr("[total]");
+ auto dot_profile_line = HasSubstr(
+ "%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ auto add_profile_line = HasSubstr(
+ "%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
+ auto tuple_profile_line = HasSubstr(
+ "%tuple.2 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
+ "f32[2,2]{1,0} %add)");
+ auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
+ auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
+
+ hlo_profile_lines.erase(hlo_profile_lines.begin() + 7,
+ hlo_profile_lines.end());
+
+ EXPECT_THAT(
+ hlo_profile_lines,
+ UnorderedElementsAre(header, total_cycles_profile_line, dot_profile_line,
+ add_profile_line, tuple_profile_line,
+ arg0_profile_line, arg1_profile_line));
+}
+
} // namespace
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 3a877c5337..5c57fee326 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -25,7 +25,8 @@ def tf_library(name, graph, config,
visibility=None, testonly=None,
tfcompile_flags=None,
tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
- include_standard_runtime_deps=True, deps=None, tags=None):
+ include_standard_runtime_deps=True,
+ enable_xla_hlo_profiling=False, deps=None, tags=None):
"""Runs tfcompile to compile a TensorFlow graph into executable code.
Given an invocation of tf_library(name="foo", ...), generates the following
@@ -68,6 +69,8 @@ def tf_library(name, graph, config,
include_standard_runtime_deps: If True, the standard list of kernel/runtime
deps is added to deps. If False, deps must contain the full set of deps
needed by the generated library.
+ enable_xla_hlo_profiling: Enable XLA HLO profiling in the generated program,
+ and emit metadata that lets us pretty-print the gathered profile counters.
deps: a list of deps to include on the build rules for the generated
library, added to the standard deps if standard_runtime_deps is True.
tags: tags to apply to subsidiary build rules.
@@ -137,6 +140,10 @@ def tf_library(name, graph, config,
flags = tfcompile_flags
else:
flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])])
+ if enable_xla_hlo_profiling:
+ profiling_flag = "--xla_hlo_profile"
+ else:
+ profiling_flag = ""
native.genrule(
name=("gen_" + name),
srcs=[
@@ -157,7 +164,7 @@ def tf_library(name, graph, config,
" --out_header=$(@D)/" + header_file +
" --out_metadata_object=$(@D)/" + metadata_object_file +
" --out_function_object=$(@D)/" + function_object_file +
- " " + flags),
+ " " + flags + " " + profiling_flag),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
@@ -220,6 +227,8 @@ def tf_library(name, graph, config,
] + (need_xla_data_proto and [
# If we're generating the program shape, we must depend on the proto.
"//tensorflow/compiler/xla:xla_data_proto",
+ ] or []) + (enable_xla_hlo_profiling and [
+ "//tensorflow/compiler/xla/service:hlo_profile_printer_data"
] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index 8ea014c2ee..839e1588b7 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -100,6 +100,8 @@ Status Main(const MainFlags& flags) {
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
+ codegen_opts.gen_hlo_profile_printer_data =
+ xla::legacy_flags::GetDebugOptionsFromFlags().xla_hlo_profile();
TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name,
&codegen_opts.namespaces));