diff options
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/aot/codegen.cc | 303 | ||||
-rw-r--r-- | tensorflow/compiler/aot/codegen.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/aot/codegen_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/aot/codegen_test_h.golden | 182 | ||||
-rw-r--r-- | tensorflow/compiler/aot/flags.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/aot/flags.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tests/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tests/tfcompile_test.cc | 72 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tfcompile.bzl | 11 | ||||
-rw-r--r-- | tensorflow/compiler/aot/tfcompile_main.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 55 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc | 88 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h | 223 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc | 217 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h | 87 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc | 133 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_executable.h | 16 |
17 files changed, 1154 insertions, 257 deletions
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index fc5c6ce58d..ae22f7edc4 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,10 +164,6 @@ string RewriteWithName(const string& name, string code, // Generate methods for args (inputs). Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { - *methods += R"( - void** args() { return args_; } - const void *const *args() const { return args_; } -)"; size_t num_args = ps.parameters_size(); if (compile_result.has_context_arg) { // If the compiled function needs a XlaLocalRuntimeContext* arg, it's @@ -184,21 +180,21 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); const string code = R"( void set_arg{{NAME}}_data(void* data) { - args_[{{I}}] = data; + set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { - return static_cast<{{TYPE}}*>(args_[{{I}}]); + return static_cast<{{TYPE}}*>(arg_data({{I}})); } {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } const {{TYPE}}* arg{{NAME}}_data() const { - return static_cast<const {{TYPE}}*>(args_[{{I}}]); + return static_cast<const {{TYPE}}*>(arg_data({{I}})); } const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const { return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -213,74 +209,33 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, Status GenResultMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { - // Non-tuple (i.e. single-result) case. - if (config.fetch_size() != 1) { - return errors::InvalidArgument( - "non-tuple result implies 1 fetch, but got ", config.fetch_size(), - " fetches"); - } - *methods += R"( - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } -)"; - std::vector<std::pair<string, string>> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites)); - const string code = R"( - {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>(temps_[kResultIndex]); - } - {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { - return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - temps_[kResultIndex])){{INDICES}}; - } - const {{TYPE}}* result{{NAME}}_data() const { - return static_cast<const {{TYPE}}*>(temps_[kResultIndex]); - } - const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { - return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( - temps_[kResultIndex])){{INDICES}}; + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and we rely on this to simplify code generation. + return errors::Internal("codegen requires the XLA result to be a tuple"); } -)"; - *methods += RewriteWithName("0", code, rewrites); - if (!config.fetch(0).name().empty()) { - *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites); - } - return Status::OK(); - } - // Tuple (i.e. multi-result) case. if (config.fetch_size() != ps.result().tuple_shapes_size()) { return errors::InvalidArgument("mismatch between fetch_size(", config.feed_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - *methods += R"( - void** results() { - return static_cast<void**>(temps_[kResultIndex]); - } - const void *const *results() const { - return static_cast<const void *const *>(temps_[kResultIndex]); - } -)"; for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector<std::pair<string, string>> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>( - static_cast<void**>(temps_[kResultIndex])[{{I}}]); + return static_cast<{{TYPE}}*>(result_data({{I}})); } {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } const {{TYPE}}* result{{NAME}}_data() const { - return static_cast<{{TYPE}}*>( - static_cast<void**>(temps_[kResultIndex])[{{I}}]); + return static_cast<const {{TYPE}}*>(result_data({{I}})); } const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>( - static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -291,6 +246,84 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generates code implementing {Arg,Result}Names(), where T is one of +// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string +// literal in the array, with nullptr terminating the array. +template <typename T> +string GenNameToIndexCode(const T& entries, bool generate) { + // No need for a static array if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // Determine when to stop. We stop emitting string literals after the last + // non-empty name. + int end = entries.size(); + for (int i = entries.size() - 1; i >= 0; --i) { + if (!entries[i].name().empty()) { + break; + } + end = i; + } + // Emit string literals up to the last non-empty name. + string code = "{\n static const char* kNames[] = {"; + for (int i = 0; i < end; ++i) { + if (i > 0) { + code += ", "; + } + code += "\""; + code += entries[i].name(); + code += "\""; + } + if (end > 0) { + code += ", "; + } + code += "nullptr};\n return kNames;\n }"; + return code; +} + +// Converts the given `str` into a comma-separated list of per-character values. +string StringToCharList(const string& str) { + string list; + for (const char c : str) { + if (!list.empty()) { + list += ","; + } + list += strings::StrCat(static_cast<int>(c)); + } + return list; +} + +string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) { + // No need for any static magic if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape.clear_parameter_names(); + const string proto_str = program_shape.SerializeAsString(); + // Embed the program shape as a serialized protobuf in the header file. + // + // TODO(toddw): This strategy will likely fail for larger protobufs, depending + // on the C++ compiler that is used. Figure out another solution if necessary. + string code = R"({ + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {{{PROTO_LIST}}}; + static constexpr int kProtoSize = {{PROTO_SIZE}}; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + })"; + str_util::ReplaceAllPairs( + &code, { + {"{{PROTO_LIST}}", StringToCharList(proto_str)}, + {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())}, + }); + return code; +} + Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { @@ -336,24 +369,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, const size_t temp_bytes_total = total_buffer_bytes(itemp.data(), itemp.size()); - // Create rewrite strings for the optional context arg. - string context_include; - string context_set_arg, context_set_thread_pool, context_member_var; - string run_result = "true"; - string error_msg = "tensorflow::string()"; - if (compile_result.has_context_arg) { - // NOTE: Extra spaces and newlines are used to ensure nice formatting. - context_include = - "#include " - "\"tensorflow/compiler/tf2xla/" - "xla_local_runtime_context.h\"\n"; - context_set_arg = " args_[kNumArgs-1] = &context_;\n"; - context_set_thread_pool = " context_.thread_pool = pool;\n"; - context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n"; - run_result = "!context_.error"; - error_msg = "context_.error_msg"; - } - // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { @@ -366,6 +381,19 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, ns_end += strings::StrCat("} // end namespace ", n, "\n"); } + // Generate metadata. + const string arg_names_code = + GenNameToIndexCode(config.feed(), opts.gen_name_to_index); + const string result_names_code = + GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); + const string include_xla_data_proto = + opts.gen_program_shape + ? + R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" + : ""; + const string program_shape_code = + GenProgramShapeCode(ps, opts.gen_program_shape); + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -380,22 +408,23 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) -{{CONTEXT_INCLUDE}} -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +{{INCLUDE_XLA_DATA_PROTO}} +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); {{NS_START}} // {{CLASS}} represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // {{CLASS}} computation; // // ...set args using computation.argN methods @@ -411,9 +440,9 @@ extern "C" void {{ENTRY}}( // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: // {{PROGRAM_SHAPE}} @@ -423,7 +452,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} { +class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -434,47 +463,31 @@ class {{CLASS}} { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } -{{CONTEXT_SET_ARG}} - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~{{CLASS}}() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); -{{CONTEXT_SET_THREAD_POOL}} - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_); - return {{RUN_RESULT}}; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return {{ERROR_MSG}}; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = {{ENTRY}}; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + {{CLASS}}(const {{CLASS}}&) = delete; + {{CLASS}}& operator=(const {{CLASS}}&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -493,10 +506,6 @@ class {{CLASS}} { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. {{METHODS_ARG}} // Result methods for managing output buffers. Buffers are in row-major order. @@ -511,10 +520,6 @@ class {{CLASS}} { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} @@ -522,7 +527,7 @@ class {{CLASS}} { private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = {{RESULT_INDEX}}; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -531,14 +536,14 @@ class {{CLASS}} { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; -{{CONTEXT_MEMBER_VAR}} + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() {{ARG_NAMES_CODE}} + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() {{RESULT_NAMES_CODE}} - TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}}); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}} }; {{NS_END}} @@ -550,22 +555,22 @@ class {{CLASS}} { const std::vector<std::pair<string, string>> rewrites = { {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, - {"{{CONTEXT_INCLUDE}}\n", context_include}, - {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var}, - {"{{CONTEXT_SET_ARG}}\n", context_set_arg}, - {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{ERROR_MSG}}", error_msg}, + {"{{HAS_CONTEXT_ARG}}", + compile_result.has_context_arg ? "true" : "false"}, + {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_CODE}}", program_shape_code}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, - {"{{RUN_RESULT}}", run_result}, + {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 740edd1e83..76dd0cc3cf 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -34,6 +34,12 @@ struct HeaderOpts { // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. std::vector<string> namespaces; + + // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. + bool gen_name_to_index = false; + + // If true, generate program shape data for the ProgramShape method. + bool gen_program_shape = false; }; // GenerateHeader uses the meta-information from compile_result to generate a diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 98cbd67e53..0f6114666f 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -127,6 +127,8 @@ TEST(GenerateHeader, Golden) { HeaderOpts opts; opts.class_name = "MyClass"; opts.namespaces = {"foo", "bar"}; + opts.gen_name_to_index = true; + opts.gen_program_shape = true; tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("feed0"); @@ -145,7 +147,8 @@ TEST(GenerateHeader, Golden) { xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), xla::ShapeUtil::MakeOpaqueShape(), }, - xla::ShapeUtil::MakeShape(xla::U32, {5, 6})); + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 01963c6df4..65f342ce27 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -9,24 +9,25 @@ #ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); namespace foo { namespace bar { // MyClass represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // MyClass computation; // // ...set args using computation.argN methods @@ -42,19 +43,19 @@ namespace bar { // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6] +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 // arg bytes aligned: 128 // temp bytes total: 126 // temp bytes aligned: 224 -class MyClass { +class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 3; @@ -65,47 +66,31 @@ class MyClass { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } - args_[kNumArgs-1] = &context_; - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~MyClass() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - entry_point(temps_[kResultIndex], &run_options_, args_, temps_); - return !context_.error; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return context_.error_msg; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = entry_point; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = true; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + MyClass(const MyClass&) = delete; + MyClass& operator=(const MyClass&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -124,66 +109,59 @@ class MyClass { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. - - void** args() { return args_; } - const void *const *args() const { return args_; } void set_arg0_data(void* data) { - args_[0] = data; + set_arg_data(0, data); } float* arg0_data() { - return static_cast<float*>(args_[0]); + return static_cast<float*>(arg_data(0)); } float& arg0(size_t dim0, size_t dim1) { return (*static_cast<float(*)[1][2]>( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } const float* arg0_data() const { - return static_cast<const float*>(args_[0]); + return static_cast<const float*>(arg_data(0)); } const float& arg0(size_t dim0, size_t dim1) const { return (*static_cast<const float(*)[1][2]>( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } void set_arg_myfeed_data(void* data) { - args_[0] = data; + set_arg_data(0, data); } float* arg_myfeed_data() { - return static_cast<float*>(args_[0]); + return static_cast<float*>(arg_data(0)); } float& arg_myfeed(size_t dim0, size_t dim1) { return (*static_cast<float(*)[1][2]>( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } const float* arg_myfeed_data() const { - return static_cast<const float*>(args_[0]); + return static_cast<const float*>(arg_data(0)); } const float& arg_myfeed(size_t dim0, size_t dim1) const { return (*static_cast<const float(*)[1][2]>( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } void set_arg1_data(void* data) { - args_[1] = data; + set_arg_data(1, data); } tensorflow::int64* arg1_data() { - return static_cast<tensorflow::int64*>(args_[1]); + return static_cast<tensorflow::int64*>(arg_data(1)); } tensorflow::int64& arg1(size_t dim0, size_t dim1) { return (*static_cast<tensorflow::int64(*)[3][4]>( - args_[1]))[dim0][dim1]; + arg_data(1)))[dim0][dim1]; } const tensorflow::int64* arg1_data() const { - return static_cast<const tensorflow::int64*>(args_[1]); + return static_cast<const tensorflow::int64*>(arg_data(1)); } const tensorflow::int64& arg1(size_t dim0, size_t dim1) const { return (*static_cast<const tensorflow::int64(*)[3][4]>( - args_[1]))[dim0][dim1]; + arg_data(1)))[dim0][dim1]; } // Result methods for managing output buffers. Buffers are in row-major order. @@ -198,50 +176,43 @@ class MyClass { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } - tensorflow::uint32* result0_data() { - return static_cast<tensorflow::uint32*>(temps_[kResultIndex]); + return static_cast<tensorflow::uint32*>(result_data(0)); } tensorflow::uint32& result0(size_t dim0, size_t dim1) { return (*static_cast<tensorflow::uint32(*)[5][6]>( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } const tensorflow::uint32* result0_data() const { - return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]); + return static_cast<const tensorflow::uint32*>(result_data(0)); } const tensorflow::uint32& result0(size_t dim0, size_t dim1) const { return (*static_cast<const tensorflow::uint32(*)[5][6]>( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } tensorflow::uint32* result_myfetch_data() { - return static_cast<tensorflow::uint32*>(temps_[kResultIndex]); + return static_cast<tensorflow::uint32*>(result_data(0)); } tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) { return (*static_cast<tensorflow::uint32(*)[5][6]>( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } const tensorflow::uint32* result_myfetch_data() const { - return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]); + return static_cast<const tensorflow::uint32*>(result_data(0)); } const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const { return (*static_cast<const tensorflow::uint32(*)[5][6]>( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = 6; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = 5; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -250,14 +221,29 @@ class MyClass { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() { + static const char* kNames[] = {"myfeed", nullptr}; + return kNames; + } + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() { + static const char* kNames[] = {"myfetch", nullptr}; + return kNames; + } - TF_DISALLOW_COPY_AND_ASSIGN(MyClass); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() { + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; + static constexpr int kProtoSize = 50; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + } }; } // end namespace bar diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 4e3998b682..5aff10346f 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -64,6 +64,10 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) { "namespaces are given, within the global namespace."}, {"out_object", &flags->out_object, "Output object file name."}, {"out_header", &flags->out_header, "Output header file name."}, + {"gen_name_to_index", &flags->gen_name_to_index, + "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, + {"gen_program_shape", &flags->gen_program_shape, + "Generate program shape data for the ProgramShape method."}, }; flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); } diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index e11a0173fa..3246dbf95c 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -37,6 +37,10 @@ struct MainFlags { string cpp_class; string out_object; string out_header; + + // C++ codegen options + bool gen_name_to_index = false; + bool gen_program_shape = false; }; // Appends to flag_list a tensorflow::Flag for each field in MainFlags. diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index b0b1213a84..7dfd49cc3b 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -132,6 +132,7 @@ tf_library( cpp_class = "MatMulAndAddComp", graph = "test_graph_tfmatmulandadd.pb", tags = ["manual"], + tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) tf_library( @@ -156,6 +157,8 @@ tf_cc_test( ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", ":test_graph_tfsplits", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 07562e59c8..cfde5651c6 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -188,6 +190,23 @@ TEST(TFCompileTest, Gather) { EXPECT_FALSE(gather.Run()); EXPECT_EQ(gather.error_msg(), "Invalid index for gather"); } + + // Try a successful gather again, after the error, to ensure the error state + // is cleared. + { + const float params[4] = {1, 2, 3, 4}; + std::copy(params + 0, params + 4, gather.arg0_data()); + const int32 indices[2] = {1, 3}; + std::copy(indices + 0, indices + 2, gather.arg1_data()); + EXPECT_TRUE(gather.Run()); + EXPECT_EQ(gather.error_msg(), ""); + const float results[2] = {2, 4}; + for (int i = 0; i < 2; ++i) { + EXPECT_EQ(gather.result0(i), results[i]); + EXPECT_EQ(gather.result0_data()[i], results[i]); + } + EXPECT_EQ(gather.result0_data(), gather.results()[0]); + } } TEST(TFCompileTest, MatMul2) { @@ -421,6 +440,59 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, LookupNameIndex) { + // add doesn't have any names defined in its config. + AddComp add; + EXPECT_FALSE(add.HasNameIndices()); + + // muladd has names defined for all feeds and fetches. + MatMulAndAddComp muladd; + EXPECT_TRUE(muladd.HasNameIndices()); + + EXPECT_EQ(muladd.LookupArgIndex("x"), 0); + EXPECT_EQ(muladd.LookupArgIndex("y"), 1); + EXPECT_EQ(muladd.LookupArgIndex(""), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_hold"), -1); + EXPECT_EQ(muladd.LookupArgIndex("y_hold"), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_y_prod"), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_y_sum"), -1); + + EXPECT_EQ(muladd.LookupResultIndex("x_y_prod"), 0); + EXPECT_EQ(muladd.LookupResultIndex("x_y_sum"), 1); + EXPECT_EQ(muladd.LookupResultIndex(""), -1); + EXPECT_EQ(muladd.LookupResultIndex("x"), -1); + EXPECT_EQ(muladd.LookupResultIndex("y"), -1); + EXPECT_EQ(muladd.LookupResultIndex("x_hold"), -1); + EXPECT_EQ(muladd.LookupResultIndex("y_hold"), -1); +} + +TEST(TFCompileTest, ProgramShape) { + using xla::ShapeUtil; + const xla::Shape f32_2x2 = ShapeUtil::MakeShape(xla::F32, {2, 2}); + + // add doesn't have the program shape defined. + AddComp add; + ASSERT_TRUE(add.ProgramShape() == nullptr); + + // muladd has the program shape defined. + MatMulAndAddComp muladd; + const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + ASSERT_TRUE(muladd_shape != nullptr); + ASSERT_EQ(muladd_shape->parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); + + const xla::Shape& muladd_result = muladd_shape->result(); + ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); + ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); + const xla::Shape& muladd_result0 = + ShapeUtil::GetTupleElementShape(muladd_result, 0); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_result0, f32_2x2)); + const xla::Shape& muladd_result1 = + ShapeUtil::GetTupleElementShape(muladd_result, 1); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 608d461a4c..461a9315c5 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -167,6 +167,8 @@ def tf_library(name, graph, config, # The cc_library rule packaging up the header and object file, and needed # kernel implementations. + need_xla_data_proto = (tfcompile_flags and + tfcompile_flags.find("--gen_program_shape") != -1) native.cc_library( name=name, srcs=[object_file], @@ -177,11 +179,12 @@ def tf_library(name, graph, config, # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", "//tensorflow/core:framework_lite", - ] + (include_standard_runtime_deps and [ + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the proto. + "//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32", "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index cc499c3284..6ab3d47418 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -94,6 +94,8 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, StringPiece(obj.data(), obj.size()))); HeaderOpts header_opts; + header_opts.gen_name_to_index = flags.gen_name_to_index; + header_opts.gen_program_shape = flags.gen_program_shape; if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0769b13718..08f2249e0d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -59,6 +59,41 @@ cc_library( ) cc_library( + name = "xla_compiled_cpu_function", + srcs = ["xla_compiled_cpu_function.cc"], + hdrs = ["xla_compiled_cpu_function.h"], + visibility = ["//visibility:public"], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + +cc_library( + name = "xla_jit_compiled_cpu_function", + srcs = ["xla_jit_compiled_cpu_function.cc"], + hdrs = ["xla_jit_compiled_cpu_function.h"], + visibility = ["//visibility:public"], + deps = [ + ":tf2xla", + ":tf2xla_proto", + ":xla_compiled_cpu_function", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service/cpu:cpu_executable", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( name = "xla_compiler", srcs = [ "xla_compilation_device.cc", @@ -179,6 +214,26 @@ tf_cc_test( ) tf_cc_test( + name = "xla_jit_compiled_cpu_function_test", + srcs = ["xla_jit_compiled_cpu_function_test.cc"], + deps = [ + ":tf2xla_proto", + ":xla_jit_compiled_cpu_function", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], deps = [ diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc new file mode 100644 index 0000000000..b5c17c5273 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" + +#include <cassert> +#include "tensorflow/compiler/aot/runtime.h" + +namespace tensorflow { + +XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, + AllocMode alloc_mode) + : raw_function_(static_data.raw_function), + result_index_(static_data.result_index), + args_(new void*[static_data.num_args]), + temps_(new void*[static_data.num_temps]), + arg_names_(static_data.arg_names), + result_names_(static_data.result_names), + program_shape_(static_data.program_shape) { + // Allocate arg and temp buffers. + if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + static_data.arg_sizes, static_data.num_args, args_, + /*annotate_initialized=*/false); + } + alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + static_data.temp_sizes, static_data.num_temps, temps_, + /*annotate_initialized=*/true); + + // The runtime context is always the last arg, if it is required. + if (static_data.requires_runtime_context) { + args_[static_data.num_args - 1] = &context_; + } +} + +XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { + tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); + tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + delete[] args_; + delete[] temps_; +} + +namespace { + +// Linear search through `names` looking for a match with `name`. Returns -1 if +// the name isn't found, or is empty. +// +// REQUIRES: `names` is a nullptr-terminated array. +int LookupNameIndex(const string& name, const char** names) { + // Hitting this assert means that there is no name-to-index data available; + // for AOT try the setting the tfcompile --gen_name_to_index flag. + assert(names != nullptr); + + constexpr int kNotFound = -1; + if (name.empty()) { + return kNotFound; + } + for (int index = 0; names[index] != nullptr; ++index) { + if (name == names[index]) { + return index; + } + } + return kNotFound; +} + +} // namespace + +int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { + return LookupNameIndex(name, arg_names_); +} + +int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { + return LookupNameIndex(name, result_names_); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h new file mode 100644 index 0000000000..01e6b4c071 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ + +#include <functional> +#include <string> + +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/types.h" + +// Forward-declare, rather than include, to reduce code size for users that +// never use this functionality. +namespace xla { +class ProgramShape; +} + +namespace tensorflow { + +// Represents a function compiled by XLA, produced via either JIT or AOT. +// +// The Run method invokes the actual computation, with inputs read from arg +// buffers, and outputs written to result buffers. Each Run call may also use a +// set of temporary buffers for the computation. +// +// By default each instance of this class manages its own arg, result and temp +// buffers. The AllocMode constructor parameter may be used to modify the buffer +// allocation strategy. +// +// Under the default allocation strategy, this class is thread-compatible: +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. +class XlaCompiledCpuFunction { + public: + // Type of the raw function, produced by either JIT or AOT. + // + // TODO(toddw): Add support for hlo profiling, and replace std::function with + // a raw function pointer, for some codesize savings. + using RawFunction = std::function<void( + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps)>; + + // StaticData represents the state necessary to run an XLA-compiled + // function. For JIT this is backed by data in XlaCompiledCpuFunctionJit; for + // AOT this is backed by data compiled into the object file. + struct StaticData { + // The raw function to call. + RawFunction raw_function; + + // Cardinality and sizes of arg and temp buffers. + const intptr_t* arg_sizes = nullptr; + size_t num_args = 0; + const intptr_t* temp_sizes = nullptr; + size_t num_temps = 0; + + // The 0-based index of the result tuple, in the temp buffers. + size_t result_index = 0; + + // Is the final arg XlaLocalRuntimeContext? + bool requires_runtime_context = false; + + // [Optional] Arrays of arg and result names. These are arrays of C-style + // strings, where the array is terminated by nullptr. + const char** arg_names = nullptr; + const char** result_names = nullptr; + + // [Optional] Arg and result shapes. + const xla::ProgramShape* program_shape = nullptr; + }; + + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + // Allocate all buffers - args, results and temps. + ARGS_RESULTS_AND_TEMPS, + + // Only allocate result and temp buffers. + // Use set_arg_data to set argument buffers before Run is called. + RESULTS_AND_TEMPS_ONLY, + }; + + XlaCompiledCpuFunction( + const StaticData& static_data, + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + virtual ~XlaCompiledCpuFunction(); + + XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; + XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; + + // Sets the intra-op thread pool used to run individual ops concurrently. + void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + run_options_.set_intra_op_thread_pool(pool); + context_.thread_pool = pool; + } + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run() { + context_.error = false; + context_.error_msg.clear(); + raw_function_(temps_[result_index_], &run_options_, + const_cast<const void**>(args_), temps_); + return !context_.error; + } + + // Returns the error message from the previous failed Run call. + const string& error_msg() const { return context_.error_msg; } + + // ------------------------------ + // Arg methods for managing input buffers. Buffers are in row-major order. + + // Returns the underlying array of argument buffers, where args()[I] is the + // buffer for the positional argument at index I. + void** args() { return args_; } + const void* const* args() const { return args_; } + + // Returns the buffer for the positional argument at the given `index`. + void* arg_data(size_t index) { return args_[index]; } + const void* arg_data(size_t index) const { return args_[index]; } + + // Sets the buffer for the positional argument at the given `index` to `data`. + // Must be called before Run to have an effect. May be called under any + // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be + // called for each positional argument, in order to set the argument buffers. + // + // Allocated memory must be aligned to the size specified by + // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in + // tensorflow/compiler/aot/runtime.h to ensure correct alignment. + // + // If StaticData.requires_runtime_context==true, the final argument is an + // XlaLocalRuntimeContext, which is managed internally by this class, and + // should not be changed. + // + // Aliasing of argument and result buffers is not allowed, and results in + // undefined behavior. + void set_arg_data(size_t index, void* data) { args_[index] = data; } + + // ------------------------------ + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. Unlike the arg methods, + // there is no set_resultN_data method. The result buffers are managed + // internally, and may change after each call to Run. + + // Returns the underlying array of result buffers, where results()[I] is the + // buffer for the positional result at index I. + void** results() { return static_cast<void**>(temps_[result_index_]); } + const void* const* results() const { + return static_cast<const void* const*>(temps_[result_index_]); + } + + // Returns the buffer for the positional result at the given `index`. + void* result_data(size_t index) { return results()[index]; } + const void* result_data(size_t index) const { return results()[index]; } + + // ------------------------------ + // Methods for extracting optional metadata. + + // Returns true iff data is available for the Lookup{Arg,Result}Index methods. + // E.g. the data might not be compiled into the binary for AOT. + bool HasNameIndices() const { + return arg_names_ != nullptr && result_names_ != nullptr; + } + + // Returns the 0-based index for the argument with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupArgIndex(const string& name) const; + + // Returns the 0-based index for the result with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupResultIndex(const string& name) const; + + // Returns the shape of the args and results. May return nullptr if the + // program shape isn't available. + const xla::ProgramShape* ProgramShape() const { return program_shape_; } + + private: + const RawFunction raw_function_; + const size_t result_index_; + + // Arrays of argument and temp buffers; entries in args_ may be overwritten by + // the user. + void** args_ = nullptr; + void** temps_ = nullptr; + + // Backing memory for individual arg and temp buffers. + void* alloc_args_ = nullptr; + void* alloc_temps_ = nullptr; + + // Options and context passed to the compiled function. + xla::ExecutableRunOptions run_options_; + tensorflow::XlaLocalRuntimeContext context_; + + // Optional metadata. + const char** arg_names_ = nullptr; + const char** result_names_ = nullptr; + const xla::ProgramShape* program_shape_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc new file mode 100644 index 0000000000..1dd454ea8d --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -0,0 +1,217 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" + +#include <memory> +#include <vector> + +#include "tensorflow/compiler/tf2xla/tf2xla.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +// Returns a vector of positional argument buffer sizes. +xla::StatusOr<std::vector<intptr_t>> ComputeArgSizes( + const xla::ProgramShape& program_shape, bool requires_runtime_context) { + std::vector<intptr_t> arg_sizes; + const size_t num_args = program_shape.parameters_size(); + arg_sizes.reserve(num_args); + for (int i = 0; i < num_args; ++i) { + const xla::Shape& arg_shape = program_shape.parameters(i); + if (i == num_args - 1 && requires_runtime_context) { + // If the compiled function needs an XlaLocalRuntimeContext* arg, it's + // always last, and must be represented as an opaque type. + const xla::PrimitiveType type = arg_shape.element_type(); + if (type != xla::OPAQUE) { + return errors::InvalidArgument( + "expected final context arg to be opaque, but got type: ", + xla::PrimitiveType_Name(type), ", from program shape: ", + xla::ShapeUtil::HumanString(program_shape)); + } + arg_sizes.push_back(-1); + } else { + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); + } + } + return std::move(arg_sizes); +} + +// Returns a vector of positional temporary buffer sizes. +xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes( + const xla::BufferAssignment& buffer_assignment) { + const std::vector<xla::BufferAllocation>& allocations = + buffer_assignment.Allocations(); + std::vector<intptr_t> temp_sizes; + temp_sizes.reserve(allocations.size()); + for (const xla::BufferAllocation& allocation : allocations) { + // Callers don't allocate temporary buffers for parameters. Nor for + // thread-local buffers, which are lowered to alloca. + if (allocation.is_entry_computation_parameter() || + allocation.is_thread_local()) { + temp_sizes.push_back(-1); + } else { + temp_sizes.push_back(allocation.size()); + } + } + return std::move(temp_sizes); +} + +// Returns the index of the result in the temp buffers. +xla::StatusOr<size_t> ComputeResultIndex( + const xla::BufferAssignment& buffer_assignment) { + TF_ASSIGN_OR_RETURN(const xla::BufferAllocation::Slice result_slice, + buffer_assignment.GetUniqueTopLevelOutputSlice()); + return result_slice.index(); +} + +// Adapt ComputeFunctionType, which includes a final profile_counters arg, to +// RawFunction, which doesn't include that final arg. +// +// TODO(toddw): Change RawFunction and AOT to also pass the final +// profile_counters arg, and remove this adapter. +XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( + xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { + return [compute_function](void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps) { + return compute_function(result, run_options, args, temps, + /*profile_counters=*/nullptr); + }; +} + +// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold +// the actual strings in nonempty_names, and hold arrays of pointers in +// name_ptrs, terminated by a nullptr entry. +template <typename T> +void CollectNames(const T& entries, std::vector<string>* nonempty_names, + std::vector<const char*>* name_ptrs) { + // First collect `nonempty_names`, to ensure the underlying strings won't + // change out from under us. + for (const auto& entry : entries) { + const string& name = entry.name(); + if (!name.empty()) { + nonempty_names->push_back(name); + } + } + // Now set `name_ptrs` pointing to the strings in `nonempty_names`. + name_ptrs->reserve(entries.size() + 1); // +1 for nullptr array terminator + size_t nonempty_index = 0; + for (const auto& entry : entries) { + const string& name = entry.name(); + if (!name.empty()) { + name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str()); + ++nonempty_index; + } else { + name_ptrs->push_back(""); + } + } + name_ptrs->push_back(nullptr); // array terminator +} + +} // namespace + +/*static*/ xla::StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>> +XlaJitCompiledCpuFunction::Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options) { + // Convert the graph_def into an xla::Computation. + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + xla::ClientLibrary::GetOrCreateLocalClient()); + xla::Computation computation; + bool requires_runtime_context; + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( + graph_def, config, client, &computation, &requires_runtime_context)); + + // Get and verify the program shape. + TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> program_shape, + client->GetComputationShape(computation)); + if (program_shape->result().element_type() != xla::TUPLE) { + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and XlaCompiledCpuFunction relies on this for simpler + // calling semantics. + return errors::Internal( + "XlaJitCompiledCpuFunction requires the XLA result to be a tuple"); + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape->clear_parameter_names(); + + // Compute arg shapes, needed to compile the executable. + std::vector<const xla::Shape*> arg_shapes; + arg_shapes.reserve(program_shape->parameters_size()); + for (int i = 0; i < program_shape->parameters_size(); ++i) { + arg_shapes.push_back(&program_shape->parameters(i)); + } + + // Compile the executable. The static_cast to the CpuExecutable subclass is + // necessary since the raw function and buffer assignments are only available + // there. + TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable, + client->Compile(computation, arg_shapes, build_options)); + const xla::cpu::CpuExecutable* cpu_executable = + static_cast<xla::cpu::CpuExecutable*>(executable->executable()); + XlaCompiledCpuFunction::RawFunction raw_function = + RawFunctionAdapter(cpu_executable->compute_function()); + const xla::BufferAssignment& buffer_assignment = + cpu_executable->buffer_assignment(); + + // Compute buffer sizes and the result index, needed to run the raw function. + TF_ASSIGN_OR_RETURN( + std::vector<intptr_t> arg_sizes, + ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector<intptr_t> temp_sizes, + ComputeTempSizes(buffer_assignment)); + TF_ASSIGN_OR_RETURN(size_t result_index, + ComputeResultIndex(buffer_assignment)); + + std::unique_ptr<XlaJitCompiledCpuFunction> jit_unique_ptr( + new XlaJitCompiledCpuFunction); + XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); + jit->executable_ = std::move(executable); + jit->arg_sizes_ = std::move(arg_sizes); + jit->temp_sizes_ = std::move(temp_sizes); + jit->program_shape_ = std::move(program_shape); + jit->static_data_.raw_function = std::move(raw_function); + jit->static_data_.arg_sizes = jit->arg_sizes_.data(); + jit->static_data_.num_args = jit->arg_sizes_.size(); + jit->static_data_.temp_sizes = jit->temp_sizes_.data(); + jit->static_data_.num_temps = jit->temp_sizes_.size(); + jit->static_data_.result_index = result_index; + jit->static_data_.requires_runtime_context = requires_runtime_context; + // Optional metadata is collected and set below. + CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); + CollectNames(config.fetch(), &jit->nonempty_result_names_, + &jit->result_names_); + jit->static_data_.arg_names = jit->arg_names_.data(); + jit->static_data_.result_names = jit->result_names_.data(); + jit->static_data_.program_shape = jit->program_shape_.get(); + return std::move(jit_unique_ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h new file mode 100644 index 0000000000..af307ae4ef --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ + +#include <memory> +#include <vector> + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the result of JIT compilation by XLA down to a function. This +// class holds the state necessary to create XlaCompiledCpuFunction instances, +// which are used to actually invoke the compiled computation. +// +// XlaJitCompiledCpuFunction must outlive the XlaCompiledCpuFunctions that are +// created from it. It holds state shared by all of the functions, including the +// JIT-compiled function itself, along with buffer sizes and other metadata +// necessary for execution. +class XlaJitCompiledCpuFunction { + public: + // Compile a tensorflow::GraphDef into an XlaJitCompiledCpuFunction. The given + // `config` specifies the portion of the graph to compile, via feeds and + // fetches. Each feed is a positional input argument for the compiled + // function, while each fetch is a positional output argument. + static xla::StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>> Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options); + + XlaJitCompiledCpuFunction(const XlaJitCompiledCpuFunction&) = delete; + XlaJitCompiledCpuFunction& operator=(const XlaJitCompiledCpuFunction&) = + delete; + + // Returns static data used to create an XlaCompiledCpuFunction instance, + // which represents the JIT-compiled function. The static data is unchanging + // across each instance. + const XlaCompiledCpuFunction::StaticData& StaticData() const { + return static_data_; + } + + private: + XlaJitCompiledCpuFunction() {} + + // The executable holds the underlying function. + std::unique_ptr<xla::LocalExecutable> executable_; + + // The static data is backed by the rest of the state in this class. + XlaCompiledCpuFunction::StaticData static_data_; + + // The backing arrays of arg and temp buffer sizes. + std::vector<intptr_t> arg_sizes_; + std::vector<intptr_t> temp_sizes_; + + // The backing arrays of arg and result names. We hold the actual strings in + // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static + // data to refer to. + std::vector<string> nonempty_arg_names_; + std::vector<string> nonempty_result_names_; + std::vector<const char*> arg_names_; + std::vector<const char*> result_names_; + + // The backing data for the program shape. + std::unique_ptr<const xla::ProgramShape> program_shape_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc new file mode 100644 index 0000000000..5bee68eefc --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +AttrValue TypeAttrValue(DataType type) { + AttrValue attr_value; + SetAttrValue(type, &attr_value); + return attr_value; +} + +GraphDef SumGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* sum = graph_def.add_node(); + sum->set_name("sum"); + sum->set_op("Add"); + sum->add_input("x"); + sum->add_input("y"); + (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32); + return graph_def; +} + +tf2xla::Config SumConfig() { + tf2xla::Config config; + tf2xla::Feed* x = config.add_feed(); + x->mutable_id()->set_node_name("x"); + x->set_name("x_name"); + tf2xla::Feed* y = config.add_feed(); + y->mutable_id()->set_node_name("y"); + y->set_name("y_name"); + tf2xla::Fetch* sum = config.add_fetch(); + sum->mutable_id()->set_node_name("sum"); + sum->set_name("sum_name"); + return config; +} + +TEST(XlaJitCompiledCpuFunction, Sum) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr<XlaJitCompiledCpuFunction> jit, + XlaJitCompiledCpuFunction::Compile(graph_def, config, + xla::ExecutableBuildOptions())); + XlaCompiledCpuFunction function(jit->StaticData()); + + // Run the function and check results. + *static_cast<int32*>(function.arg_data(0)) = 10; + *static_cast<int32*>(function.arg_data(1)) = 32; + EXPECT_TRUE(function.Run()); + EXPECT_EQ(function.error_msg(), ""); + EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 42); + + // Run the function again. + *static_cast<int32*>(function.arg_data(0)) = 100; + *static_cast<int32*>(function.arg_data(1)) = 320; + EXPECT_TRUE(function.Run()); + EXPECT_EQ(function.error_msg(), ""); + EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 420); + + // Check name to index lookups. + EXPECT_TRUE(function.HasNameIndices()); + + EXPECT_EQ(function.LookupArgIndex("x_name"), 0); + EXPECT_EQ(function.LookupArgIndex("y_name"), 1); + EXPECT_EQ(function.LookupArgIndex(""), -1); + EXPECT_EQ(function.LookupArgIndex("x"), -1); + EXPECT_EQ(function.LookupArgIndex("y"), -1); + EXPECT_EQ(function.LookupArgIndex("sum"), -1); + EXPECT_EQ(function.LookupArgIndex("sum_name"), -1); + + EXPECT_EQ(function.LookupResultIndex("sum_name"), 0); + EXPECT_EQ(function.LookupResultIndex(""), -1); + EXPECT_EQ(function.LookupResultIndex("x"), -1); + EXPECT_EQ(function.LookupResultIndex("y"), -1); + EXPECT_EQ(function.LookupResultIndex("sum"), -1); + EXPECT_EQ(function.LookupResultIndex("x_name"), -1); + EXPECT_EQ(function.LookupResultIndex("y_name"), -1); + + // Check program shape. + using xla::ShapeUtil; + const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); + const xla::ProgramShape* program_shape = function.ProgramShape(); + ASSERT_TRUE(program_shape != nullptr); + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32)); + + const xla::Shape& result = program_shape->result(); + ASSERT_EQ(result.element_type(), xla::TUPLE); + ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1); + const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0); + EXPECT_TRUE(ShapeUtil::Compatible(result0, s32)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 0d68aa7399..238bc9b46a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -87,6 +87,17 @@ class CpuExecutable : public Executable { std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override; + // Type of the computation function we expect in the JIT. + using ComputeFunctionType = void (*)( + void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); + + const ComputeFunctionType& compute_function() const { + return compute_function_; + } + + const BufferAssignment& buffer_assignment() const { return *assignment_; } + private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer @@ -129,11 +140,6 @@ class CpuExecutable : public Executable { // positives. string ir_module_string_; - // Type of the computation function we expect in the JIT. - // void function(void* result, const void* run_options, - // const void** args_array, void** temps_array) - using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - uint64*); ComputeFunctionType compute_function_; // Entry function name for the computation. |