diff options
author | 2017-10-02 23:33:20 -0700 | |
---|---|---|
committer | 2017-10-02 23:37:06 -0700 | |
commit | 263d025fb6dee974eefb30a51372188fb856d6cc (patch) | |
tree | b32ec04077368f45fbf31da8852b4fe072611e45 /tensorflow/compiler/aot/codegen.cc | |
parent | 955c525d416c163c9dd857e637b0476b112b0ea0 (diff) |
Add XlaCompiledFunction, a lightweight API for calling XLA computations that are
compiled down to functions. The API is based on a generic form of the original
AOT auto-generated header.
For AOT (tfcompile), this API has been slotted into the auto-generated header.
For JIT, a new XlaCompiledFunctionJit class has been added, which compiles a
tensorflow::GraphDef and allows the user to create XlaCompiledFunction objects.
XlaCompiledFunction contains optional metadata; mappings from arg/result names
to their index, and the program shape. This data is always available via JIT,
but only provided via AOT if the tfcompile --gen_name_to_index and
--gen_program_shape flags are set. We don't enable by default for AOT to keep
binary sizes smaller; the ProgramShape proto pulls in lots of code, and may also
be large.
PiperOrigin-RevId: 170811579
Diffstat (limited to 'tensorflow/compiler/aot/codegen.cc')
-rw-r--r-- | tensorflow/compiler/aot/codegen.cc | 303 |
1 files changed, 154 insertions, 149 deletions
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index fc5c6ce58d..ae22f7edc4 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,10 +164,6 @@ string RewriteWithName(const string& name, string code, // Generate methods for args (inputs). Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { - *methods += R"( - void** args() { return args_; } - const void *const *args() const { return args_; } -)"; size_t num_args = ps.parameters_size(); if (compile_result.has_context_arg) { // If the compiled function needs a XlaLocalRuntimeContext* arg, it's @@ -184,21 +180,21 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); const string code = R"( void set_arg{{NAME}}_data(void* data) { - args_[{{I}}] = data; + set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { - return static_cast<{{TYPE}}*>(args_[{{I}}]); + return static_cast<{{TYPE}}*>(arg_data({{I}})); } {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } const {{TYPE}}* arg{{NAME}}_data() const { - return static_cast<const {{TYPE}}*>(args_[{{I}}]); + return static_cast<const {{TYPE}}*>(arg_data({{I}})); } const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const { return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -213,74 +209,33 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, Status GenResultMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { - // Non-tuple (i.e. single-result) case. - if (config.fetch_size() != 1) { - return errors::InvalidArgument( - "non-tuple result implies 1 fetch, but got ", config.fetch_size(), - " fetches"); - } - *methods += R"( - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } -)"; - std::vector<std::pair<string, string>> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites)); - const string code = R"( - {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>(temps_[kResultIndex]); - } - {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { - return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - temps_[kResultIndex])){{INDICES}}; - } - const {{TYPE}}* result{{NAME}}_data() const { - return static_cast<const {{TYPE}}*>(temps_[kResultIndex]); - } - const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { - return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( - temps_[kResultIndex])){{INDICES}}; + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and we rely on this to simplify code generation. + return errors::Internal("codegen requires the XLA result to be a tuple"); } -)"; - *methods += RewriteWithName("0", code, rewrites); - if (!config.fetch(0).name().empty()) { - *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites); - } - return Status::OK(); - } - // Tuple (i.e. multi-result) case. if (config.fetch_size() != ps.result().tuple_shapes_size()) { return errors::InvalidArgument("mismatch between fetch_size(", config.feed_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - *methods += R"( - void** results() { - return static_cast<void**>(temps_[kResultIndex]); - } - const void *const *results() const { - return static_cast<const void *const *>(temps_[kResultIndex]); - } -)"; for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector<std::pair<string, string>> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>( - static_cast<void**>(temps_[kResultIndex])[{{I}}]); + return static_cast<{{TYPE}}*>(result_data({{I}})); } {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } const {{TYPE}}* result{{NAME}}_data() const { - return static_cast<{{TYPE}}*>( - static_cast<void**>(temps_[kResultIndex])[{{I}}]); + return static_cast<const {{TYPE}}*>(result_data({{I}})); } const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( - static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -291,6 +246,84 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generates code implementing {Arg,Result}Names(), where T is one of +// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string +// literal in the array, with nullptr terminating the array. +template <typename T> +string GenNameToIndexCode(const T& entries, bool generate) { + // No need for a static array if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // Determine when to stop. We stop emitting string literals after the last + // non-empty name. + int end = entries.size(); + for (int i = entries.size() - 1; i >= 0; --i) { + if (!entries[i].name().empty()) { + break; + } + end = i; + } + // Emit string literals up to the last non-empty name. + string code = "{\n static const char* kNames[] = {"; + for (int i = 0; i < end; ++i) { + if (i > 0) { + code += ", "; + } + code += "\""; + code += entries[i].name(); + code += "\""; + } + if (end > 0) { + code += ", "; + } + code += "nullptr};\n return kNames;\n }"; + return code; +} + +// Converts the given `str` into a comma-separated list of per-character values. +string StringToCharList(const string& str) { + string list; + for (const char c : str) { + if (!list.empty()) { + list += ","; + } + list += strings::StrCat(static_cast<int>(c)); + } + return list; +} + +string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) { + // No need for any static magic if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape.clear_parameter_names(); + const string proto_str = program_shape.SerializeAsString(); + // Embed the program shape as a serialized protobuf in the header file. + // + // TODO(toddw): This strategy will likely fail for larger protobufs, depending + // on the C++ compiler that is used. Figure out another solution if necessary. + string code = R"({ + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {{{PROTO_LIST}}}; + static constexpr int kProtoSize = {{PROTO_SIZE}}; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + })"; + str_util::ReplaceAllPairs( + &code, { + {"{{PROTO_LIST}}", StringToCharList(proto_str)}, + {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())}, + }); + return code; +} + Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { @@ -336,24 +369,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, const size_t temp_bytes_total = total_buffer_bytes(itemp.data(), itemp.size()); - // Create rewrite strings for the optional context arg. - string context_include; - string context_set_arg, context_set_thread_pool, context_member_var; - string run_result = "true"; - string error_msg = "tensorflow::string()"; - if (compile_result.has_context_arg) { - // NOTE: Extra spaces and newlines are used to ensure nice formatting. - context_include = - "#include " - "\"tensorflow/compiler/tf2xla/" - "xla_local_runtime_context.h\"\n"; - context_set_arg = " args_[kNumArgs-1] = &context_;\n"; - context_set_thread_pool = " context_.thread_pool = pool;\n"; - context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n"; - run_result = "!context_.error"; - error_msg = "context_.error_msg"; - } - // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { @@ -366,6 +381,19 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, ns_end += strings::StrCat("} // end namespace ", n, "\n"); } + // Generate metadata. + const string arg_names_code = + GenNameToIndexCode(config.feed(), opts.gen_name_to_index); + const string result_names_code = + GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); + const string include_xla_data_proto = + opts.gen_program_shape + ? + R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" + : ""; + const string program_shape_code = + GenProgramShapeCode(ps, opts.gen_program_shape); + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -380,22 +408,23 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) -{{CONTEXT_INCLUDE}} -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +{{INCLUDE_XLA_DATA_PROTO}} +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); {{NS_START}} // {{CLASS}} represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // {{CLASS}} computation; // // ...set args using computation.argN methods @@ -411,9 +440,9 @@ extern "C" void {{ENTRY}}( // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: // {{PROGRAM_SHAPE}} @@ -423,7 +452,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} { +class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -434,47 +463,31 @@ class {{CLASS}} { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } -{{CONTEXT_SET_ARG}} - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~{{CLASS}}() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); -{{CONTEXT_SET_THREAD_POOL}} - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_); - return {{RUN_RESULT}}; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return {{ERROR_MSG}}; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = {{ENTRY}}; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + {{CLASS}}(const {{CLASS}}&) = delete; + {{CLASS}}& operator=(const {{CLASS}}&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -493,10 +506,6 @@ class {{CLASS}} { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. {{METHODS_ARG}} // Result methods for managing output buffers. Buffers are in row-major order. @@ -511,10 +520,6 @@ class {{CLASS}} { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} @@ -522,7 +527,7 @@ class {{CLASS}} { private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = {{RESULT_INDEX}}; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -531,14 +536,14 @@ class {{CLASS}} { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; -{{CONTEXT_MEMBER_VAR}} + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() {{ARG_NAMES_CODE}} + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() {{RESULT_NAMES_CODE}} - TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}}); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}} }; {{NS_END}} @@ -550,22 +555,22 @@ class {{CLASS}} { const std::vector<std::pair<string, string>> rewrites = { {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, - {"{{CONTEXT_INCLUDE}}\n", context_include}, - {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var}, - {"{{CONTEXT_SET_ARG}}\n", context_set_arg}, - {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{ERROR_MSG}}", error_msg}, + {"{{HAS_CONTEXT_ARG}}", + compile_result.has_context_arg ? "true" : "false"}, + {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_CODE}}", program_shape_code}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, - {"{{RUN_RESULT}}", run_result}, + {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, |