aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-02 23:33:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-02 23:37:06 -0700
commit263d025fb6dee974eefb30a51372188fb856d6cc (patch)
treeb32ec04077368f45fbf31da8852b4fe072611e45 /tensorflow/compiler
parent955c525d416c163c9dd857e637b0476b112b0ea0 (diff)
Add XlaCompiledFunction, a lightweight API for calling XLA computations that are
compiled down to functions. The API is based on a generic form of the original AOT auto-generated header. For AOT (tfcompile), this API has been slotted into the auto-generated header. For JIT, a new XlaCompiledFunctionJit class has been added, which compiles a tensorflow::GraphDef and allows the user to create XlaCompiledFunction objects. XlaCompiledFunction contains optional metadata; mappings from arg/result names to their index, and the program shape. This data is always available via JIT, but only provided via AOT if the tfcompile --gen_name_to_index and --gen_program_shape flags are set. We don't enable by default for AOT to keep binary sizes smaller; the ProgramShape proto pulls in lots of code, and may also be large. PiperOrigin-RevId: 170811579
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/aot/codegen.cc303
-rw-r--r--tensorflow/compiler/aot/codegen.h6
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc5
-rw-r--r--tensorflow/compiler/aot/codegen_test_h.golden182
-rw-r--r--tensorflow/compiler/aot/flags.cc4
-rw-r--r--tensorflow/compiler/aot/flags.h4
-rw-r--r--tensorflow/compiler/aot/tests/BUILD3
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc72
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl11
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD55
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc88
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h223
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc217
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h87
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc133
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h16
17 files changed, 1154 insertions, 257 deletions
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index fc5c6ce58d..ae22f7edc4 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -164,10 +164,6 @@ string RewriteWithName(const string& name, string code,
// Generate methods for args (inputs).
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
const CompileResult& compile_result, string* methods) {
- *methods += R"(
- void** args() { return args_; }
- const void *const *args() const { return args_; }
-)";
size_t num_args = ps.parameters_size();
if (compile_result.has_context_arg) {
// If the compiled function needs a XlaLocalRuntimeContext* arg, it's
@@ -184,21 +180,21 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
const string code = R"(
void set_arg{{NAME}}_data(void* data) {
- args_[{{I}}] = data;
+ set_arg_data({{I}}, data);
}
{{TYPE}}* arg{{NAME}}_data() {
- return static_cast<{{TYPE}}*>(args_[{{I}}]);
+ return static_cast<{{TYPE}}*>(arg_data({{I}}));
}
{{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
- args_[{{I}}])){{INDICES}};
+ arg_data({{I}}))){{INDICES}};
}
const {{TYPE}}* arg{{NAME}}_data() const {
- return static_cast<const {{TYPE}}*>(args_[{{I}}]);
+ return static_cast<const {{TYPE}}*>(arg_data({{I}}));
}
const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
- args_[{{I}}])){{INDICES}};
+ arg_data({{I}}))){{INDICES}};
}
)";
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
@@ -213,74 +209,33 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
Status GenResultMethods(const tf2xla::Config& config,
const xla::ProgramShape& ps, string* methods) {
if (ps.result().element_type() != xla::TUPLE) {
- // Non-tuple (i.e. single-result) case.
- if (config.fetch_size() != 1) {
- return errors::InvalidArgument(
- "non-tuple result implies 1 fetch, but got ", config.fetch_size(),
- " fetches");
- }
- *methods += R"(
- void** results() { return temps_ + kResultIndex; }
- const void *const *results() const { return temps_ + kResultIndex; }
-)";
- std::vector<std::pair<string, string>> rewrites;
- TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites));
- const string code = R"(
- {{TYPE}}* result{{NAME}}_data() {
- return static_cast<{{TYPE}}*>(temps_[kResultIndex]);
- }
- {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
- return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
- temps_[kResultIndex])){{INDICES}};
- }
- const {{TYPE}}* result{{NAME}}_data() const {
- return static_cast<const {{TYPE}}*>(temps_[kResultIndex]);
- }
- const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
- return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
- temps_[kResultIndex])){{INDICES}};
+ // The XlaCompiler we use to build the xla computation always generates a
+ // tuple result, and we rely on this to simplify code generation.
+ return errors::Internal("codegen requires the XLA result to be a tuple");
}
-)";
- *methods += RewriteWithName("0", code, rewrites);
- if (!config.fetch(0).name().empty()) {
- *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites);
- }
- return Status::OK();
- }
- // Tuple (i.e. multi-result) case.
if (config.fetch_size() != ps.result().tuple_shapes_size()) {
return errors::InvalidArgument("mismatch between fetch_size(",
config.feed_size(), ") and tuple_size(",
ps.result().tuple_shapes_size(), ")");
}
- *methods += R"(
- void** results() {
- return static_cast<void**>(temps_[kResultIndex]);
- }
- const void *const *results() const {
- return static_cast<const void *const *>(temps_[kResultIndex]);
- }
-)";
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
- return static_cast<{{TYPE}}*>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}]);
+ return static_cast<{{TYPE}}*>(result_data({{I}}));
}
{{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
+ result_data({{I}}))){{INDICES}};
}
const {{TYPE}}* result{{NAME}}_data() const {
- return static_cast<{{TYPE}}*>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}]);
+ return static_cast<const {{TYPE}}*>(result_data({{I}}));
}
const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
- static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
+ result_data({{I}}))){{INDICES}};
}
)";
*methods += RewriteWithName(strings::StrCat(i), code, rewrites);
@@ -291,6 +246,84 @@ Status GenResultMethods(const tf2xla::Config& config,
return Status::OK();
}
+// Generates code implementing {Arg,Result}Names(), where T is one of
+// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
+// literal in the array, with nullptr terminating the array.
+template <typename T>
+string GenNameToIndexCode(const T& entries, bool generate) {
+ // No need for a static array if we're not supposed to generate the data.
+ if (!generate) {
+ return "{\n return nullptr;\n }";
+ }
+ // Determine when to stop. We stop emitting string literals after the last
+ // non-empty name.
+ int end = entries.size();
+ for (int i = entries.size() - 1; i >= 0; --i) {
+ if (!entries[i].name().empty()) {
+ break;
+ }
+ end = i;
+ }
+ // Emit string literals up to the last non-empty name.
+ string code = "{\n static const char* kNames[] = {";
+ for (int i = 0; i < end; ++i) {
+ if (i > 0) {
+ code += ", ";
+ }
+ code += "\"";
+ code += entries[i].name();
+ code += "\"";
+ }
+ if (end > 0) {
+ code += ", ";
+ }
+ code += "nullptr};\n return kNames;\n }";
+ return code;
+}
+
+// Converts the given `str` into a comma-separated list of per-character values.
+string StringToCharList(const string& str) {
+ string list;
+ for (const char c : str) {
+ if (!list.empty()) {
+ list += ",";
+ }
+ list += strings::StrCat(static_cast<int>(c));
+ }
+ return list;
+}
+
+string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) {
+ // No need for any static magic if we're not supposed to generate the data.
+ if (!generate) {
+ return "{\n return nullptr;\n }";
+ }
+ // The parameter names are currently meaningless, and redundant with the rest
+ // of our metadata, so clear them out to avoid confusion and save space.
+ program_shape.clear_parameter_names();
+ const string proto_str = program_shape.SerializeAsString();
+ // Embed the program shape as a serialized protobuf in the header file.
+ //
+ // TODO(toddw): This strategy will likely fail for larger protobufs, depending
+ // on the C++ compiler that is used. Figure out another solution if necessary.
+ string code = R"({
+ static const xla::ProgramShape* kShape = []() {
+ static const char kProto[] = {{{PROTO_LIST}}};
+ static constexpr int kProtoSize = {{PROTO_SIZE}};
+ xla::ProgramShape* shape = new xla::ProgramShape;
+ shape->ParseFromArray(kProto, kProtoSize);
+ return shape;
+ }();
+ return kShape;
+ })";
+ str_util::ReplaceAllPairs(
+ &code, {
+ {"{{PROTO_LIST}}", StringToCharList(proto_str)},
+ {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())},
+ });
+ return code;
+}
+
Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
for (const tf2xla::Feed& feed : config.feed()) {
if (!feed.name().empty()) {
@@ -336,24 +369,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
const size_t temp_bytes_total =
total_buffer_bytes(itemp.data(), itemp.size());
- // Create rewrite strings for the optional context arg.
- string context_include;
- string context_set_arg, context_set_thread_pool, context_member_var;
- string run_result = "true";
- string error_msg = "tensorflow::string()";
- if (compile_result.has_context_arg) {
- // NOTE: Extra spaces and newlines are used to ensure nice formatting.
- context_include =
- "#include "
- "\"tensorflow/compiler/tf2xla/"
- "xla_local_runtime_context.h\"\n";
- context_set_arg = " args_[kNumArgs-1] = &context_;\n";
- context_set_thread_pool = " context_.thread_pool = pool;\n";
- context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n";
- run_result = "!context_.error";
- error_msg = "context_.error_msg";
- }
-
// Create rewrite strings for namespace start and end.
string ns_start;
for (const string& n : opts.namespaces) {
@@ -366,6 +381,19 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
ns_end += strings::StrCat("} // end namespace ", n, "\n");
}
+ // Generate metadata.
+ const string arg_names_code =
+ GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
+ const string result_names_code =
+ GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
+ const string include_xla_data_proto =
+ opts.gen_program_shape
+ ?
+ R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
+ : "";
+ const string program_shape_code =
+ GenProgramShapeCode(ps, opts.gen_program_shape);
+
// Use a poor-man's text templating mechanism; first populate the full header
// with placeholder tokens, and then rewrite the tokens with real values.
*header =
@@ -380,22 +408,23 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config,
#ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
-{{CONTEXT_INCLUDE}}
-#include "tensorflow/compiler/aot/runtime.h"
-#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/core/platform/macros.h"
+{{INCLUDE_XLA_DATA_PROTO}}
+#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
+namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void {{ENTRY}}(
- void* result, xla::ExecutableRunOptions* run_options,
- void** args, void** temps);
+ void* result, const xla::ExecutableRunOptions* run_options,
+ const void** args, void** temps);
{{NS_START}}
// {{CLASS}} represents a computation previously specified in a
-// TensorFlow graph, now compiled into executable code. Usage example:
+// TensorFlow graph, now compiled into executable code. This extends the generic
+// XlaCompiledCpuFunction class with statically type-safe arg and result
+// methods. Usage example:
//
// {{CLASS}} computation;
// // ...set args using computation.argN methods
@@ -411,9 +440,9 @@ extern "C" void {{ENTRY}}(
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
-// o Calls to non-const methods require exclusive access to the object.
-// o Concurrent calls to const methods are OK, if those calls are made while
-// it is guaranteed that no thread may call a non-const method.
+// o Calls to non-const methods require exclusive access to the object.
+// o Concurrent calls to const methods are OK, if those calls are made while it
+// is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
// {{PROGRAM_SHAPE}}
@@ -423,7 +452,7 @@ extern "C" void {{ENTRY}}(
// arg bytes aligned: {{ARG_BYTES_ALIGNED}}
// temp bytes total: {{TEMP_BYTES_TOTAL}}
// temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
-class {{CLASS}} {
+class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = {{ARG_NUM}};
@@ -434,47 +463,31 @@ class {{CLASS}} {
return kArgSizes;
}
- // AllocMode controls the buffer allocation mode.
- enum class AllocMode {
- // Allocate all buffers - args, results and temps.
- ARGS_RESULTS_AND_TEMPS,
-
- // Only allocate result and temp buffers.
- // Use set_argN_data to set argument buffers before Run is called.
- RESULTS_AND_TEMPS_ONLY,
- };
-
- {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
- if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
- alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
- ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
- }
-{{CONTEXT_SET_ARG}}
- alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
- TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
- }
-
- ~{{CLASS}}() {
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
- }
-
- // Sets the thread pool to use during the Run call.
- {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
- run_options_.set_intra_op_thread_pool(pool);
-{{CONTEXT_SET_THREAD_POOL}}
- return *this;
- }
-
- // Runs the computation, with inputs read from arg buffers, and outputs
- // written to result buffers. Returns true on success and false on failure.
- bool Run() {
- {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_);
- return {{RUN_RESULT}};
- }
-
- // Returns the error message from the previous failed Run call.
- tensorflow::string error_msg() const { return {{ERROR_MSG}}; }
+ // Returns static data used to create an XlaCompiledCpuFunction.
+ static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
+ static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
+ XlaCompiledCpuFunction::StaticData* data =
+ new XlaCompiledCpuFunction::StaticData;
+ data->raw_function = {{ENTRY}};
+ data->arg_sizes = ArgSizes();
+ data->num_args = kNumArgs;
+ data->temp_sizes = TempSizes();
+ data->num_temps = kNumTemps;
+ data->result_index = kResultIndex;
+ data->requires_runtime_context = {{HAS_CONTEXT_ARG}};
+ data->arg_names = StaticArgNames();
+ data->result_names = StaticResultNames();
+ data->program_shape = StaticProgramShape();
+ return data;
+ }();
+ return *kStaticData;
+ }
+
+ {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
+ : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
+
+ {{CLASS}}(const {{CLASS}}&) = delete;
+ {{CLASS}}& operator=(const {{CLASS}}&) = delete;
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument, with the following
@@ -493,10 +506,6 @@ class {{CLASS}} {
// Returns a reference to the value of type T for positional argument N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
- //
- // void** args()
- // Returns an array of argument buffers, where args()[N] is the buffer for
- // positional argument N.
{{METHODS_ARG}}
// Result methods for managing output buffers. Buffers are in row-major order.
@@ -511,10 +520,6 @@ class {{CLASS}} {
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
- // void** results()
- // Returns an array of result buffers, where results()[N] is the buffer for
- // positional result N.
- //
// Unlike the arg methods, there is no set_resultN_data method. The result
// buffers are managed internally, and may change after each call to Run.
{{METHODS_RESULT}}
@@ -522,7 +527,7 @@ class {{CLASS}} {
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = {{TEMP_NUM}};
- // The 0-based index of the result in the temporary buffers.
+ // The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = {{RESULT_INDEX}};
// Byte size of each result / temporary buffer. There are kNumTemps entries.
@@ -531,14 +536,14 @@ class {{CLASS}} {
return kTempSizes;
}
- void* args_[kNumArgs];
- void* temps_[kNumTemps];
- void* alloc_args_ = nullptr;
- void* alloc_temps_ = nullptr;
- xla::ExecutableRunOptions run_options_;
-{{CONTEXT_MEMBER_VAR}}
+ // Array of names of each positional argument, terminated by nullptr.
+ static const char** StaticArgNames() {{ARG_NAMES_CODE}}
+
+ // Array of names of each positional result, terminated by nullptr.
+ static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
- TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}});
+ // Shape of the args and results.
+ static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}}
};
{{NS_END}}
@@ -550,22 +555,22 @@ class {{CLASS}} {
const std::vector<std::pair<string, string>> rewrites = {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
+ {"{{ARG_NAMES_CODE}}", arg_names_code},
{"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
{"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
{"{{CLASS}}", opts.class_name},
- {"{{CONTEXT_INCLUDE}}\n", context_include},
- {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var},
- {"{{CONTEXT_SET_ARG}}\n", context_set_arg},
- {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool},
{"{{ENTRY}}", compile_result.entry_point},
- {"{{ERROR_MSG}}", error_msg},
+ {"{{HAS_CONTEXT_ARG}}",
+ compile_result.has_context_arg ? "true" : "false"},
+ {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
{"{{METHODS_ARG}}\n", methods_arg},
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{PROGRAM_SHAPE_CODE}}", program_shape_code},
{"{{RESULT_INDEX}}", strings::StrCat(result_index)},
- {"{{RUN_RESULT}}", run_result},
+ {"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
{"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
index 740edd1e83..76dd0cc3cf 100644
--- a/tensorflow/compiler/aot/codegen.h
+++ b/tensorflow/compiler/aot/codegen.h
@@ -34,6 +34,12 @@ struct HeaderOpts {
// Namespaces specifies a list of C++ namespaces to add to the generated
// header. If empty, all symbols will be in the global namespace.
std::vector<string> namespaces;
+
+ // If true, generate name-to-index data for Lookup{Arg,Result}Index methods.
+ bool gen_name_to_index = false;
+
+ // If true, generate program shape data for the ProgramShape method.
+ bool gen_program_shape = false;
};
// GenerateHeader uses the meta-information from compile_result to generate a
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 98cbd67e53..0f6114666f 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -127,6 +127,8 @@ TEST(GenerateHeader, Golden) {
HeaderOpts opts;
opts.class_name = "MyClass";
opts.namespaces = {"foo", "bar"};
+ opts.gen_name_to_index = true;
+ opts.gen_program_shape = true;
tf2xla::Config config;
tf2xla::Feed* feed = config.add_feed();
feed->mutable_id()->set_node_name("feed0");
@@ -145,7 +147,8 @@ TEST(GenerateHeader, Golden) {
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
xla::ShapeUtil::MakeOpaqueShape(),
},
- xla::ShapeUtil::MakeShape(xla::U32, {5, 6}));
+ xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
compile_result.has_context_arg = true;
compile_result.entry_point = "entry_point";
compile_result.pointer_size = 8;
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index 01963c6df4..65f342ce27 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -9,24 +9,25 @@
#ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
-#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
-#include "tensorflow/compiler/aot/runtime.h"
-#include "tensorflow/compiler/xla/executable_run_options.h"
-#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
#include "tensorflow/core/platform/types.h"
namespace Eigen { struct ThreadPoolDevice; }
+namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void entry_point(
- void* result, xla::ExecutableRunOptions* run_options,
- void** args, void** temps);
+ void* result, const xla::ExecutableRunOptions* run_options,
+ const void** args, void** temps);
namespace foo {
namespace bar {
// MyClass represents a computation previously specified in a
-// TensorFlow graph, now compiled into executable code. Usage example:
+// TensorFlow graph, now compiled into executable code. This extends the generic
+// XlaCompiledCpuFunction class with statically type-safe arg and result
+// methods. Usage example:
//
// MyClass computation;
// // ...set args using computation.argN methods
@@ -42,19 +43,19 @@ namespace bar {
// buffer allocation strategy.
//
// Under the default allocation strategy, this class is thread-compatible:
-// o Calls to non-const methods require exclusive access to the object.
-// o Concurrent calls to const methods are OK, if those calls are made while
-// it is guaranteed that no thread may call a non-const method.
+// o Calls to non-const methods require exclusive access to the object.
+// o Concurrent calls to const methods are OK, if those calls are made while it
+// is guaranteed that no thread may call a non-const method.
//
// The logical function signature is:
-// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6]
+// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6])
//
// Memory stats:
// arg bytes total: 104
// arg bytes aligned: 128
// temp bytes total: 126
// temp bytes aligned: 224
-class MyClass {
+class MyClass : public tensorflow::XlaCompiledCpuFunction {
public:
// Number of input arguments for the compiled computation.
static constexpr size_t kNumArgs = 3;
@@ -65,47 +66,31 @@ class MyClass {
return kArgSizes;
}
- // AllocMode controls the buffer allocation mode.
- enum class AllocMode {
- // Allocate all buffers - args, results and temps.
- ARGS_RESULTS_AND_TEMPS,
-
- // Only allocate result and temp buffers.
- // Use set_argN_data to set argument buffers before Run is called.
- RESULTS_AND_TEMPS_ONLY,
- };
-
- MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
- if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
- alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
- ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
- }
- args_[kNumArgs-1] = &context_;
- alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
- TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
- }
-
- ~MyClass() {
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
- tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
- }
-
- // Sets the thread pool to use during the Run call.
- MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
- run_options_.set_intra_op_thread_pool(pool);
- context_.thread_pool = pool;
- return *this;
- }
-
- // Runs the computation, with inputs read from arg buffers, and outputs
- // written to result buffers. Returns true on success and false on failure.
- bool Run() {
- entry_point(temps_[kResultIndex], &run_options_, args_, temps_);
- return !context_.error;
- }
-
- // Returns the error message from the previous failed Run call.
- tensorflow::string error_msg() const { return context_.error_msg; }
+ // Returns static data used to create an XlaCompiledCpuFunction.
+ static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
+ static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
+ XlaCompiledCpuFunction::StaticData* data =
+ new XlaCompiledCpuFunction::StaticData;
+ data->raw_function = entry_point;
+ data->arg_sizes = ArgSizes();
+ data->num_args = kNumArgs;
+ data->temp_sizes = TempSizes();
+ data->num_temps = kNumTemps;
+ data->result_index = kResultIndex;
+ data->requires_runtime_context = true;
+ data->arg_names = StaticArgNames();
+ data->result_names = StaticResultNames();
+ data->program_shape = StaticProgramShape();
+ return data;
+ }();
+ return *kStaticData;
+ }
+
+ MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
+ : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
+
+ MyClass(const MyClass&) = delete;
+ MyClass& operator=(const MyClass&) = delete;
// Arg methods for managing input buffers. Buffers are in row-major order.
// There is a set of methods for each positional argument, with the following
@@ -124,66 +109,59 @@ class MyClass {
// Returns a reference to the value of type T for positional argument N,
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
- //
- // void** args()
- // Returns an array of argument buffers, where args()[N] is the buffer for
- // positional argument N.
-
- void** args() { return args_; }
- const void *const *args() const { return args_; }
void set_arg0_data(void* data) {
- args_[0] = data;
+ set_arg_data(0, data);
}
float* arg0_data() {
- return static_cast<float*>(args_[0]);
+ return static_cast<float*>(arg_data(0));
}
float& arg0(size_t dim0, size_t dim1) {
return (*static_cast<float(*)[1][2]>(
- args_[0]))[dim0][dim1];
+ arg_data(0)))[dim0][dim1];
}
const float* arg0_data() const {
- return static_cast<const float*>(args_[0]);
+ return static_cast<const float*>(arg_data(0));
}
const float& arg0(size_t dim0, size_t dim1) const {
return (*static_cast<const float(*)[1][2]>(
- args_[0]))[dim0][dim1];
+ arg_data(0)))[dim0][dim1];
}
void set_arg_myfeed_data(void* data) {
- args_[0] = data;
+ set_arg_data(0, data);
}
float* arg_myfeed_data() {
- return static_cast<float*>(args_[0]);
+ return static_cast<float*>(arg_data(0));
}
float& arg_myfeed(size_t dim0, size_t dim1) {
return (*static_cast<float(*)[1][2]>(
- args_[0]))[dim0][dim1];
+ arg_data(0)))[dim0][dim1];
}
const float* arg_myfeed_data() const {
- return static_cast<const float*>(args_[0]);
+ return static_cast<const float*>(arg_data(0));
}
const float& arg_myfeed(size_t dim0, size_t dim1) const {
return (*static_cast<const float(*)[1][2]>(
- args_[0]))[dim0][dim1];
+ arg_data(0)))[dim0][dim1];
}
void set_arg1_data(void* data) {
- args_[1] = data;
+ set_arg_data(1, data);
}
tensorflow::int64* arg1_data() {
- return static_cast<tensorflow::int64*>(args_[1]);
+ return static_cast<tensorflow::int64*>(arg_data(1));
}
tensorflow::int64& arg1(size_t dim0, size_t dim1) {
return (*static_cast<tensorflow::int64(*)[3][4]>(
- args_[1]))[dim0][dim1];
+ arg_data(1)))[dim0][dim1];
}
const tensorflow::int64* arg1_data() const {
- return static_cast<const tensorflow::int64*>(args_[1]);
+ return static_cast<const tensorflow::int64*>(arg_data(1));
}
const tensorflow::int64& arg1(size_t dim0, size_t dim1) const {
return (*static_cast<const tensorflow::int64(*)[3][4]>(
- args_[1]))[dim0][dim1];
+ arg_data(1)))[dim0][dim1];
}
// Result methods for managing output buffers. Buffers are in row-major order.
@@ -198,50 +176,43 @@ class MyClass {
// with dim indices specifying which value. No bounds checking is performed
// on dim indices.
//
- // void** results()
- // Returns an array of result buffers, where results()[N] is the buffer for
- // positional result N.
- //
// Unlike the arg methods, there is no set_resultN_data method. The result
// buffers are managed internally, and may change after each call to Run.
- void** results() { return temps_ + kResultIndex; }
- const void *const *results() const { return temps_ + kResultIndex; }
-
tensorflow::uint32* result0_data() {
- return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
+ return static_cast<tensorflow::uint32*>(result_data(0));
}
tensorflow::uint32& result0(size_t dim0, size_t dim1) {
return (*static_cast<tensorflow::uint32(*)[5][6]>(
- temps_[kResultIndex]))[dim0][dim1];
+ result_data(0)))[dim0][dim1];
}
const tensorflow::uint32* result0_data() const {
- return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
+ return static_cast<const tensorflow::uint32*>(result_data(0));
}
const tensorflow::uint32& result0(size_t dim0, size_t dim1) const {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
- temps_[kResultIndex]))[dim0][dim1];
+ result_data(0)))[dim0][dim1];
}
tensorflow::uint32* result_myfetch_data() {
- return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
+ return static_cast<tensorflow::uint32*>(result_data(0));
}
tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) {
return (*static_cast<tensorflow::uint32(*)[5][6]>(
- temps_[kResultIndex]))[dim0][dim1];
+ result_data(0)))[dim0][dim1];
}
const tensorflow::uint32* result_myfetch_data() const {
- return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
+ return static_cast<const tensorflow::uint32*>(result_data(0));
}
const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const {
return (*static_cast<const tensorflow::uint32(*)[5][6]>(
- temps_[kResultIndex]))[dim0][dim1];
+ result_data(0)))[dim0][dim1];
}
private:
// Number of result and temporary buffers for the compiled computation.
static constexpr size_t kNumTemps = 6;
- // The 0-based index of the result in the temporary buffers.
+ // The 0-based index of the result tuple in the temporary buffers.
static constexpr size_t kResultIndex = 5;
// Byte size of each result / temporary buffer. There are kNumTemps entries.
@@ -250,14 +221,29 @@ class MyClass {
return kTempSizes;
}
- void* args_[kNumArgs];
- void* temps_[kNumTemps];
- void* alloc_args_ = nullptr;
- void* alloc_temps_ = nullptr;
- xla::ExecutableRunOptions run_options_;
- tensorflow::XlaLocalRuntimeContext context_;
+ // Array of names of each positional argument, terminated by nullptr.
+ static const char** StaticArgNames() {
+ static const char* kNames[] = {"myfeed", nullptr};
+ return kNames;
+ }
+
+ // Array of names of each positional result, terminated by nullptr.
+ static const char** StaticResultNames() {
+ static const char* kNames[] = {"myfetch", nullptr};
+ return kNames;
+ }
- TF_DISALLOW_COPY_AND_ASSIGN(MyClass);
+ // Shape of the args and results.
+ static const xla::ProgramShape* StaticProgramShape() {
+ static const xla::ProgramShape* kShape = []() {
+ static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0};
+ static constexpr int kProtoSize = 50;
+ xla::ProgramShape* shape = new xla::ProgramShape;
+ shape->ParseFromArray(kProto, kProtoSize);
+ return shape;
+ }();
+ return kShape;
+ }
};
} // end namespace bar
diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc
index 4e3998b682..5aff10346f 100644
--- a/tensorflow/compiler/aot/flags.cc
+++ b/tensorflow/compiler/aot/flags.cc
@@ -64,6 +64,10 @@ void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
"namespaces are given, within the global namespace."},
{"out_object", &flags->out_object, "Output object file name."},
{"out_header", &flags->out_header, "Output header file name."},
+ {"gen_name_to_index", &flags->gen_name_to_index,
+ "Generate name-to-index data for Lookup{Arg,Result}Index methods."},
+ {"gen_program_shape", &flags->gen_program_shape,
+ "Generate program shape data for the ProgramShape method."},
};
flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
}
diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h
index e11a0173fa..3246dbf95c 100644
--- a/tensorflow/compiler/aot/flags.h
+++ b/tensorflow/compiler/aot/flags.h
@@ -37,6 +37,10 @@ struct MainFlags {
string cpp_class;
string out_object;
string out_header;
+
+ // C++ codegen options
+ bool gen_name_to_index = false;
+ bool gen_program_shape = false;
};
// Appends to flag_list a tensorflow::Flag for each field in MainFlags.
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index b0b1213a84..7dfd49cc3b 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -132,6 +132,7 @@ tf_library(
cpp_class = "MatMulAndAddComp",
graph = "test_graph_tfmatmulandadd.pb",
tags = ["manual"],
+ tfcompile_flags = "--gen_name_to_index --gen_program_shape",
)
tf_library(
@@ -156,6 +157,8 @@ tf_cc_test(
":test_graph_tfmatmul",
":test_graph_tfmatmulandadd",
":test_graph_tfsplits",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 07562e59c8..cfde5651c6 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -188,6 +190,23 @@ TEST(TFCompileTest, Gather) {
EXPECT_FALSE(gather.Run());
EXPECT_EQ(gather.error_msg(), "Invalid index for gather");
}
+
+ // Try a successful gather again, after the error, to ensure the error state
+ // is cleared.
+ {
+ const float params[4] = {1, 2, 3, 4};
+ std::copy(params + 0, params + 4, gather.arg0_data());
+ const int32 indices[2] = {1, 3};
+ std::copy(indices + 0, indices + 2, gather.arg1_data());
+ EXPECT_TRUE(gather.Run());
+ EXPECT_EQ(gather.error_msg(), "");
+ const float results[2] = {2, 4};
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ(gather.result0(i), results[i]);
+ EXPECT_EQ(gather.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(gather.result0_data(), gather.results()[0]);
+ }
}
TEST(TFCompileTest, MatMul2) {
@@ -421,6 +440,59 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, LookupNameIndex) {
+ // add doesn't have any names defined in its config.
+ AddComp add;
+ EXPECT_FALSE(add.HasNameIndices());
+
+ // muladd has names defined for all feeds and fetches.
+ MatMulAndAddComp muladd;
+ EXPECT_TRUE(muladd.HasNameIndices());
+
+ EXPECT_EQ(muladd.LookupArgIndex("x"), 0);
+ EXPECT_EQ(muladd.LookupArgIndex("y"), 1);
+ EXPECT_EQ(muladd.LookupArgIndex(""), -1);
+ EXPECT_EQ(muladd.LookupArgIndex("x_hold"), -1);
+ EXPECT_EQ(muladd.LookupArgIndex("y_hold"), -1);
+ EXPECT_EQ(muladd.LookupArgIndex("x_y_prod"), -1);
+ EXPECT_EQ(muladd.LookupArgIndex("x_y_sum"), -1);
+
+ EXPECT_EQ(muladd.LookupResultIndex("x_y_prod"), 0);
+ EXPECT_EQ(muladd.LookupResultIndex("x_y_sum"), 1);
+ EXPECT_EQ(muladd.LookupResultIndex(""), -1);
+ EXPECT_EQ(muladd.LookupResultIndex("x"), -1);
+ EXPECT_EQ(muladd.LookupResultIndex("y"), -1);
+ EXPECT_EQ(muladd.LookupResultIndex("x_hold"), -1);
+ EXPECT_EQ(muladd.LookupResultIndex("y_hold"), -1);
+}
+
+TEST(TFCompileTest, ProgramShape) {
+ using xla::ShapeUtil;
+ const xla::Shape f32_2x2 = ShapeUtil::MakeShape(xla::F32, {2, 2});
+
+ // add doesn't have the program shape defined.
+ AddComp add;
+ ASSERT_TRUE(add.ProgramShape() == nullptr);
+
+ // muladd has the program shape defined.
+ MatMulAndAddComp muladd;
+ const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
+ ASSERT_TRUE(muladd_shape != nullptr);
+ ASSERT_EQ(muladd_shape->parameters_size(), 2);
+ EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
+ EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
+
+ const xla::Shape& muladd_result = muladd_shape->result();
+ ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
+ ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
+ const xla::Shape& muladd_result0 =
+ ShapeUtil::GetTupleElementShape(muladd_result, 0);
+ EXPECT_TRUE(ShapeUtil::Compatible(muladd_result0, f32_2x2));
+ const xla::Shape& muladd_result1 =
+ ShapeUtil::GetTupleElementShape(muladd_result, 1);
+ EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2));
+}
+
} // namespace
} // namespace tfcompile
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 608d461a4c..461a9315c5 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -167,6 +167,8 @@ def tf_library(name, graph, config,
# The cc_library rule packaging up the header and object file, and needed
# kernel implementations.
+ need_xla_data_proto = (tfcompile_flags and
+ tfcompile_flags.find("--gen_program_shape") != -1)
native.cc_library(
name=name,
srcs=[object_file],
@@ -177,11 +179,12 @@ def tf_library(name, graph, config,
# These deps are required by all tf_library targets even if
# include_standard_runtime_deps is False. Without them, the
# generated code will fail to compile.
- "//tensorflow/compiler/aot:runtime",
- "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
- "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function",
"//tensorflow/core:framework_lite",
- ] + (include_standard_runtime_deps and [
+ ] + (need_xla_data_proto and [
+ # If we're generating the program shape, we must depend on the proto.
+ "//tensorflow/compiler/xla:xla_data_proto",
+ ] or []) + (include_standard_runtime_deps and [
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32",
"//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64",
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
index cc499c3284..6ab3d47418 100644
--- a/tensorflow/compiler/aot/tfcompile_main.cc
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -94,6 +94,8 @@ Status Main(const MainFlags& flags) {
TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
StringPiece(obj.data(), obj.size())));
HeaderOpts header_opts;
+ header_opts.gen_name_to_index = flags.gen_name_to_index;
+ header_opts.gen_program_shape = flags.gen_program_shape;
if (flags.cpp_class.empty()) {
return errors::InvalidArgument("Must specify --cpp_class");
}
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 0769b13718..08f2249e0d 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -59,6 +59,41 @@ cc_library(
)
cc_library(
+ name = "xla_compiled_cpu_function",
+ srcs = ["xla_compiled_cpu_function.cc"],
+ hdrs = ["xla_compiled_cpu_function.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ # Keep dependencies to a minimum here; this library is used in every AOT
+ # binary produced by tfcompile.
+ "//tensorflow/compiler/aot:runtime",
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:framework_lite",
+ ],
+)
+
+cc_library(
+ name = "xla_jit_compiled_cpu_function",
+ srcs = ["xla_jit_compiled_cpu_function.cc"],
+ hdrs = ["xla_jit_compiled_cpu_function.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tf2xla",
+ ":tf2xla_proto",
+ ":xla_compiled_cpu_function",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service/cpu:cpu_executable",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
name = "xla_compiler",
srcs = [
"xla_compilation_device.cc",
@@ -179,6 +214,26 @@ tf_cc_test(
)
tf_cc_test(
+ name = "xla_jit_compiled_cpu_function_test",
+ srcs = ["xla_jit_compiled_cpu_function_test.cc"],
+ deps = [
+ ":tf2xla_proto",
+ ":xla_jit_compiled_cpu_function",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
name = "xla_compiler_test",
srcs = ["xla_compiler_test.cc"],
deps = [
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
new file mode 100644
index 0000000000..b5c17c5273
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -0,0 +1,88 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
+
+#include <cassert>
+#include "tensorflow/compiler/aot/runtime.h"
+
+namespace tensorflow {
+
+XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
+ AllocMode alloc_mode)
+ : raw_function_(static_data.raw_function),
+ result_index_(static_data.result_index),
+ args_(new void*[static_data.num_args]),
+ temps_(new void*[static_data.num_temps]),
+ arg_names_(static_data.arg_names),
+ result_names_(static_data.result_names),
+ program_shape_(static_data.program_shape) {
+ // Allocate arg and temp buffers.
+ if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
+ alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ static_data.arg_sizes, static_data.num_args, args_,
+ /*annotate_initialized=*/false);
+ }
+ alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ static_data.temp_sizes, static_data.num_temps, temps_,
+ /*annotate_initialized=*/true);
+
+ // The runtime context is always the last arg, if it is required.
+ if (static_data.requires_runtime_context) {
+ args_[static_data.num_args - 1] = &context_;
+ }
+}
+
+XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
+ tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
+ tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
+ delete[] args_;
+ delete[] temps_;
+}
+
+namespace {
+
+// Linear search through `names` looking for a match with `name`. Returns -1 if
+// the name isn't found, or is empty.
+//
+// REQUIRES: `names` is a nullptr-terminated array.
+int LookupNameIndex(const string& name, const char** names) {
+ // Hitting this assert means that there is no name-to-index data available;
+ // for AOT try the setting the tfcompile --gen_name_to_index flag.
+ assert(names != nullptr);
+
+ constexpr int kNotFound = -1;
+ if (name.empty()) {
+ return kNotFound;
+ }
+ for (int index = 0; names[index] != nullptr; ++index) {
+ if (name == names[index]) {
+ return index;
+ }
+ }
+ return kNotFound;
+}
+
+} // namespace
+
+int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const {
+ return LookupNameIndex(name, arg_names_);
+}
+
+int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
+ return LookupNameIndex(name, result_names_);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
new file mode 100644
index 0000000000..01e6b4c071
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
+
+#include <functional>
+#include <string>
+
+#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/types.h"
+
+// Forward-declare, rather than include, to reduce code size for users that
+// never use this functionality.
+namespace xla {
+class ProgramShape;
+}
+
+namespace tensorflow {
+
+// Represents a function compiled by XLA, produced via either JIT or AOT.
+//
+// The Run method invokes the actual computation, with inputs read from arg
+// buffers, and outputs written to result buffers. Each Run call may also use a
+// set of temporary buffers for the computation.
+//
+// By default each instance of this class manages its own arg, result and temp
+// buffers. The AllocMode constructor parameter may be used to modify the buffer
+// allocation strategy.
+//
+// Under the default allocation strategy, this class is thread-compatible:
+// o Calls to non-const methods require exclusive access to the object.
+// o Concurrent calls to const methods are OK, if those calls are made while it
+// is guaranteed that no thread may call a non-const method.
+class XlaCompiledCpuFunction {
+ public:
+ // Type of the raw function, produced by either JIT or AOT.
+ //
+ // TODO(toddw): Add support for hlo profiling, and replace std::function with
+ // a raw function pointer, for some codesize savings.
+ using RawFunction = std::function<void(
+ void* result, const xla::ExecutableRunOptions* run_options,
+ const void** args, void** temps)>;
+
+ // StaticData represents the state necessary to run an XLA-compiled
+ // function. For JIT this is backed by data in XlaCompiledCpuFunctionJit; for
+ // AOT this is backed by data compiled into the object file.
+ struct StaticData {
+ // The raw function to call.
+ RawFunction raw_function;
+
+ // Cardinality and sizes of arg and temp buffers.
+ const intptr_t* arg_sizes = nullptr;
+ size_t num_args = 0;
+ const intptr_t* temp_sizes = nullptr;
+ size_t num_temps = 0;
+
+ // The 0-based index of the result tuple, in the temp buffers.
+ size_t result_index = 0;
+
+ // Is the final arg XlaLocalRuntimeContext?
+ bool requires_runtime_context = false;
+
+ // [Optional] Arrays of arg and result names. These are arrays of C-style
+ // strings, where the array is terminated by nullptr.
+ const char** arg_names = nullptr;
+ const char** result_names = nullptr;
+
+ // [Optional] Arg and result shapes.
+ const xla::ProgramShape* program_shape = nullptr;
+ };
+
+ // AllocMode controls the buffer allocation mode.
+ enum class AllocMode {
+ // Allocate all buffers - args, results and temps.
+ ARGS_RESULTS_AND_TEMPS,
+
+ // Only allocate result and temp buffers.
+ // Use set_arg_data to set argument buffers before Run is called.
+ RESULTS_AND_TEMPS_ONLY,
+ };
+
+ XlaCompiledCpuFunction(
+ const StaticData& static_data,
+ AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS);
+ virtual ~XlaCompiledCpuFunction();
+
+ XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete;
+ XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete;
+
+ // Sets the intra-op thread pool used to run individual ops concurrently.
+ void set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
+ run_options_.set_intra_op_thread_pool(pool);
+ context_.thread_pool = pool;
+ }
+
+ // Runs the computation, with inputs read from arg buffers, and outputs
+ // written to result buffers. Returns true on success and false on failure.
+ bool Run() {
+ context_.error = false;
+ context_.error_msg.clear();
+ raw_function_(temps_[result_index_], &run_options_,
+ const_cast<const void**>(args_), temps_);
+ return !context_.error;
+ }
+
+ // Returns the error message from the previous failed Run call.
+ const string& error_msg() const { return context_.error_msg; }
+
+ // ------------------------------
+ // Arg methods for managing input buffers. Buffers are in row-major order.
+
+ // Returns the underlying array of argument buffers, where args()[I] is the
+ // buffer for the positional argument at index I.
+ void** args() { return args_; }
+ const void* const* args() const { return args_; }
+
+ // Returns the buffer for the positional argument at the given `index`.
+ void* arg_data(size_t index) { return args_[index]; }
+ const void* arg_data(size_t index) const { return args_[index]; }
+
+ // Sets the buffer for the positional argument at the given `index` to `data`.
+ // Must be called before Run to have an effect. May be called under any
+ // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be
+ // called for each positional argument, in order to set the argument buffers.
+ //
+ // Allocated memory must be aligned to the size specified by
+ // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in
+ // tensorflow/compiler/aot/runtime.h to ensure correct alignment.
+ //
+ // If StaticData.requires_runtime_context==true, the final argument is an
+ // XlaLocalRuntimeContext, which is managed internally by this class, and
+ // should not be changed.
+ //
+ // Aliasing of argument and result buffers is not allowed, and results in
+ // undefined behavior.
+ void set_arg_data(size_t index, void* data) { args_[index] = data; }
+
+ // ------------------------------
+ // Result methods for managing output buffers. Buffers are in row-major order.
+ // Must only be called after a successful Run call. Unlike the arg methods,
+ // there is no set_resultN_data method. The result buffers are managed
+ // internally, and may change after each call to Run.
+
+ // Returns the underlying array of result buffers, where results()[I] is the
+ // buffer for the positional result at index I.
+ void** results() { return static_cast<void**>(temps_[result_index_]); }
+ const void* const* results() const {
+ return static_cast<const void* const*>(temps_[result_index_]);
+ }
+
+ // Returns the buffer for the positional result at the given `index`.
+ void* result_data(size_t index) { return results()[index]; }
+ const void* result_data(size_t index) const { return results()[index]; }
+
+ // ------------------------------
+ // Methods for extracting optional metadata.
+
+ // Returns true iff data is available for the Lookup{Arg,Result}Index methods.
+ // E.g. the data might not be compiled into the binary for AOT.
+ bool HasNameIndices() const {
+ return arg_names_ != nullptr && result_names_ != nullptr;
+ }
+
+ // Returns the 0-based index for the argument with the given `name`.
+ // Returns -1 if the name wasn't found, or data isn't available.
+ //
+ // The index remains constant for every instance of XlaCompiledCpuFunction
+ // generated from the same static data, and might not be cheap to determine.
+ // Recommended usage is to capture this in a variable for re-use.
+ int LookupArgIndex(const string& name) const;
+
+ // Returns the 0-based index for the result with the given `name`.
+ // Returns -1 if the name wasn't found, or data isn't available.
+ //
+ // The index remains constant for every instance of XlaCompiledCpuFunction
+ // generated from the same static data, and might not be cheap to determine.
+ // Recommended usage is to capture this in a variable for re-use.
+ int LookupResultIndex(const string& name) const;
+
+ // Returns the shape of the args and results. May return nullptr if the
+ // program shape isn't available.
+ const xla::ProgramShape* ProgramShape() const { return program_shape_; }
+
+ private:
+ const RawFunction raw_function_;
+ const size_t result_index_;
+
+ // Arrays of argument and temp buffers; entries in args_ may be overwritten by
+ // the user.
+ void** args_ = nullptr;
+ void** temps_ = nullptr;
+
+ // Backing memory for individual arg and temp buffers.
+ void* alloc_args_ = nullptr;
+ void* alloc_temps_ = nullptr;
+
+ // Options and context passed to the compiled function.
+ xla::ExecutableRunOptions run_options_;
+ tensorflow::XlaLocalRuntimeContext context_;
+
+ // Optional metadata.
+ const char** arg_names_ = nullptr;
+ const char** result_names_ = nullptr;
+ const xla::ProgramShape* program_shape_ = nullptr;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
new file mode 100644
index 0000000000..1dd454ea8d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -0,0 +1,217 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/tf2xla.h"
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Returns a vector of positional argument buffer sizes.
+xla::StatusOr<std::vector<intptr_t>> ComputeArgSizes(
+ const xla::ProgramShape& program_shape, bool requires_runtime_context) {
+ std::vector<intptr_t> arg_sizes;
+ const size_t num_args = program_shape.parameters_size();
+ arg_sizes.reserve(num_args);
+ for (int i = 0; i < num_args; ++i) {
+ const xla::Shape& arg_shape = program_shape.parameters(i);
+ if (i == num_args - 1 && requires_runtime_context) {
+ // If the compiled function needs an XlaLocalRuntimeContext* arg, it's
+ // always last, and must be represented as an opaque type.
+ const xla::PrimitiveType type = arg_shape.element_type();
+ if (type != xla::OPAQUE) {
+ return errors::InvalidArgument(
+ "expected final context arg to be opaque, but got type: ",
+ xla::PrimitiveType_Name(type), ", from program shape: ",
+ xla::ShapeUtil::HumanString(program_shape));
+ }
+ arg_sizes.push_back(-1);
+ } else {
+ constexpr size_t kPointerSize = sizeof(void*);
+ arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize));
+ }
+ }
+ return std::move(arg_sizes);
+}
+
+// Returns a vector of positional temporary buffer sizes.
+xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
+ const xla::BufferAssignment& buffer_assignment) {
+ const std::vector<xla::BufferAllocation>& allocations =
+ buffer_assignment.Allocations();
+ std::vector<intptr_t> temp_sizes;
+ temp_sizes.reserve(allocations.size());
+ for (const xla::BufferAllocation& allocation : allocations) {
+ // Callers don't allocate temporary buffers for parameters. Nor for
+ // thread-local buffers, which are lowered to alloca.
+ if (allocation.is_entry_computation_parameter() ||
+ allocation.is_thread_local()) {
+ temp_sizes.push_back(-1);
+ } else {
+ temp_sizes.push_back(allocation.size());
+ }
+ }
+ return std::move(temp_sizes);
+}
+
+// Returns the index of the result in the temp buffers.
+xla::StatusOr<size_t> ComputeResultIndex(
+ const xla::BufferAssignment& buffer_assignment) {
+ TF_ASSIGN_OR_RETURN(const xla::BufferAllocation::Slice result_slice,
+ buffer_assignment.GetUniqueTopLevelOutputSlice());
+ return result_slice.index();
+}
+
+// Adapt ComputeFunctionType, which includes a final profile_counters arg, to
+// RawFunction, which doesn't include that final arg.
+//
+// TODO(toddw): Change RawFunction and AOT to also pass the final
+// profile_counters arg, and remove this adapter.
+XlaCompiledCpuFunction::RawFunction RawFunctionAdapter(
+ xla::cpu::CpuExecutable::ComputeFunctionType compute_function) {
+ return [compute_function](void* result,
+ const xla::ExecutableRunOptions* run_options,
+ const void** args, void** temps) {
+ return compute_function(result, run_options, args, temps,
+ /*profile_counters=*/nullptr);
+ };
+}
+
+// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold
+// the actual strings in nonempty_names, and hold arrays of pointers in
+// name_ptrs, terminated by a nullptr entry.
+template <typename T>
+void CollectNames(const T& entries, std::vector<string>* nonempty_names,
+ std::vector<const char*>* name_ptrs) {
+ // First collect `nonempty_names`, to ensure the underlying strings won't
+ // change out from under us.
+ for (const auto& entry : entries) {
+ const string& name = entry.name();
+ if (!name.empty()) {
+ nonempty_names->push_back(name);
+ }
+ }
+ // Now set `name_ptrs` pointing to the strings in `nonempty_names`.
+ name_ptrs->reserve(entries.size() + 1); // +1 for nullptr array terminator
+ size_t nonempty_index = 0;
+ for (const auto& entry : entries) {
+ const string& name = entry.name();
+ if (!name.empty()) {
+ name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str());
+ ++nonempty_index;
+ } else {
+ name_ptrs->push_back("");
+ }
+ }
+ name_ptrs->push_back(nullptr); // array terminator
+}
+
+} // namespace
+
+/*static*/ xla::StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>>
+XlaJitCompiledCpuFunction::Compile(
+ const GraphDef& graph_def, const tf2xla::Config& config,
+ const xla::ExecutableBuildOptions& build_options) {
+ // Convert the graph_def into an xla::Computation.
+ TF_ASSIGN_OR_RETURN(xla::LocalClient * client,
+ xla::ClientLibrary::GetOrCreateLocalClient());
+ xla::Computation computation;
+ bool requires_runtime_context;
+ TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(
+ graph_def, config, client, &computation, &requires_runtime_context));
+
+ // Get and verify the program shape.
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> program_shape,
+ client->GetComputationShape(computation));
+ if (program_shape->result().element_type() != xla::TUPLE) {
+ // The XlaCompiler we use to build the xla computation always generates a
+ // tuple result, and XlaCompiledCpuFunction relies on this for simpler
+ // calling semantics.
+ return errors::Internal(
+ "XlaJitCompiledCpuFunction requires the XLA result to be a tuple");
+ }
+ // The parameter names are currently meaningless, and redundant with the rest
+ // of our metadata, so clear them out to avoid confusion and save space.
+ program_shape->clear_parameter_names();
+
+ // Compute arg shapes, needed to compile the executable.
+ std::vector<const xla::Shape*> arg_shapes;
+ arg_shapes.reserve(program_shape->parameters_size());
+ for (int i = 0; i < program_shape->parameters_size(); ++i) {
+ arg_shapes.push_back(&program_shape->parameters(i));
+ }
+
+ // Compile the executable. The static_cast to the CpuExecutable subclass is
+ // necessary since the raw function and buffer assignments are only available
+ // there.
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
+ client->Compile(computation, arg_shapes, build_options));
+ const xla::cpu::CpuExecutable* cpu_executable =
+ static_cast<xla::cpu::CpuExecutable*>(executable->executable());
+ XlaCompiledCpuFunction::RawFunction raw_function =
+ RawFunctionAdapter(cpu_executable->compute_function());
+ const xla::BufferAssignment& buffer_assignment =
+ cpu_executable->buffer_assignment();
+
+ // Compute buffer sizes and the result index, needed to run the raw function.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<intptr_t> arg_sizes,
+ ComputeArgSizes(*program_shape, requires_runtime_context));
+ TF_ASSIGN_OR_RETURN(std::vector<intptr_t> temp_sizes,
+ ComputeTempSizes(buffer_assignment));
+ TF_ASSIGN_OR_RETURN(size_t result_index,
+ ComputeResultIndex(buffer_assignment));
+
+ std::unique_ptr<XlaJitCompiledCpuFunction> jit_unique_ptr(
+ new XlaJitCompiledCpuFunction);
+ XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get();
+ jit->executable_ = std::move(executable);
+ jit->arg_sizes_ = std::move(arg_sizes);
+ jit->temp_sizes_ = std::move(temp_sizes);
+ jit->program_shape_ = std::move(program_shape);
+ jit->static_data_.raw_function = std::move(raw_function);
+ jit->static_data_.arg_sizes = jit->arg_sizes_.data();
+ jit->static_data_.num_args = jit->arg_sizes_.size();
+ jit->static_data_.temp_sizes = jit->temp_sizes_.data();
+ jit->static_data_.num_temps = jit->temp_sizes_.size();
+ jit->static_data_.result_index = result_index;
+ jit->static_data_.requires_runtime_context = requires_runtime_context;
+ // Optional metadata is collected and set below.
+ CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
+ CollectNames(config.fetch(), &jit->nonempty_result_names_,
+ &jit->result_names_);
+ jit->static_data_.arg_names = jit->arg_names_.data();
+ jit->static_data_.result_names = jit->result_names_.data();
+ jit->static_data_.program_shape = jit->program_shape_.get();
+ return std::move(jit_unique_ptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
new file mode 100644
index 0000000000..af307ae4ef
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Represents the result of JIT compilation by XLA down to a function. This
+// class holds the state necessary to create XlaCompiledCpuFunction instances,
+// which are used to actually invoke the compiled computation.
+//
+// XlaJitCompiledCpuFunction must outlive the XlaCompiledCpuFunctions that are
+// created from it. It holds state shared by all of the functions, including the
+// JIT-compiled function itself, along with buffer sizes and other metadata
+// necessary for execution.
+class XlaJitCompiledCpuFunction {
+ public:
+ // Compile a tensorflow::GraphDef into an XlaJitCompiledCpuFunction. The given
+ // `config` specifies the portion of the graph to compile, via feeds and
+ // fetches. Each feed is a positional input argument for the compiled
+ // function, while each fetch is a positional output argument.
+ static xla::StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>> Compile(
+ const GraphDef& graph_def, const tf2xla::Config& config,
+ const xla::ExecutableBuildOptions& build_options);
+
+ XlaJitCompiledCpuFunction(const XlaJitCompiledCpuFunction&) = delete;
+ XlaJitCompiledCpuFunction& operator=(const XlaJitCompiledCpuFunction&) =
+ delete;
+
+ // Returns static data used to create an XlaCompiledCpuFunction instance,
+ // which represents the JIT-compiled function. The static data is unchanging
+ // across each instance.
+ const XlaCompiledCpuFunction::StaticData& StaticData() const {
+ return static_data_;
+ }
+
+ private:
+ XlaJitCompiledCpuFunction() {}
+
+ // The executable holds the underlying function.
+ std::unique_ptr<xla::LocalExecutable> executable_;
+
+ // The static data is backed by the rest of the state in this class.
+ XlaCompiledCpuFunction::StaticData static_data_;
+
+ // The backing arrays of arg and temp buffer sizes.
+ std::vector<intptr_t> arg_sizes_;
+ std::vector<intptr_t> temp_sizes_;
+
+ // The backing arrays of arg and result names. We hold the actual strings in
+ // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static
+ // data to refer to.
+ std::vector<string> nonempty_arg_names_;
+ std::vector<string> nonempty_result_names_;
+ std::vector<const char*> arg_names_;
+ std::vector<const char*> result_names_;
+
+ // The backing data for the program shape.
+ std::unique_ptr<const xla::ProgramShape> program_shape_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
new file mode 100644
index 0000000000..5bee68eefc
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h"
+
+#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+AttrValue TypeAttrValue(DataType type) {
+ AttrValue attr_value;
+ SetAttrValue(type, &attr_value);
+ return attr_value;
+}
+
+GraphDef SumGraph() {
+ GraphDef graph_def;
+ NodeDef* x = graph_def.add_node();
+ x->set_name("x");
+ x->set_op("Placeholder");
+ (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
+ NodeDef* y = graph_def.add_node();
+ y->set_name("y");
+ y->set_op("Placeholder");
+ (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32);
+ NodeDef* sum = graph_def.add_node();
+ sum->set_name("sum");
+ sum->set_op("Add");
+ sum->add_input("x");
+ sum->add_input("y");
+ (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32);
+ return graph_def;
+}
+
+tf2xla::Config SumConfig() {
+ tf2xla::Config config;
+ tf2xla::Feed* x = config.add_feed();
+ x->mutable_id()->set_node_name("x");
+ x->set_name("x_name");
+ tf2xla::Feed* y = config.add_feed();
+ y->mutable_id()->set_node_name("y");
+ y->set_name("y_name");
+ tf2xla::Fetch* sum = config.add_fetch();
+ sum->mutable_id()->set_node_name("sum");
+ sum->set_name("sum_name");
+ return config;
+}
+
+TEST(XlaJitCompiledCpuFunction, Sum) {
+ GraphDef graph_def = SumGraph();
+ tf2xla::Config config = SumConfig();
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<XlaJitCompiledCpuFunction> jit,
+ XlaJitCompiledCpuFunction::Compile(graph_def, config,
+ xla::ExecutableBuildOptions()));
+ XlaCompiledCpuFunction function(jit->StaticData());
+
+ // Run the function and check results.
+ *static_cast<int32*>(function.arg_data(0)) = 10;
+ *static_cast<int32*>(function.arg_data(1)) = 32;
+ EXPECT_TRUE(function.Run());
+ EXPECT_EQ(function.error_msg(), "");
+ EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 42);
+
+ // Run the function again.
+ *static_cast<int32*>(function.arg_data(0)) = 100;
+ *static_cast<int32*>(function.arg_data(1)) = 320;
+ EXPECT_TRUE(function.Run());
+ EXPECT_EQ(function.error_msg(), "");
+ EXPECT_EQ(*static_cast<int32*>(function.result_data(0)), 420);
+
+ // Check name to index lookups.
+ EXPECT_TRUE(function.HasNameIndices());
+
+ EXPECT_EQ(function.LookupArgIndex("x_name"), 0);
+ EXPECT_EQ(function.LookupArgIndex("y_name"), 1);
+ EXPECT_EQ(function.LookupArgIndex(""), -1);
+ EXPECT_EQ(function.LookupArgIndex("x"), -1);
+ EXPECT_EQ(function.LookupArgIndex("y"), -1);
+ EXPECT_EQ(function.LookupArgIndex("sum"), -1);
+ EXPECT_EQ(function.LookupArgIndex("sum_name"), -1);
+
+ EXPECT_EQ(function.LookupResultIndex("sum_name"), 0);
+ EXPECT_EQ(function.LookupResultIndex(""), -1);
+ EXPECT_EQ(function.LookupResultIndex("x"), -1);
+ EXPECT_EQ(function.LookupResultIndex("y"), -1);
+ EXPECT_EQ(function.LookupResultIndex("sum"), -1);
+ EXPECT_EQ(function.LookupResultIndex("x_name"), -1);
+ EXPECT_EQ(function.LookupResultIndex("y_name"), -1);
+
+ // Check program shape.
+ using xla::ShapeUtil;
+ const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
+ const xla::ProgramShape* program_shape = function.ProgramShape();
+ ASSERT_TRUE(program_shape != nullptr);
+ ASSERT_EQ(program_shape->parameters_size(), 2);
+ EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
+ EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32));
+
+ const xla::Shape& result = program_shape->result();
+ ASSERT_EQ(result.element_type(), xla::TUPLE);
+ ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1);
+ const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0);
+ EXPECT_TRUE(ShapeUtil::Compatible(result0, s32));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 0d68aa7399..238bc9b46a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -87,6 +87,17 @@ class CpuExecutable : public Executable {
std::unique_ptr<HloCostAnalysis> CreateCostAnalysis() const override;
+ // Type of the computation function we expect in the JIT.
+ using ComputeFunctionType = void (*)(
+ void* /*result*/, const ExecutableRunOptions* /*run_options*/,
+ const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/);
+
+ const ComputeFunctionType& compute_function() const {
+ return compute_function_;
+ }
+
+ const BufferAssignment& buffer_assignment() const { return *assignment_; }
+
private:
// Allocate buffers required for execution and assign them to the elements of
// "buffers". "buffers" should be sized to the number of buffers in buffer
@@ -129,11 +140,6 @@ class CpuExecutable : public Executable {
// positives.
string ir_module_string_;
- // Type of the computation function we expect in the JIT.
- // void function(void* result, const void* run_options,
- // const void** args_array, void** temps_array)
- using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
- uint64*);
ComputeFunctionType compute_function_;
// Entry function name for the computation.