aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-10 16:02:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 16:06:59 -0700
commitf51fc5f89a3fa934e078b35d50b26625a9ff42cf (patch)
treea2a63482390f20302eead8835a80e0cb24d6a734
parent8d532ac40f4db7f5293610fd3c6e92a3f7409b76 (diff)
Introduce and use a BufferInfo class.
The BufferInfo represents information about buffer assignment in XlaCompiledCpuFunction. Arg sizes and temp sizes are now derived from BufferInfo instead of being discrete sources of information. Also made StaticData() private, tfcompile clients should not need to access it directly. PiperOrigin-RevId: 208283305
-rw-r--r--tensorflow/compiler/aot/BUILD1
-rw-r--r--tensorflow/compiler/aot/codegen.cc169
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc12
-rw-r--r--tensorflow/compiler/aot/codegen_test_h.golden58
-rw-r--r--tensorflow/compiler/aot/test.cc12
-rw-r--r--tensorflow/compiler/tf2xla/BUILD5
-rw-r--r--tensorflow/compiler/tf2xla/cpu_function_runtime.cc30
-rw-r--r--tensorflow/compiler/tf2xla/cpu_function_runtime.h133
-rw-r--r--tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc72
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc63
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h137
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc80
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h8
-rw-r--r--tensorflow/compiler/xla/service/compiler.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD13
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.cc57
-rw-r--r--tensorflow/compiler/xla/service/cpu/buffer_info_util.h42
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc37
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h17
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc8
20 files changed, 655 insertions, 304 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index d2f803bd18..1899a32e4d 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -48,6 +48,7 @@ cc_library(
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 8dbe1e11b7..89fefdad54 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -36,6 +37,8 @@ namespace tfcompile {
namespace {
+using BufferInfo = cpu_function_runtime::BufferInfo;
+
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
}
@@ -85,27 +88,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
return Status::OK();
}
-// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
-// values. There are `n` entries in `sizes`.
-size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
- size_t total = 0;
- for (size_t i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
- total += sizes[i];
- }
- }
- return total;
+// Returns the sum of the size of each buffer in `buffer_infos`.
+size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
+ return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
+ [](size_t size, const BufferInfo& buffer_info) {
+ return size + buffer_info.size();
+ });
}
-// Fills in arg_sizes with the byte size of each positional arg.
-Status ComputeArgSizes(const CompileResult& compile_result,
- std::vector<int64>* arg_sizes) {
- const xla::ProgramShape& ps = compile_result.program_shape;
- for (int i = 0; i < ps.parameters_size(); ++i) {
- arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
- ps.parameters(i), compile_result.pointer_size));
- }
- return Status::OK();
+// Returns a vector of BufferInfo instances in `buffer_infos` that are entry
+// parameter buffers.
+std::vector<BufferInfo> ExtractEntryParamBufferInfos(
+ const std::vector<BufferInfo>& buffer_infos) {
+ std::vector<BufferInfo> result;
+ std::copy_if(buffer_infos.begin(), buffer_infos.end(),
+ std::back_inserter(result), [](const BufferInfo& buffer_info) {
+ return buffer_info.is_entry_parameter();
+ });
+ return result;
+}
+
+// Returns a vector of BufferInfo instances in `buffer_infos` that are temp
+// buffers.
+std::vector<BufferInfo> ExtractTempBufferInfos(
+ const std::vector<BufferInfo>& buffer_infos) {
+ std::vector<BufferInfo> result;
+ std::copy_if(buffer_infos.begin(), buffer_infos.end(),
+ std::back_inserter(result), [](const BufferInfo& buffer_info) {
+ return buffer_info.is_temp_buffer();
+ });
+ return result;
}
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
@@ -278,6 +290,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
return Status::OK();
}
+// Returns a list of C++ expressions that, when executed, will construct the
+// BufferInfo instances in `buffer_infos`.
+std::vector<string> BufferInfosToCppExpression(
+ const std::vector<BufferInfo>& buffer_infos) {
+ std::vector<string> buffer_infos_as_strings;
+ std::transform(buffer_infos.begin(), buffer_infos.end(),
+ std::back_inserter(buffer_infos_as_strings),
+ [](const BufferInfo& buffer_info) {
+ std::pair<uint64, uint64> encoded = buffer_info.Encode();
+ string encoded_second_as_str =
+ encoded.second == ~0ULL
+ ? "~0ULL"
+ : strings::StrCat(encoded.second, "ULL");
+ return strings::StrCat(
+ "::tensorflow::cpu_function_runtime::BufferInfo({",
+ encoded.first, "ULL, ", encoded_second_as_str, "})");
+ });
+ return buffer_infos_as_strings;
+}
} // namespace
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
@@ -286,29 +317,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ValidateConfig(config));
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
- const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
- if (result_index < 0 || result_index >= temp_sizes.size()) {
+ const std::vector<BufferInfo>& buffer_infos =
+ compile_result.aot->buffer_infos();
+ const std::vector<int32> arg_index_table =
+ ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
+ std::vector<string> buffer_infos_as_strings =
+ BufferInfosToCppExpression(buffer_infos);
+ if (result_index < 0 || result_index >= buffer_infos.size()) {
return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,",
- temp_sizes.size(), ")");
+ buffer_infos.size(), ")");
}
// Compute sizes and generate methods.
- std::vector<int64> arg_sizes;
- TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
+ std::vector<BufferInfo> buffer_infos_for_args =
+ ExtractEntryParamBufferInfos(buffer_infos);
+ std::vector<BufferInfo> buffer_infos_for_temps =
+ ExtractTempBufferInfos(buffer_infos);
const xla::ProgramShape& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
- const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
- const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
- const size_t arg_bytes_aligned =
- cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size());
- const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
- const size_t temp_bytes_aligned =
- cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size());
- const size_t temp_bytes_total =
- total_buffer_bytes(itemp.data(), itemp.size());
+ const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
+ buffer_infos_for_args.data(), buffer_infos_for_args.size(),
+ /*allocate_entry_params=*/true);
+ const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
+ const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
+ buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
+ /*allocate_entry_params=*/true);
+ const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
// Create rewrite strings for namespace start and end.
string ns_start;
@@ -343,8 +380,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// calling HloProfilePrinter::profile_counters_size.
const string assign_profile_counters_size =
opts.gen_hlo_profile_printer_data
- ? "data->profile_counters_size = "
- "data->hlo_profile_printer_data->profile_counters_size();"
+ ? "data->set_profile_counters_size("
+ "data->hlo_profile_printer_data()->profile_counters_size());"
: "";
// Use a poor-man's text templating mechanism; first populate the full header
@@ -414,9 +451,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
- static const intptr_t* ArgSizes() {
- static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
- return kArgSizes;
+ static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
+ return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
}
// Returns static data used to create an XlaCompiledCpuFunction.
@@ -424,17 +460,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
- data->raw_function = {{ENTRY}};
- data->arg_sizes = ArgSizes();
- data->num_args = kNumArgs;
- data->temp_sizes = TempSizes();
- data->num_temps = kNumTemps;
- data->result_index = kResultIndex;
- data->arg_names = StaticArgNames();
- data->result_names = StaticResultNames();
- data->program_shape = StaticProgramShape();
- data->hlo_profile_printer_data = StaticHloProfilePrinterData();
- {{ASSIGN_PROFILE_COUNTERS_SIZE}}
+ data->set_raw_function({{ENTRY}});
+ data->set_buffer_infos(BufferInfos());
+ data->set_num_buffers(kNumBuffers);
+ data->set_arg_index_table(ArgIndexToBufferIndex());
+ data->set_num_args(kNumArgs);
+ data->set_result_index(kResultIndex);
+ data->set_arg_names(StaticArgNames());
+ data->set_result_names(StaticResultNames());
+ data->set_program_shape(StaticProgramShape());
+ data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
+{{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
return *kStaticData;
@@ -482,17 +518,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{{METHODS_RESULT}}
private:
- // Number of result and temporary buffers for the compiled computation.
- static constexpr size_t kNumTemps = {{TEMP_NUM}};
- // The 0-based index of the result tuple in the temporary buffers.
- static constexpr size_t kResultIndex = {{RESULT_INDEX}};
+ // Number of buffers for the compiled computation.
+ static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
- // Byte size of each result / temporary buffer. There are kNumTemps entries.
- static const intptr_t* TempSizes() {
- static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
- return kTempSizes;
+ static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
+ static const ::tensorflow::cpu_function_runtime::BufferInfo
+ kBufferInfos[kNumBuffers] = {
+{{BUFFER_INFOS_AS_STRING}}
+ };
+ return kBufferInfos;
}
+ static const ::tensorflow::int32* ArgIndexToBufferIndex() {
+ static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
+{{ARG_INDEX_TABLE}}
+ };
+ return kArgIndexToBufferIndex;
+ }
+
+ // The 0-based index of the result tuple in the temporary buffers.
+ static constexpr size_t kResultIndex = {{RESULT_INDEX}};
+
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
@@ -523,8 +569,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
- {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
- {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
+ {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
+ {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
@@ -546,8 +592,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
- {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
- {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
+ {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
+ {"{{BUFFER_INFOS_AS_STRING}}",
+ str_util::Join(buffer_infos_as_strings, ",\n")}};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 29bc9c13b8..60d59ae996 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -32,6 +32,8 @@ namespace tensorflow {
namespace tfcompile {
namespace {
+using ::tensorflow::cpu_function_runtime::BufferInfo;
+
void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
@@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) {
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;
- compile_result.aot.reset(
- new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
+ compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
+ {},
+ {BufferInfo::MakeTempBuffer(1),
+ BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
+ BufferInfo::MakeTempBuffer(2),
+ BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
+ BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
+ 5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index 6641d45e83..e4d8a02877 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = 2;
// Byte size of each argument buffer. There are kNumArgs entries.
- static const intptr_t* ArgSizes() {
- static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96};
- return kArgSizes;
+ static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
+ return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
}
// Returns static data used to create an XlaCompiledCpuFunction.
@@ -75,17 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
- data->raw_function = entry_point;
- data->arg_sizes = ArgSizes();
- data->num_args = kNumArgs;
- data->temp_sizes = TempSizes();
- data->num_temps = kNumTemps;
- data->result_index = kResultIndex;
- data->arg_names = StaticArgNames();
- data->result_names = StaticResultNames();
- data->program_shape = StaticProgramShape();
- data->hlo_profile_printer_data = StaticHloProfilePrinterData();
-
+ data->set_raw_function(entry_point);
+ data->set_buffer_infos(BufferInfos());
+ data->set_num_buffers(kNumBuffers);
+ data->set_arg_index_table(ArgIndexToBufferIndex());
+ data->set_num_args(kNumArgs);
+ data->set_result_index(kResultIndex);
+ data->set_arg_names(StaticArgNames());
+ data->set_result_names(StaticResultNames());
+ data->set_program_shape(StaticProgramShape());
+ data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
+
return data;
}();
return *kStaticData;
@@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}
private:
- // Number of result and temporary buffers for the compiled computation.
- static constexpr size_t kNumTemps = 6;
- // The 0-based index of the result tuple in the temporary buffers.
- static constexpr size_t kResultIndex = 5;
+ // Number of buffers for the compiled computation.
+ static constexpr size_t kNumBuffers = 6;
+
+ static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
+ static const ::tensorflow::cpu_function_runtime::BufferInfo
+ kBufferInfos[kNumBuffers] = {
+::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
+ };
+ return kBufferInfos;
+ }
- // Byte size of each result / temporary buffer. There are kNumTemps entries.
- static const intptr_t* TempSizes() {
- static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
- return kTempSizes;
+ static const ::tensorflow::int32* ArgIndexToBufferIndex() {
+ static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
+1, 3
+ };
+ return kArgIndexToBufferIndex;
}
+ // The 0-based index of the result tuple in the temporary buffers.
+ static constexpr size_t kResultIndex = 5;
+
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
static const char* kNames[] = {"myfeed", nullptr};
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
index 6b098049cb..df966767b3 100644
--- a/tensorflow/compiler/aot/test.cc
+++ b/tensorflow/compiler/aot/test.cc
@@ -51,11 +51,9 @@ namespace tensorflow {
namespace tfcompile {
namespace {
-void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
- for (int i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
- memset(bufs[i], 0, sizes[i]);
- }
+void zero_buffers(void** bufs, const XlaCompiledCpuFunction& computation) {
+ for (int i = 0; i < computation.num_args(); ++i) {
+ memset(bufs[i], 0, computation.arg_size(i));
}
}
@@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+ zero_buffers(computation.args(), computation);
EXPECT_TRUE(computation.Run());
}
@@ -80,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+ zero_buffers(computation.args(), computation);
testing::StartTiming();
while (--iters) {
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 61759fd276..fda32c8a1c 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -95,6 +95,10 @@ cc_library(
name = "cpu_function_runtime",
srcs = ["cpu_function_runtime.cc"],
hdrs = ["cpu_function_runtime.h"],
+ visibility = [
+ "//tensorflow/compiler/aot:__pkg__",
+ "//tensorflow/compiler/xla/service/cpu:__pkg__",
+ ],
deps = [
# Keep dependencies to a minimum here; this library is used in every AOT
# binary produced by tfcompile.
@@ -144,6 +148,7 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_executable",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc
index 2ffad2af8c..fcc4095e39 100644
--- a/tensorflow/compiler/tf2xla/cpu_function_runtime.cc
+++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.cc
@@ -55,19 +55,26 @@ size_t align_to(size_t n, size_t align) {
} // namespace
namespace cpu_function_runtime {
-size_t AlignedBufferBytes(const intptr_t* sizes, size_t n) {
+size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
+ bool allocate_entry_params) {
size_t total = 0;
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] > 0) {
- total += align_to(sizes[i], kAlign);
+ bool should_allocate =
+ buffer_infos[i].is_temp_buffer() ||
+ (buffer_infos[i].is_entry_parameter() && allocate_entry_params);
+
+ if (should_allocate) {
+ total += align_to(buffer_infos[i].size(), kAlign);
}
}
return total;
}
-void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
+void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
+ bool allocate_entry_params, void** bufs,
bool annotate_initialized) {
- const size_t total = AlignedBufferBytes(sizes, n);
+ const size_t total =
+ AlignedBufferBytes(buffer_infos, n, allocate_entry_params);
void* contiguous = nullptr;
if (total > 0) {
contiguous = aligned_malloc(total, kAlign);
@@ -79,13 +86,14 @@ void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
}
uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
for (size_t i = 0; i < n; ++i) {
- if (sizes[i] < 0) {
- // bufs[i] is either a constant, an entry parameter or a thread local
- // allocation.
- bufs[i] = nullptr;
- } else {
+ bool should_allocate =
+ buffer_infos[i].is_temp_buffer() ||
+ (buffer_infos[i].is_entry_parameter() && allocate_entry_params);
+ if (should_allocate) {
bufs[i] = reinterpret_cast<void*>(pos);
- pos += align_to(sizes[i], kAlign);
+ pos += align_to(buffer_infos[i].size(), kAlign);
+ } else {
+ bufs[i] = nullptr;
}
}
return contiguous;
diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime.h b/tensorflow/compiler/tf2xla/cpu_function_runtime.h
index c7b4559c65..dfc1e8b8ae 100644
--- a/tensorflow/compiler/tf2xla/cpu_function_runtime.h
+++ b/tensorflow/compiler/tf2xla/cpu_function_runtime.h
@@ -18,29 +18,142 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
+#include <cassert>
+
namespace tensorflow {
namespace cpu_function_runtime {
+// Stores information about one buffer used by an XLA:CPU compiled function.
+// These buffers are used for holding inputs to the computation, outputs from
+// the computation and as temporary scratch space.
+class BufferInfo {
+ public:
+ // Creates a BufferInfo from a serialized encoding generated by `Encode`.
+ explicit BufferInfo(std::pair<uint64, uint64> encoding)
+ : entry_param_number_(encoding.second) {
+ Kind kind;
+ uint64 size;
+ Unpack(encoding.first, &kind, &size);
+ kind_ = kind;
+ size_ = size;
+ }
+
+ // Returns true if this buffer stores a constant. These never need to be
+ // allocated by the runtime.
+ bool is_constant() const { return kind() == Kind::kConstant; }
+
+ // Returns true if this buffer stores an entry parameter. These may or may
+ // not need to be allocated by the runtime, depending on
+ // XlaCompiledCpuFunction::AllocMode.
+ bool is_entry_parameter() const { return kind() == Kind::kEntryParameter; }
+
+ // Returns the entry parameter number of this buffer.
+ uint64 entry_parameter_number() const {
+ assert(is_entry_parameter());
+ return entry_param_number_;
+ }
+
+ // Returns true if this buffer is temporary scratch space required by the XLA
+ // computations. These are always allocated by the runtime.
+ bool is_temp_buffer() const { return kind() == Kind::kTempBuffer; }
+
+ // Returns true if this buffer is allocated on the C stack or into registers.
+ // These buffers are never allocated by the runtime.
+ bool is_on_stack_buffer() const { return kind() == Kind::kOnStackBuffer; }
+
+ // Returns the size for this buffer.
+ uint64 size() const { return size_; }
+
+ // Encodes this BufferInfo into two 64 bit integers that can be used to
+ // reconstruct the BufferInfo later using the constructor. We need this
+ // because we use BufferInfo in places where using protocol buffers would
+ // negatively impact binary size.
+ std::pair<uint64, uint64> Encode() const {
+ static_assert(sizeof(*this) == 16, "");
+ uint64 upper = Pack(kind(), size_);
+ uint64 lower = entry_param_number_;
+ return {upper, lower};
+ }
+
+ bool operator==(const BufferInfo& buffer_info) const {
+ if (kind() != buffer_info.kind() || size() != buffer_info.size()) {
+ return false;
+ }
+ return !is_entry_parameter() ||
+ entry_parameter_number() == buffer_info.entry_parameter_number();
+ }
+
+ // Factory methods:
+
+ static BufferInfo MakeTempBuffer(uint64 size) {
+ return BufferInfo(Kind::kTempBuffer, /*size=*/size,
+ /*entry_param_number=*/-1);
+ }
+ static BufferInfo MakeConstant(uint64 size) {
+ return BufferInfo(Kind::kConstant, /*size=*/size,
+ /*entry_param_number=*/-1);
+ }
+ static BufferInfo MakeEntryParameter(uint64 size, uint64 param_number) {
+ return BufferInfo(Kind::kEntryParameter, /*size=*/size,
+ /*entry_param_number=*/param_number);
+ }
+ static BufferInfo MakeOnStackBuffer(uint64 size) {
+ return BufferInfo(Kind::kOnStackBuffer, /*size=*/size,
+ /*entry_param_number=*/-1);
+ }
+
+ private:
+ BufferInfo() = default;
+
+ enum class Kind : unsigned {
+ kConstant,
+ kTempBuffer,
+ kEntryParameter,
+ kOnStackBuffer
+ };
+
+ Kind kind() const { return static_cast<Kind>(kind_); }
+
+ explicit BufferInfo(Kind kind, uint64 size, uint64 entry_param_number)
+ : kind_(kind), size_(size), entry_param_number_(entry_param_number) {}
+
+ static uint64 Pack(Kind kind, uint64 size) {
+ return (static_cast<uint64>(size) << 2) | static_cast<uint64>(kind);
+ }
+
+ static void Unpack(uint64 packed, Kind* kind, uint64* size) {
+ *size = packed >> 2;
+ *kind = static_cast<Kind>((packed << 62) >> 62);
+ }
+
+ Kind kind_ : 2;
+ uint64 size_ : 62;
+ int64 entry_param_number_;
+};
// Align to 64-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
constexpr size_t kAlign = 64;
-// AlignedBufferBytes returns the sum of each size in `sizes`, skipping -1
-// values. There are `n` entries in `sizes`. Each buffer is aligned to
-// kAlign byte boundaries.
-size_t AlignedBufferBytes(const intptr_t* sizes, size_t n);
+// AlignedBufferBytes returns the sum of the size of each buffer in
+// `buffer_infos`, skipping constants, on-stack buffers and, if
+// allocate_entry_params is false, entry parameters. There are `n` entries in
+// `buffer_infos`. Each buffer is aligned to kAlign byte boundaries.
+size_t AlignedBufferBytes(const BufferInfo* buffer_infos, size_t n,
+ bool allocate_entry_params);
// MallocContiguousBuffers allocates buffers for use by the entry point
-// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
-// where -1 causes the buffer pointer to be nullptr. There are `n` entries in
-// `sizes`. If `annotate_initialized` is set, the allocated memory will be
-// annotated as having been initialized - this is useful when allocating
-// temporary buffers.
+// generated by tfcompile. There are `n` entries in `buffer_infos`. If
+// `annotate_initialized` is set, the allocated memory will be annotated as
+// having been initialized - this is useful when allocating temporary buffers.
+// If allocate_entry_params is true then allocates temp buffers and entry
+// parameters, otherwise allocated only temp buffers. Slots in `bufs`
+// corresponding to unallocated buffers are set to nullptr.
//
// A single contiguous block of memory is allocated, and portions of it are
// parceled out into `bufs`, which must have space for `n` entries. Returns
// the head of the allocated contiguous block, which should be passed to
// FreeContiguous when the buffers are no longer in use.
-void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
+void* MallocContiguousBuffers(const BufferInfo* buffer_infos, size_t n,
+ bool allocate_entry_params, void** bufs,
bool annotate_initialized);
// FreeContiguous frees the contiguous block of memory allocated by
diff --git a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc
index f4f27a1562..8ca628c4eb 100644
--- a/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc
+++ b/tensorflow/compiler/tf2xla/cpu_function_runtime_test.cc
@@ -21,6 +21,8 @@ limitations under the License.
namespace tensorflow {
namespace {
+using cpu_function_runtime::BufferInfo;
+
TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
// We've chosen 64 byte alignment for the tfcompile runtime to mimic the
// regular tensorflow allocator, which was chosen to play nicely with Eigen.
@@ -30,20 +32,51 @@ TEST(XlaCompiledCpuFunctionTest, AlignmentValue) {
EXPECT_EQ(cpu_function_runtime::kAlign, Allocator::kAllocatorAlignment);
}
+std::vector<BufferInfo> SizesToBufferInfos(const intptr_t* sizes, size_t n) {
+ std::vector<BufferInfo> buffer_infos;
+ std::transform(sizes, sizes + n, std::back_inserter(buffer_infos),
+ [&](intptr_t size) {
+ if (size == -1) {
+ // Use a dummy on-stack buffer allocation to indicat the
+ // the current slot does not need an allocation.
+ int64 on_stack_buffer_size = 4;
+ return BufferInfo::MakeOnStackBuffer(on_stack_buffer_size);
+ }
+ return BufferInfo::MakeTempBuffer(size);
+ });
+ return buffer_infos;
+}
+
+// Simple wrappers to make writing tests more ergonomic.
+
+size_t AlignedBufferBytesFromSizes(const intptr_t* sizes, size_t n) {
+ std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
+ return AlignedBufferBytes(buffer_infos.data(), n,
+ /*allocate_entry_params=*/false);
+}
+
+void* MallocContiguousBuffersFromSizes(const intptr_t* sizes, size_t n,
+ void** bufs, bool annotate_initialized) {
+ std::vector<BufferInfo> buffer_infos = SizesToBufferInfos(sizes, n);
+ return MallocContiguousBuffers(buffer_infos.data(), n,
+ /*allocate_entry_params=*/false, bufs,
+ annotate_initialized);
+}
+
TEST(XlaCompiledCpuFunctionTest, AlignedBufferBytes) {
- EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(nullptr, 0), 0);
+ EXPECT_EQ(AlignedBufferBytesFromSizes(nullptr, 0), 0);
static constexpr intptr_t sizesA[1] = {-1};
- EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesA, 1), 0);
+ EXPECT_EQ(AlignedBufferBytesFromSizes(sizesA, 1), 0);
static constexpr intptr_t sizesB[1] = {3};
- EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesB, 1), 64);
+ EXPECT_EQ(AlignedBufferBytesFromSizes(sizesB, 1), 64);
static constexpr intptr_t sizesC[1] = {32};
- EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesC, 1), 64);
+ EXPECT_EQ(AlignedBufferBytesFromSizes(sizesC, 1), 64);
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
- EXPECT_EQ(cpu_function_runtime::AlignedBufferBytes(sizesD, 7), 320);
+ EXPECT_EQ(AlignedBufferBytesFromSizes(sizesD, 7), 320);
}
void* add_ptr(void* base, uintptr_t delta) {
@@ -56,15 +89,14 @@ void* add_ptr(void* base, uintptr_t delta) {
// free. We also check the contiguous property.
TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test empty sizes.
- void* base =
- cpu_function_runtime::MallocContiguousBuffers(nullptr, 0, nullptr, false);
+ void* base = MallocContiguousBuffersFromSizes(nullptr, 0, nullptr, false);
EXPECT_EQ(base, nullptr);
cpu_function_runtime::FreeContiguous(base);
// Test non-empty sizes with 0 sum.
static constexpr intptr_t sizesA[1] = {-1};
void* bufA[1];
- base = cpu_function_runtime::MallocContiguousBuffers(sizesA, 1, bufA, false);
+ base = MallocContiguousBuffersFromSizes(sizesA, 1, bufA, false);
EXPECT_EQ(base, nullptr);
EXPECT_EQ(bufA[0], nullptr);
cpu_function_runtime::FreeContiguous(base);
@@ -72,7 +104,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test non-empty sizes with non-0 sum.
static constexpr intptr_t sizesB[1] = {3};
void* bufB[1];
- base = cpu_function_runtime::MallocContiguousBuffers(sizesB, 1, bufB, false);
+ base = MallocContiguousBuffersFromSizes(sizesB, 1, bufB, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufB[0], add_ptr(base, 0));
char* bufB0_bytes = static_cast<char*>(bufB[0]);
@@ -84,7 +116,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test non-empty sizes with non-0 sum, and annotate_initialized.
static constexpr intptr_t sizesC[1] = {3};
void* bufC[1];
- base = cpu_function_runtime::MallocContiguousBuffers(sizesC, 1, bufC, true);
+ base = MallocContiguousBuffersFromSizes(sizesC, 1, bufC, true);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufC[0], add_ptr(base, 0));
char* bufC0_bytes = static_cast<char*>(bufC[0]);
@@ -96,7 +128,7 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
// Test mixed sizes.
static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
void* bufD[7];
- base = cpu_function_runtime::MallocContiguousBuffers(sizesD, 7, bufD, false);
+ base = MallocContiguousBuffersFromSizes(sizesD, 7, bufD, false);
EXPECT_NE(base, nullptr);
EXPECT_EQ(bufD[0], add_ptr(base, 0));
EXPECT_EQ(bufD[1], nullptr);
@@ -117,5 +149,23 @@ TEST(XlaCompiledCpuFunctionTest, MallocFreeContiguousBuffers) {
cpu_function_runtime::FreeContiguous(base);
}
+void CheckRoundTripIsOk(const BufferInfo& buffer_info) {
+ BufferInfo round_trip(buffer_info.Encode());
+ ASSERT_EQ(round_trip, buffer_info);
+}
+
+TEST(XlaCompiledCpuFunctionTest, BufferInfoTest) {
+ CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(0));
+ CheckRoundTripIsOk(BufferInfo::MakeTempBuffer(4));
+ CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(0));
+ CheckRoundTripIsOk(BufferInfo::MakeOnStackBuffer(4));
+ CheckRoundTripIsOk(BufferInfo::MakeConstant(0));
+ CheckRoundTripIsOk(BufferInfo::MakeConstant(4));
+ CheckRoundTripIsOk(
+ BufferInfo::MakeEntryParameter(/*size=*/0, /*param_number=*/4));
+ CheckRoundTripIsOk(
+ BufferInfo::MakeEntryParameter(/*size=*/4, /*param_number=*/0));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index 334459138b..09c5d1dd19 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
-#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include <cassert>
@@ -22,61 +21,55 @@ namespace tensorflow {
XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
AllocMode alloc_mode)
- : raw_function_(static_data.raw_function),
- result_index_(static_data.result_index),
- args_(new void*[static_data.num_args]),
- temps_(new void*[static_data.num_temps]),
- arg_index_to_temp_index_(new int32[static_data.num_args]),
- num_args_(static_data.num_args),
- arg_names_(static_data.arg_names),
- result_names_(static_data.result_names),
- program_shape_(static_data.program_shape),
- hlo_profile_printer_data_(static_data.hlo_profile_printer_data) {
+ : raw_function_(static_data.raw_function_),
+ result_index_(static_data.result_index_),
+ args_(new void*[static_data.num_args_]),
+ buffer_table_(new void*[static_data.num_buffers_]),
+ buffer_infos_(static_data.buffer_infos_),
+ arg_index_table_(static_data.arg_index_table_),
+ num_args_(static_data.num_args_),
+ arg_names_(static_data.arg_names_),
+ result_names_(static_data.result_names_),
+ program_shape_(static_data.program_shape_),
+ hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
+ bool allocate_entry_params =
+ alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS;
// Allocate arg and temp buffers.
- if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
- alloc_args_ = cpu_function_runtime::MallocContiguousBuffers(
- static_data.arg_sizes, static_data.num_args, args_,
- /*annotate_initialized=*/false);
- }
- alloc_temps_ = cpu_function_runtime::MallocContiguousBuffers(
- static_data.temp_sizes, static_data.num_temps, temps_,
+ alloc_buffer_table_ = cpu_function_runtime::MallocContiguousBuffers(
+ static_data.buffer_infos_, static_data.num_buffers_,
+ /*allocate_entry_params=*/allocate_entry_params, buffer_table_,
/*annotate_initialized=*/true);
-
- for (int i = 0; i < static_data.num_temps; i++) {
- if (static_data.temp_sizes[i] < -1) {
- int32 param_number = -(static_data.temp_sizes[i] + 2);
- arg_index_to_temp_index_[param_number] = i;
+ if (allocate_entry_params) {
+ for (int32 i = 0; i < num_args_; i++) {
+ args_[i] = buffer_table_[arg_index_table_[i]];
}
}
-
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
// signature, but it is ignored by the generated code and we pass in null for
// it.
if (hlo_profiling_enabled()) {
- profile_counters_ = new int64[static_data.profile_counters_size]();
+ profile_counters_ = new int64[static_data.profile_counters_size_]();
}
}
bool XlaCompiledCpuFunction::Run() {
- // Propagate pointers to the argument buffers into the temps array. Code
- // generated by XLA discovers the incoming argument pointers from the temps
- // array.
+ // Propagate pointers to the argument buffers into the buffer table. Code
+ // generated by XLA discovers the incoming argument pointers from the buffer
+ // table.
for (int32 i = 0; i < num_args_; i++) {
- temps_[arg_index_to_temp_index_[i]] = args_[i];
+ buffer_table_[arg_index_table_[i]] = args_[i];
}
- raw_function_(temps_[result_index_], &run_options_, nullptr, temps_,
- profile_counters_);
+ raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
+ buffer_table_, profile_counters_);
return true;
}
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
- cpu_function_runtime::FreeContiguous(alloc_args_);
- cpu_function_runtime::FreeContiguous(alloc_temps_);
+ cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
delete[] args_;
- delete[] temps_;
- delete[] arg_index_to_temp_index_;
+ delete[] buffer_table_;
delete[] profile_counters_;
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 27cfb354bf..7dd8c24eb7 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <cassert>
#include <string>
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/core/platform/types.h"
@@ -56,46 +57,85 @@ class XlaCompiledCpuFunction {
// StaticData represents the state necessary to run an XLA-compiled
// function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
// AOT this is backed by data compiled into the object file.
- struct StaticData {
+ //
+ // The contents of StaticData are XLA-internal implementation details and
+ // should not be relied on by clients.
+ //
+ // TODO(sanjoy): Come up with a cleaner way to express the contraint we want
+ // here: generated XlaCompiledCpuFunction subclasses should be able to create
+ // instances of StaticData but only XlaCompiledCpuFunction should be able to
+ // read from StaticData instances.
+ class StaticData {
+ public:
+ void set_raw_function(RawFunction raw_function) {
+ raw_function_ = raw_function;
+ }
+ void set_buffer_infos(
+ const cpu_function_runtime::BufferInfo* buffer_infos) {
+ buffer_infos_ = buffer_infos;
+ }
+ void set_num_buffers(size_t num_buffers) { num_buffers_ = num_buffers; }
+ void set_arg_index_table(const int32* arg_index_table) {
+ arg_index_table_ = arg_index_table;
+ }
+ void set_num_args(int64 num_args) { num_args_ = num_args; }
+ void set_result_index(size_t result_index) { result_index_ = result_index; }
+ void set_arg_names(const char** arg_names) { arg_names_ = arg_names; }
+ void set_result_names(const char** result_names) {
+ result_names_ = result_names;
+ }
+ void set_program_shape(const xla::ProgramShape* program_shape) {
+ program_shape_ = program_shape;
+ }
+ const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
+ return hlo_profile_printer_data_;
+ }
+ void set_hlo_profile_printer_data(
+ const xla::HloProfilePrinterData* hlo_profile_printer_data) {
+ hlo_profile_printer_data_ = hlo_profile_printer_data;
+ }
+ void set_profile_counters_size(int64 profile_counters_size) {
+ profile_counters_size_ = profile_counters_size;
+ }
+
+ private:
// The raw function to call.
- RawFunction raw_function;
-
- // Cardinality and size of arg buffers.
- const intptr_t* arg_sizes = nullptr;
- size_t num_args = 0;
-
- // Cardinality and size of temp buffers.
- //
- // If temp_sizes[i] >= 0 then the i'th temp is a regular temporary buffer.
- //
- // If temp_sizes[i] == -1 then the i'th temp is a constant buffer. The
- // corresponding entry in the temp buffer array needs to be set to null.
- //
- // If temp_sizes[i] < -1 then the i'th temp is the entry parameter
- // -(temp_sizes[i] + 2).
- const intptr_t* temp_sizes = nullptr;
- size_t num_temps = 0;
+ RawFunction raw_function_;
+
+ // Contains information about the buffers used by the XLA computation.
+ const cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
+ size_t num_buffers_ = 0;
+
+ // Entry parameter i is described by
+ // buffer_infos[arg_index_table[i]].
+ const int32* arg_index_table_ = nullptr;
+
+ // There are num_args entry parameters.
+ int64 num_args_ = 0;
// The 0-based index of the result tuple, in the temp buffers.
- size_t result_index = 0;
+ size_t result_index_ = 0;
// [Optional] Arrays of arg and result names. These are arrays of C-style
// strings, where the array is terminated by nullptr.
- const char** arg_names = nullptr;
- const char** result_names = nullptr;
+ const char** arg_names_ = nullptr;
+ const char** result_names_ = nullptr;
// [Optional] Arg and result shapes.
- const xla::ProgramShape* program_shape = nullptr;
+ const xla::ProgramShape* program_shape_ = nullptr;
// [Optional] Profile printer data. Null if profiling is disabled.
- const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr;
+ const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
// [Optional] The number of profile counters expected in the profile counter
// buffer by the generated code and hlo_profile_printer. 0 if profiling is
// disabled. This information is already present in
// hlo_profile_printer_data but xla::HloProfilePrinterData is forward
// declared so we don't have access to that information here.
- int64 profile_counters_size = 0;
+ int64 profile_counters_size_ = 0;
+
+ // Only XlaCompiledCpuFunction is allowed to read the above fields.
+ friend class XlaCompiledCpuFunction;
};
// AllocMode controls the buffer allocation mode.
@@ -137,6 +177,9 @@ class XlaCompiledCpuFunction {
// Returns the underlying array of argument buffers, where args()[I] is the
// buffer for the positional argument at index I.
+ //
+ // TODO(sanjoy): We should retire this in favor of explicit accessors. That
+ // would let us elide the args_ array.
void** args() { return args_; }
const void* const* args() const { return args_; }
@@ -144,6 +187,18 @@ class XlaCompiledCpuFunction {
void* arg_data(size_t index) { return args_[index]; }
const void* arg_data(size_t index) const { return args_[index]; }
+ int num_args() const { return num_args_; }
+
+ // Returns the size of entry parameter `idx`.
+ //
+ // There is a static version of this method on tfcompile generated subclasses
+ // of XlaCompiledCpuFunction, but try to prefer this when possible since it
+ // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses.
+ int arg_size(int idx) const {
+ assert(idx < num_args());
+ return buffer_infos_[arg_index_table_[idx]].size();
+ }
+
// Sets the buffer for the positional argument at the given `index` to `data`.
// Must be called before Run to have an effect. May be called under any
// AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be
@@ -165,9 +220,9 @@ class XlaCompiledCpuFunction {
// Returns the underlying array of result buffers, where results()[I] is the
// buffer for the positional result at index I.
- void** results() { return static_cast<void**>(temps_[result_index_]); }
+ void** results() { return static_cast<void**>(buffer_table_[result_index_]); }
const void* const* results() const {
- return static_cast<const void* const*>(temps_[result_index_]);
+ return static_cast<const void* const*>(buffer_table_[result_index_]);
}
// Profile counters for this XLA computation.
@@ -225,25 +280,31 @@ class XlaCompiledCpuFunction {
const RawFunction raw_function_;
const size_t result_index_;
- // Arrays of argument and temp buffers; entries in args_ may be overwritten by
- // the user.
- void** args_ = nullptr;
- void** temps_ = nullptr;
+ // Array of argument buffers; entries in args_ may be overwritten by the user.
+ void** const args_;
+
+ // Array containing pointers to argument and temp buffers (slots corresponding
+ // to constant and on-stack buffers are null).
+ void** const buffer_table_;
+
+ // Describes the buffers used by the XLA computation.
+ const cpu_function_runtime::BufferInfo* const buffer_infos_;
- // Argument i needs to be placed in temps_[arg_index_to_temp_index_[i]] for
- // XLA generated code to be able to find it.
+ // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]]
+ // for XLA generated code to be able to find it.
//
// For now we need to keep around the args_ array because there is code that
// depends on args() returning a void**. However, in the future we may remove
- // args_ in favor of using temps_ as the sole storage for the arguments.
- int32* arg_index_to_temp_index_;
+ // args_ in favor of using buffer_table_ as the sole storage for the
+ // arguments.
+ const int32* const arg_index_table_;
// The number of incoming arguments.
- int32 num_args_;
+ const int32 num_args_;
- // Backing memory for individual arg and temp buffers.
- void* alloc_args_ = nullptr;
- void* alloc_temps_ = nullptr;
+ // Backing memory for buffer_table_ and args_, the latter depending on
+ // AllocMode.
+ void* alloc_buffer_table_ = nullptr;
// Backing memory for profiling counters.
int64* profile_counters_ = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 114a9241bd..86a78ee429 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
+#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -35,45 +36,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-
-// Returns a vector of positional argument buffer sizes.
-xla::StatusOr<std::vector<intptr_t>> ComputeArgSizes(
- const xla::ProgramShape& program_shape) {
- std::vector<intptr_t> arg_sizes;
- const size_t num_args = program_shape.parameters_size();
- arg_sizes.reserve(num_args);
- for (int i = 0; i < num_args; ++i) {
- const xla::Shape& arg_shape = program_shape.parameters(i);
- constexpr size_t kPointerSize = sizeof(void*);
- arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize));
- }
- return std::move(arg_sizes);
-}
-
-// Returns a vector of positional temporary buffer sizes.
-xla::StatusOr<std::vector<intptr_t>> ComputeTempSizes(
- const xla::BufferAssignment& buffer_assignment) {
- const std::vector<xla::BufferAllocation>& allocations =
- buffer_assignment.Allocations();
- std::vector<intptr_t> temp_sizes;
- temp_sizes.reserve(allocations.size());
- for (const xla::BufferAllocation& allocation : allocations) {
- if (allocation.is_constant() || allocation.is_thread_local()) {
- // Constants are lowered to globals. Thread locals are lowered to
- // allocas.
- temp_sizes.push_back(-1);
- } else if (allocation.is_entry_computation_parameter()) {
- // Entry computation parameters need some preprocessing in
- // XlaCompiledCpuFunction::Run. See the comment on
- // XlaCompiledCpuFunction::StaticData::temp_sizes.
- temp_sizes.push_back(-allocation.parameter_number() - 2);
- } else {
- temp_sizes.push_back(allocation.size());
- }
- }
- return std::move(temp_sizes);
-}
-
// Returns the index of the result in the temp buffers.
xla::StatusOr<size_t> ComputeResultIndex(
const xla::BufferAssignment& buffer_assignment) {
@@ -157,11 +119,11 @@ XlaJitCompiledCpuFunction::Compile(
const xla::BufferAssignment& buffer_assignment =
cpu_executable->buffer_assignment();
- // Compute buffer sizes and the result index, needed to run the raw function.
- TF_ASSIGN_OR_RETURN(std::vector<intptr_t> arg_sizes,
- ComputeArgSizes(*program_shape));
- TF_ASSIGN_OR_RETURN(std::vector<intptr_t> temp_sizes,
- ComputeTempSizes(buffer_assignment));
+ // Compute buffer infos and the result index, needed to run the raw function.
+ std::vector<cpu_function_runtime::BufferInfo> buffer_infos =
+ xla::cpu::CreateBufferInfosFromBufferAssignment(buffer_assignment);
+ std::vector<int32> arg_index_table =
+ xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
TF_ASSIGN_OR_RETURN(size_t result_index,
ComputeResultIndex(buffer_assignment));
@@ -169,28 +131,28 @@ XlaJitCompiledCpuFunction::Compile(
new XlaJitCompiledCpuFunction);
XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get();
jit->executable_ = std::move(executable);
- jit->arg_sizes_ = std::move(arg_sizes);
- jit->temp_sizes_ = std::move(temp_sizes);
+ jit->buffer_infos_ = std::move(buffer_infos);
+ jit->arg_index_table_ = std::move(arg_index_table);
jit->program_shape_ = std::move(program_shape);
- jit->static_data_.raw_function = std::move(raw_function);
- jit->static_data_.arg_sizes = jit->arg_sizes_.data();
- jit->static_data_.num_args = jit->arg_sizes_.size();
- jit->static_data_.temp_sizes = jit->temp_sizes_.data();
- jit->static_data_.num_temps = jit->temp_sizes_.size();
- jit->static_data_.result_index = result_index;
+ jit->static_data_.set_raw_function(raw_function);
+ jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
+ jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
+ jit->static_data_.set_arg_index_table(jit->arg_index_table_.data());
+ jit->static_data_.set_num_args(jit->arg_index_table_.size());
+ jit->static_data_.set_result_index(result_index);
// Optional metadata is collected and set below.
CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_);
CollectNames(config.fetch(), &jit->nonempty_result_names_,
&jit->result_names_);
- jit->static_data_.arg_names = jit->arg_names_.data();
- jit->static_data_.result_names = jit->result_names_.data();
- jit->static_data_.program_shape = jit->program_shape_.get();
+ jit->static_data_.set_arg_names(jit->arg_names_.data());
+ jit->static_data_.set_result_names(jit->result_names_.data());
+ jit->static_data_.set_program_shape(jit->program_shape_.get());
if (cpu_executable->hlo_profiling_enabled()) {
- jit->static_data_.hlo_profile_printer_data =
- &cpu_executable->hlo_profile_printer_data();
- jit->static_data_.profile_counters_size =
- cpu_executable->hlo_profile_printer_data().profile_counters_size();
+ jit->static_data_.set_hlo_profile_printer_data(
+ &cpu_executable->hlo_profile_printer_data());
+ jit->static_data_.set_profile_counters_size(
+ cpu_executable->hlo_profile_printer_data().profile_counters_size());
}
return std::move(jit_unique_ptr);
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
index af307ae4ef..d3c8f22a80 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h
@@ -66,9 +66,11 @@ class XlaJitCompiledCpuFunction {
// The static data is backed by the rest of the state in this class.
XlaCompiledCpuFunction::StaticData static_data_;
- // The backing arrays of arg and temp buffer sizes.
- std::vector<intptr_t> arg_sizes_;
- std::vector<intptr_t> temp_sizes_;
+ // The backing array for buffer infos.
+ std::vector<cpu_function_runtime::BufferInfo> buffer_infos_;
+
+ // The backing array for the arg index table.
+ std::vector<int32> arg_index_table_;
// The backing arrays of arg and result names. We hold the actual strings in
// nonempty_*_names_, and hold arrays of pointers in *_names_ for the static
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
index 99abb9bae3..34f7fe12ca 100644
--- a/tensorflow/compiler/xla/service/compiler.h
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -48,11 +48,6 @@ namespace xla {
// compuation.
using ObjectFileData = std::vector<char>;
-// Contains the buffer sizes information needed to allocate buffers to execute
-// an ahead-of-time computation. Entries which contain -1 designate a parameter
-// which should be skipped over during allocation.
-using BufferSizes = std::vector<int64>;
-
// Abstract superclass describing the result of an ahead-of-time compilation.
class AotCompilationResult {
public:
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 504b61d134..3efe3e2f93 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -55,11 +55,23 @@ cc_library(
)
cc_library(
+ name = "buffer_info_util",
+ srcs = ["buffer_info_util.cc"],
+ hdrs = ["buffer_info_util.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:cpu_function_runtime",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "cpu_compiler",
srcs = ["cpu_compiler.cc"],
hdrs = ["cpu_compiler.h"],
deps = [
":compiler_functor",
+ ":buffer_info_util",
":conv_canonicalization",
":cpu_copy_insertion",
":cpu_executable",
@@ -73,6 +85,7 @@ cc_library(
":ir_emitter",
":parallel_task_assignment",
":simple_orc_jit",
+ "//tensorflow/compiler/tf2xla:cpu_function_runtime",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
new file mode 100644
index 0000000000..408fe0f5bf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.cc
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
+
+namespace xla {
+namespace cpu {
+
+using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
+
+std::vector<BufferInfo> CreateBufferInfosFromBufferAssignment(
+ const BufferAssignment& buffer_assignment) {
+ std::vector<BufferInfo> buffer_infos;
+ for (const BufferAllocation& allocation : buffer_assignment.Allocations()) {
+ if (allocation.is_thread_local()) {
+ buffer_infos.push_back(BufferInfo::MakeOnStackBuffer(allocation.size()));
+ } else if (allocation.is_constant()) {
+ buffer_infos.push_back(BufferInfo::MakeConstant(allocation.size()));
+ } else if (allocation.is_entry_computation_parameter()) {
+ buffer_infos.push_back(BufferInfo::MakeEntryParameter(
+ /*size=*/allocation.size(),
+ /*param_number=*/allocation.parameter_number()));
+ } else {
+ buffer_infos.push_back(BufferInfo::MakeTempBuffer(allocation.size()));
+ }
+ }
+ return buffer_infos;
+}
+
+std::vector<int32> CreateArgIndexTableFromBufferInfos(
+ tensorflow::gtl::ArraySlice<BufferInfo> buffer_infos) {
+ std::vector<int32> result;
+ for (int64 i = 0; i < buffer_infos.size(); i++) {
+ if (buffer_infos[i].is_entry_parameter()) {
+ if (buffer_infos[i].entry_parameter_number() >= result.size()) {
+ result.resize(buffer_infos[i].entry_parameter_number() + 1);
+ }
+ result[buffer_infos[i].entry_parameter_number()] = i;
+ }
+ }
+ return result;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/buffer_info_util.h b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
new file mode 100644
index 0000000000..05de70c726
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/buffer_info_util.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
+
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+namespace cpu {
+// Creates and returns a list of BufferInfo instances containing relevant
+// information from `buffer_assignment`.
+std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
+CreateBufferInfosFromBufferAssignment(
+ const BufferAssignment& buffer_assignment);
+
+// Creates and returns a table containing the mapping from entry computation
+// parameters to buffer allocation indices.
+//
+// If this function returns V then entry parameter i has buffer allocation index
+// V[i].
+std::vector<int32> CreateArgIndexTableFromBufferInfos(
+ tensorflow::gtl::ArraySlice<::tensorflow::cpu_function_runtime::BufferInfo>
+ buffer_infos);
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_BUFFER_INFO_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 8cbe9a1b0d..2df959c4dc 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
+#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h"
@@ -103,6 +104,7 @@ limitations under the License.
namespace xla {
namespace cpu {
+using BufferInfo = ::tensorflow::cpu_function_runtime::BufferInfo;
CpuAotCompilationOptions::CpuAotCompilationOptions(
string triple, string cpu_name, string features, string entry_point_name,
@@ -120,11 +122,11 @@ se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
}
CpuAotCompilationResult::CpuAotCompilationResult(
- ObjectFileData object_file_data, BufferSizes buffer_sizes,
+ ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
int64 result_buffer_index,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
: object_file_data_(std::move(object_file_data)),
- buffer_sizes_(std::move(buffer_sizes)),
+ buffer_infos_(std::move(buffer_infos)),
result_buffer_index_(result_buffer_index),
hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
@@ -838,39 +840,14 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
ObjectFileData object_file_data(object_file->getBufferStart(),
object_file->getBufferEnd());
- BufferSizes buffer_sizes;
- for (const BufferAllocation& allocation : assignment->Allocations()) {
- // Callers don't need to allocate anything for thread-local temporary
- // buffers. They are lowered to allocas.
- if (allocation.is_thread_local()) {
- buffer_sizes.push_back(-1);
- continue;
- }
-
- // Callers don't need to allocate anything for constant buffers. They are
- // lowered to globals.
- if (allocation.is_constant()) {
- buffer_sizes.push_back(-1);
- continue;
- }
-
- // Callers don't need to allocate anything for entry computation buffers,
- // but they do need to stash the pointer to the entry computation buffer
- // in the temp buffer table. See the comment on
- // XlaCompiledCpuFunction::StaticData::temp_sizes.
- if (allocation.is_entry_computation_parameter()) {
- buffer_sizes.push_back(-allocation.parameter_number() - 2);
- continue;
- }
-
- buffer_sizes.push_back(allocation.size());
- }
+ std::vector<BufferInfo> buffer_infos =
+ CreateBufferInfosFromBufferAssignment(*assignment);
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment->GetUniqueTopLevelOutputSlice());
results.emplace_back(MakeUnique<CpuAotCompilationResult>(
- std::move(object_file_data), std::move(buffer_sizes),
+ std::move(object_file_data), std::move(buffer_infos),
result_slice.index(), std::move(hlo_profile_printer_data)));
}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
index e56f9f0113..04e1c48872 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include "llvm/Target/TargetMachine.h"
+#include "tensorflow/compiler/tf2xla/cpu_function_runtime.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/llvm_compiler.h"
@@ -78,7 +79,8 @@ class CpuAotCompilationOptions : public AotCompilationOptions {
class CpuAotCompilationResult : public AotCompilationResult {
public:
CpuAotCompilationResult(
- ObjectFileData object_file_data, BufferSizes buffer_sizes,
+ ObjectFileData object_file_data,
+ std::vector<::tensorflow::cpu_function_runtime::BufferInfo> buffer_infos,
int64 result_buffer_index,
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
~CpuAotCompilationResult();
@@ -88,17 +90,20 @@ class CpuAotCompilationResult : public AotCompilationResult {
}
const ObjectFileData& object_file_data() const { return object_file_data_; }
- const BufferSizes& buffer_sizes() const { return buffer_sizes_; }
+ const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>&
+ buffer_infos() const {
+ return buffer_infos_;
+ }
int64 result_buffer_index() const { return result_buffer_index_; }
private:
// Contains the compiled computation: an object file.
const ObjectFileData object_file_data_;
- // The list of buffer sizes which should be allocated in order to execute the
- // compiled computation. These buffers are used for temporary buffers used
- // ephemerally during computation as well as the output result.
- const BufferSizes buffer_sizes_;
+ // A list of BufferInfo objects describing the buffers used by the XLA
+ // computation.
+ const std::vector<::tensorflow::cpu_function_runtime::BufferInfo>
+ buffer_infos_;
// Contains which buffer index into |buffer_sizes| was designated to the
// result of the computation. This buffer should be passed into the output
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index e310966d8b..60eb21aafd 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -92,10 +92,10 @@ int main(int argc, char** argv) {
// It's lame to hard-code the buffer assignments, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
CHECK_EQ(result->result_buffer_index(), 1);
- CHECK_EQ(result->buffer_sizes().size(), 3);
- CHECK_EQ(result->buffer_sizes()[0], -2); // param buffer
- CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer
- CHECK_EQ(result->buffer_sizes()[2], -1); // const buffer
+ CHECK_EQ(result->buffer_infos().size(), 3);
+ CHECK(result->buffer_infos()[0].is_entry_parameter()); // param buffer
+ CHECK_EQ(result->buffer_infos()[1].size(), sizeof(float)); // result buffer
+ CHECK(result->buffer_infos()[2].is_constant()); // const buffer
if (triple.isOSBinFormatELF()) {
// Check the ELF magic.
CHECK_EQ(result->object_file_data()[0], 0x7F);