aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot/codegen.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-02 23:33:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 23:37:06 -0700
commit263d025fb6dee974eefb30a51372188fb856d6cc (patch)
treeb32ec04077368f45fbf31da8852b4fe072611e45 /tensorflow/compiler/aot/codegen.cc
parent955c525d416c163c9dd857e637b0476b112b0ea0 (diff)
Add XlaCompiledFunction, a lightweight API for calling XLA computations that are
compiled down to functions. The API is based on a generic form of the original AOT auto-generated header. For AOT (tfcompile), this API has been slotted into the auto-generated header. For JIT, a new XlaCompiledFunctionJit class has been added, which compiles a tensorflow::GraphDef and allows the user to create XlaCompiledFunction objects. XlaCompiledFunction contains optional metadata; mappings from arg/result names to their index, and the program shape. This data is always available via JIT, but only provided via AOT if the tfcompile --gen_name_to_index and --gen_program_shape flags are set. We don't enable by default for AOT to keep binary sizes smaller; the ProgramShape proto pulls in lots of code, and may also be large. PiperOrigin-RevId: 170811579
Diffstat (limited to 'tensorflow/compiler/aot/codegen.cc')
-rw-r--r--tensorflow/compiler/aot/codegen.cc303
1 files changed, 154 insertions, 149 deletions
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index fc5c6ce58d..ae22f7edc4 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -164,10 +164,6 @@ string RewriteWithName(const string& name, string code,
// Generate methods for args (inputs).
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
const CompileResult& compile_result, string* methods) {
- *methods += R"(
- void** args() { return args_; }
- const void *const *args() const { return args_; }
-)";
size_t num_args = ps.parameters_size();
if (compile_result.has_context_arg) {
// If the compiled function needs a XlaLocalRuntimeContext* arg, it's
@@ -184,21 +180,21 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
const string code = R"(
void set_arg{{NAME}}_data(void* data) {
- args_[{{I}}] = data;
+ set_arg_data({{I}}, data);
}
{{TYPE}}* arg{{NAME}}_data() {
- return static_cast<{{TYPE}}*>(args_[{{I}}]);
+ return static_cast<{{TYPE}}*>(arg_data({{I}}));
}
{{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
- args_[{{I}}])){{INDICES}};
+ arg_data({{I}}))){{INDICES}};
}
const {{TYPE}}* arg{{NAME}}_data() const {
- return static_cast<const {{TYPE}}*>(args_[{{I}}]);
+ return static_cast<const {{TYPE}}*>(arg_data({{I}}));
}
const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
- args_[{{I}}])){{INDICES}};
+ arg_data({{I}}))){{INDICES}};
}
)";
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
@@ -213,74 +209,33 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
Status GenResultMethods(const tf2xla::Config& config,
const xla::ProgramShape& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
- // Non-tuple (i.e. single-result) case.
- if (config.fetch_size() != 1) {
- return errors::InvalidArgument(
- "non-tuple result implies 1 fetch, but got ", config.fetch_size(),
- " fetches");
- }
- *methods += R"(
- void** results() { return temps_ + kResultIndex; }
- const void *const *results() const { return temps_ + kResultIndex; }
-)";
- std::vector<std::pair<string, string>> rewrites;
- TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites));
- const string code = R"(
- {{TYPE}}* result{{NAME}}_data() {
- return static_cast<{{TYPE}}*>(temps_[kResultIndex]);
- }
- {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
- return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
- temps_[kResultIndex])){{INDICES}};
- }
- const {{TYPE}}* result{{NAME}}_data() const {
- return static_cast<const {{TYPE}}*>(temps_[kResultIndex]);
- }
- const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
- return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
- temps_[kResultIndex])){{INDICES}};
+ // The XlaCompiler we use to build the xla computation always generates a
+ // tuple result, and we rely on this to simplify code generation.
+ return errors::Internal("codegen requires the XLA result to be a tuple");
}
-)";
- *methods += RewriteWithName("0", code, rewrites);
- if (!config.fetch(0).name().empty()) {
- *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites);
- }
- return Status::OK();
- }
- // Tuple (i.e. multi-result) case.
if (config.fetch_size() != ps.result().tuple_shapes_size()) {
return errors::InvalidArgument("mismatch between fetch_size(",
config.feed_size(), ") and tuple_size(",
ps.result().tuple_shapes_size(), ")");
}
- *methods += R"(
- void** results() {
- return static_cast<void**>(temps_[kResultIndex]);
- }
- const void *const *results() const {
- return static_cast<const void *const *>(temps_[kResultIndex]);
- }
-)";
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
- return static_cast<{{TYPE}}*>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}]);
+ return static_cast<{{TYPE}}*>(result_data({{I}}));
}
{{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
+ result_data({{I}}))){{INDICES}};
}
const {{TYPE}}* result{{NAME}}_data() const {
- return static_cast<{{TYPE}}*>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}]);
+ return static_cast<const {{TYPE}}*>(result_data({{I}}));
}
const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
+ result_data({{I}}))){{INDICES}};
}
)";
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
@@ -291,6 +246,84 @@ Status GenResultMethods(const tf2xla::Config& config,
return Status::OK();
}
+// Generates code implementing {Arg,Result}Names(), where T is one of
+// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
+// literal in the array, with nullptr terminating the array.
+template <typename T>
+string GenNameToIndexCode(const T& entries, bool generate) {
+ // No need for a static array if we're not supposed to generate the data.
+ if (!generate) {
+ return "{\n return nullptr;\n }";
+ }
+ // Determine when to stop. We stop emitting string literals after the last
+ // non-empty name.
+ int end = entries.size();
+ for (int i = entries.size() - 1; i >= 0; --i) {
+ if (!entries[i].name().empty()) {
+ break;
+ }
+ end = i;
+ }
+ // Emit string literals up to the last non-empty name.
+ string code = "{\n static const char* kNames[] = {";
+ for (int i = 0; i < end; ++i) {
+ if (i > 0) {
+ code += ", ";
+ }
+ code += "\"";
+ code += entries[i].name();
+ code += "\"";
+ }
+ if (end > 0) {
+ code += ", ";
+ }
+ code += "nullptr};\n return kNames;\n }";
+ return code;
+}
+
+// Converts the given `str` into a comma-separated list of per-character values.
+string StringToCharList(const string& str) {
+ string list;
+ for (const char c : str) {
+ if (!list.empty()) {
+ list += ",";
+ }
+ list += strings::StrCat(static_cast<int>(c));
+ }
+ return list;
+}
+
+string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) {
+ // No need for any static magic if we're not supposed to generate the data.
+ if (!generate) {
+ return "{\n return nullptr;\n }";
+ }
+ // The parameter names are currently meaningless, and redundant with the rest
+ // of our metadata, so clear them out to avoid confusion and save space.
+ program_shape.clear_parameter_names();
+ const string proto_str = program_shape.SerializeAsString();
+ // Embed the program shape as a serialized protobuf in the header file.
+ //
+ // TODO(toddw): This strategy will likely fail for larger protobufs, depending
+ // on the C++ compiler that is used. Figure out another solution if necessary.
+ string code = R"({
+ static const xla::ProgramShape* kShape = []() {
+ static const char kProto[] = {{{PROTO_LIST}}};
+ static constexpr int kProtoSize = {{PROTO_SIZE}};
+ xla::ProgramShape* shape = new xla::ProgramShape;
+ shape->ParseFromArray(kProto, kProtoSize);
+ return shape;
+ }();
+ return kShape;
+ })";
+ str_util::ReplaceAllPairs(
+ &code, {
+ {"{{PROTO_LIST}}", StringToCharList(proto_str)},
+ {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())},
+ });
+ return code;
+}
+
Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
for (const tf2xla::Feed& feed : config.feed()) {
if (!feed.name().empty()) {
@@ -336,24 +369,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const size_t temp_bytes_total =
total_buffer_bytes(itemp.data(), itemp.size());
- // Create rewrite strings for the optional context arg.
- string context_include;
- string context_set_arg, context_set_thread_pool, context_member_var;
- string run_result = "true";
- string error_msg = "tensorflow::string()";
- if (compile_result.has_context_arg) {
- // NOTE: Extra spaces and newlines are used to ensure nice formatting.
- context_include =
- "#include "
- "\"tensorflow/compiler/tf2xla/"
- "xla_local_runtime_context.h\"\n";
- context_set_arg = " args_[kNumArgs-1] = &context_;\n";
- context_set_thread_pool = " context_.thread_pool = pool;\n";
- context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n";
- run_result = "!context_.error";
- error_msg = "context_.error_msg";
- }
-
// Create rewrite strings for namespace start and end.
string ns_start;
for (const string& n : opts.namespaces) {
@@ -366,6 +381,19 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
ns_end += strings::StrCat("} // end namespace ", n, "\n");
}
+ // Generate metadata.
+ const string arg_names_code =
+ GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
+ const string result_names_code =
+ GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
+ const string include_xla_data_proto =
+ opts.gen_program_shape
+ ?
+ R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
+ : "";
+ const string program_shape_code =
+ GenProgramShapeCode(ps, opts.gen_program_shape);
+
// Use a poor-man's text templating mechanism; first populate the full header
// with placeholder tokens, and then rewrite the tokens with real values.
*header =
@@ -380,22 +408,23 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
#ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
-{{CONTEXT_INCLUDE}}
-#include "tensorflow/compiler/aot/runtime.h"
-#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/core/platform/macros.h"
+{{INCLUDE_XLA_DATA_PROTO}}
+#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
+namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void {{ENTRY}}(
- void* result, xla::ExecutableRunOptions* run_options,
- void** args, void** temps);
+ void* result, const xla::ExecutableRunOptions* run_options,
+ const void** args, void** temps);
{{NS_START}}
// {{CLASS}} represents a computation previously specified in a
-// TensorFlow graph, now compiled into executable code. Usage example:
+// TensorFlow graph, now compiled into executable code. This extends the generic
+// XlaCompiledCpuFunction class with statically type-safe arg and result
+// methods. Usage example:
//
// {{CLASS}} computation;
// // ...set args using computation.argN methods
@@ -411,9 +440,9 @@ extern "C" void {{ENTRY}}(
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
-// o Calls to non-const methods require exclusive access to the object.
-// o Concurrent calls to const methods are OK, if those calls are made while
-// it is guaranteed that no thread may call a non-const method.
+// o Calls to non-const methods require exclusive access to the object.
+// o Concurrent calls to const methods are OK, if those calls are made while it
+// is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
// {{PROGRAM_SHAPE}}
@@ -423,7 +452,7 @@ extern "C" void {{ENTRY}}(
// arg bytes aligned: {{ARG_BYTES_ALIGNED}}
// temp bytes total: {{TEMP_BYTES_TOTAL}}
// temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
-class {{CLASS}} {
+class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
@@ -434,47 +463,31 @@ class {{CLASS}} {
return kArgSizes;
}
- // AllocMode controls the buffer allocation mode.
- enum class AllocMode {
- // Allocate all buffers - args, results and temps.
- ARGS_RESULTS_AND_TEMPS,
-
- // Only allocate result and temp buffers.
- // Use set_argN_data to set argument buffers before Run is called.
- RESULTS_AND_TEMPS_ONLY,
- };
-
- {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
- if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
- alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
- ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
- }
-{{CONTEXT_SET_ARG}}
- alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
- TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
- }
-
- ~{{CLASS}}() {
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
- }
-
- // Sets the thread pool to use during the Run call.
- {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
- run_options_.set_intra_op_thread_pool(pool);
-{{CONTEXT_SET_THREAD_POOL}}
- return *this;
- }
-
- // Runs the computation, with inputs read from arg buffers, and outputs
- // written to result buffers. Returns true on success and false on failure.
- bool Run() {
- {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_);
- return {{RUN_RESULT}};
- }
-
- // Returns the error message from the previous failed Run call.
- tensorflow::string error_msg() const { return {{ERROR_MSG}}; }
+ // Returns static data used to create an XlaCompiledCpuFunction.
+ static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
+ static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
+ XlaCompiledCpuFunction::StaticData* data =
+ new XlaCompiledCpuFunction::StaticData;
+ data->raw_function = {{ENTRY}};
+ data->arg_sizes = ArgSizes();
+ data->num_args = kNumArgs;
+ data->temp_sizes = TempSizes();
+ data->num_temps = kNumTemps;
+ data->result_index = kResultIndex;
+ data->requires_runtime_context = {{HAS_CONTEXT_ARG}};
+ data->arg_names = StaticArgNames();
+ data->result_names = StaticResultNames();
+ data->program_shape = StaticProgramShape();
+ return data;
+ }();
+ return *kStaticData;
+ }
+
+ {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
+ : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
+
+ {{CLASS}}(const {{CLASS}}&) = delete;
+ {{CLASS}}& operator=(const {{CLASS}}&) = delete;
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument, with the following
@@ -493,10 +506,6 @@ class {{CLASS}} {
// Returns a reference to the value of type T for positional argument N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
- //
- // void** args()
- // Returns an array of argument buffers, where args()[N] is the buffer for
- // positional argument N.
{{METHODS_ARG}}
// Result methods for managing output buffers. Buffers are in row-major order.
@@ -511,10 +520,6 @@ class {{CLASS}} {
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
- // void** results()
- // Returns an array of result buffers, where results()[N] is the buffer for
- // positional result N.
- //
// Unlike the arg methods, there is no set_resultN_data method. The result
// buffers are managed internally, and may change after each call to Run.
{{METHODS_RESULT}}
@@ -522,7 +527,7 @@ class {{CLASS}} {
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = {{TEMP_NUM}};
- // The 0-based index of the result in the temporary buffers.
+ // The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
// Byte size of each result / temporary buffer. There are kNumTemps entries.
@@ -531,14 +536,14 @@ class {{CLASS}} {
return kTempSizes;
}
- void* args_[kNumArgs];
- void* temps_[kNumTemps];
- void* alloc_args_ = nullptr;
- void* alloc_temps_ = nullptr;
- xla::ExecutableRunOptions run_options_;
-{{CONTEXT_MEMBER_VAR}}
+ // Array of names of each positional argument, terminated by nullptr.
+ static const char** StaticArgNames() {{ARG_NAMES_CODE}}
+
+ // Array of names of each positional result, terminated by nullptr.
+ static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
- TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}});
+ // Shape of the args and results.
+ static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}}
};
{{NS_END}}
@@ -550,22 +555,22 @@ class {{CLASS}} {
const std::vector<std::pair<string, string>> rewrites = {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
+ {"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
{"{{CLASS}}", opts.class_name},
- {"{{CONTEXT_INCLUDE}}\n", context_include},
- {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var},
- {"{{CONTEXT_SET_ARG}}\n", context_set_arg},
- {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool},
{"{{ENTRY}}", compile_result.entry_point},
- {"{{ERROR_MSG}}", error_msg},
+ {"{{HAS_CONTEXT_ARG}}",
+ compile_result.has_context_arg ? "true" : "false"},
+ {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
{"{{METHODS_ARG}}\n", methods_arg},
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{PROGRAM_SHAPE_CODE}}", program_shape_code},
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
- {"{{RUN_RESULT}}", run_result},
+ {"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},