aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-10 16:02:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 16:06:59 -0700
commitf51fc5f89a3fa934e078b35d50b26625a9ff42cf (patch)
treea2a63482390f20302eead8835a80e0cb24d6a734 /tensorflow/compiler/aot
parent8d532ac40f4db7f5293610fd3c6e92a3f7409b76 (diff)
Introduce and use a BufferInfo class.
The BufferInfo represents information about buffer assignment in XlaCompiledCpuFunction. Arg sizes and temp sizes are now derived from BufferInfo instead of being discrete sources of information. Also made StaticData() private, tfcompile clients should not need to access it directly. PiperOrigin-RevId: 208283305
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/BUILD1
-rw-r--r--tensorflow/compiler/aot/codegen.cc169
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc12
-rw-r--r--tensorflow/compiler/aot/codegen_test_h.golden58
-rw-r--r--tensorflow/compiler/aot/test.cc12
5 files changed, 160 insertions, 92 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index d2f803bd18..1899a32e4d 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -48,6 +48,7 @@ cc_library(
"//tensorflow/compiler/xla/client:compile_only_client",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service/cpu:buffer_info_util",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework_internal",
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index 8dbe1e11b7..89fefdad54 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/str_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -36,6 +37,8 @@ namespace tfcompile {
namespace {
+using BufferInfo = cpu_function_runtime::BufferInfo;
+
bool IsAlpha(char c) {
return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
}
@@ -85,27 +88,36 @@ Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
return Status::OK();
}
-// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
-// values. There are `n` entries in `sizes`.
-size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
- size_t total = 0;
- for (size_t i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
- total += sizes[i];
- }
- }
- return total;
+// Returns the sum of the size of each buffer in `buffer_infos`.
+size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
+ return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
+ [](size_t size, const BufferInfo& buffer_info) {
+ return size + buffer_info.size();
+ });
}
-// Fills in arg_sizes with the byte size of each positional arg.
-Status ComputeArgSizes(const CompileResult& compile_result,
- std::vector<int64>* arg_sizes) {
- const xla::ProgramShape& ps = compile_result.program_shape;
- for (int i = 0; i < ps.parameters_size(); ++i) {
- arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
- ps.parameters(i), compile_result.pointer_size));
- }
- return Status::OK();
+// Returns a vector of BufferInfo instances in `buffer_infos` that are entry
+// parameter buffers.
+std::vector<BufferInfo> ExtractEntryParamBufferInfos(
+ const std::vector<BufferInfo>& buffer_infos) {
+ std::vector<BufferInfo> result;
+ std::copy_if(buffer_infos.begin(), buffer_infos.end(),
+ std::back_inserter(result), [](const BufferInfo& buffer_info) {
+ return buffer_info.is_entry_parameter();
+ });
+ return result;
+}
+
+// Returns a vector of BufferInfo instances in `buffer_infos` that are temp
+// buffers.
+std::vector<BufferInfo> ExtractTempBufferInfos(
+ const std::vector<BufferInfo>& buffer_infos) {
+ std::vector<BufferInfo> result;
+ std::copy_if(buffer_infos.begin(), buffer_infos.end(),
+ std::back_inserter(result), [](const BufferInfo& buffer_info) {
+ return buffer_info.is_temp_buffer();
+ });
+ return result;
}
// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
@@ -278,6 +290,25 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
return Status::OK();
}
+// Returns a list of C++ expressions that, when executed, will construct the
+// BufferInfo instances in `buffer_infos`.
+std::vector<string> BufferInfosToCppExpression(
+ const std::vector<BufferInfo>& buffer_infos) {
+ std::vector<string> buffer_infos_as_strings;
+ std::transform(buffer_infos.begin(), buffer_infos.end(),
+ std::back_inserter(buffer_infos_as_strings),
+ [](const BufferInfo& buffer_info) {
+ std::pair<uint64, uint64> encoded = buffer_info.Encode();
+ string encoded_second_as_str =
+ encoded.second == ~0ULL
+ ? "~0ULL"
+ : strings::StrCat(encoded.second, "ULL");
+ return strings::StrCat(
+ "::tensorflow::cpu_function_runtime::BufferInfo({",
+ encoded.first, "ULL, ", encoded_second_as_str, "})");
+ });
+ return buffer_infos_as_strings;
+}
} // namespace
Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
@@ -286,29 +317,35 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
TF_RETURN_IF_ERROR(ValidateConfig(config));
TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
const int64 result_index = compile_result.aot->result_buffer_index();
- const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
- if (result_index < 0 || result_index >= temp_sizes.size()) {
+ const std::vector<BufferInfo>& buffer_infos =
+ compile_result.aot->buffer_infos();
+ const std::vector<int32> arg_index_table =
+ ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
+ std::vector<string> buffer_infos_as_strings =
+ BufferInfosToCppExpression(buffer_infos);
+ if (result_index < 0 || result_index >= buffer_infos.size()) {
return errors::InvalidArgument("result index: ", result_index,
" is outside the range of temp sizes: [0,",
- temp_sizes.size(), ")");
+ buffer_infos.size(), ")");
}
// Compute sizes and generate methods.
- std::vector<int64> arg_sizes;
- TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
+ std::vector<BufferInfo> buffer_infos_for_args =
+ ExtractEntryParamBufferInfos(buffer_infos);
+ std::vector<BufferInfo> buffer_infos_for_temps =
+ ExtractTempBufferInfos(buffer_infos);
const xla::ProgramShape& ps = compile_result.program_shape;
string methods_arg, methods_result;
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
- const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
- const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
- const size_t arg_bytes_aligned =
- cpu_function_runtime::AlignedBufferBytes(iarg.data(), iarg.size());
- const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
- const size_t temp_bytes_aligned =
- cpu_function_runtime::AlignedBufferBytes(itemp.data(), itemp.size());
- const size_t temp_bytes_total =
- total_buffer_bytes(itemp.data(), itemp.size());
+ const size_t arg_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
+ buffer_infos_for_args.data(), buffer_infos_for_args.size(),
+ /*allocate_entry_params=*/true);
+ const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
+ const size_t temp_bytes_aligned = cpu_function_runtime::AlignedBufferBytes(
+ buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
+ /*allocate_entry_params=*/true);
+ const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
// Create rewrite strings for namespace start and end.
string ns_start;
@@ -343,8 +380,8 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
// calling HloProfilePrinter::profile_counters_size.
const string assign_profile_counters_size =
opts.gen_hlo_profile_printer_data
- ? "data->profile_counters_size = "
- "data->hlo_profile_printer_data->profile_counters_size();"
+ ? "data->set_profile_counters_size("
+ "data->hlo_profile_printer_data()->profile_counters_size());"
: "";
// Use a poor-man's text templating mechanism; first populate the full header
@@ -414,9 +451,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = {{ARG_NUM}};
// Byte size of each argument buffer. There are kNumArgs entries.
- static const intptr_t* ArgSizes() {
- static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
- return kArgSizes;
+ static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
+ return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
}
// Returns static data used to create an XlaCompiledCpuFunction.
@@ -424,17 +460,17 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
- data->raw_function = {{ENTRY}};
- data->arg_sizes = ArgSizes();
- data->num_args = kNumArgs;
- data->temp_sizes = TempSizes();
- data->num_temps = kNumTemps;
- data->result_index = kResultIndex;
- data->arg_names = StaticArgNames();
- data->result_names = StaticResultNames();
- data->program_shape = StaticProgramShape();
- data->hlo_profile_printer_data = StaticHloProfilePrinterData();
- {{ASSIGN_PROFILE_COUNTERS_SIZE}}
+ data->set_raw_function({{ENTRY}});
+ data->set_buffer_infos(BufferInfos());
+ data->set_num_buffers(kNumBuffers);
+ data->set_arg_index_table(ArgIndexToBufferIndex());
+ data->set_num_args(kNumArgs);
+ data->set_result_index(kResultIndex);
+ data->set_arg_names(StaticArgNames());
+ data->set_result_names(StaticResultNames());
+ data->set_program_shape(StaticProgramShape());
+ data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
+{{ASSIGN_PROFILE_COUNTERS_SIZE}}
return data;
}();
return *kStaticData;
@@ -482,17 +518,27 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{{METHODS_RESULT}}
private:
- // Number of result and temporary buffers for the compiled computation.
- static constexpr size_t kNumTemps = {{TEMP_NUM}};
- // The 0-based index of the result tuple in the temporary buffers.
- static constexpr size_t kResultIndex = {{RESULT_INDEX}};
+ // Number of buffers for the compiled computation.
+ static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
- // Byte size of each result / temporary buffer. There are kNumTemps entries.
- static const intptr_t* TempSizes() {
- static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
- return kTempSizes;
+ static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
+ static const ::tensorflow::cpu_function_runtime::BufferInfo
+ kBufferInfos[kNumBuffers] = {
+{{BUFFER_INFOS_AS_STRING}}
+ };
+ return kBufferInfos;
}
+ static const ::tensorflow::int32* ArgIndexToBufferIndex() {
+ static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
+{{ARG_INDEX_TABLE}}
+ };
+ return kArgIndexToBufferIndex;
+ }
+
+ // The 0-based index of the result tuple in the temporary buffers.
+ static constexpr size_t kResultIndex = {{RESULT_INDEX}};
+
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {{ARG_NAMES_CODE}}
@@ -523,8 +569,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
{"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
{"{{ARG_NAMES_CODE}}", arg_names_code},
- {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
- {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
+ {"{{ARG_NUM}}", strings::StrCat(arg_index_table.size())},
+ {"{{ARG_INDEX_TABLE}}", str_util::Join(arg_index_table, ", ")},
{"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
{"{{CLASS}}", opts.class_name},
{"{{DECLS_FROM_OBJ_FILE}}",
@@ -546,8 +592,9 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{RESULT_NAMES_CODE}}", result_names_code},
{"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
{"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
- {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
- {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}};
+ {"{{NUM_BUFFERS}}", strings::StrCat(buffer_infos.size())},
+ {"{{BUFFER_INFOS_AS_STRING}}",
+ str_util::Join(buffer_infos_as_strings, ",\n")}};
str_util::ReplaceAllPairs(header, rewrites);
return Status::OK();
}
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
index 29bc9c13b8..60d59ae996 100644
--- a/tensorflow/compiler/aot/codegen_test.cc
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -32,6 +32,8 @@ namespace tensorflow {
namespace tfcompile {
namespace {
+using ::tensorflow::cpu_function_runtime::BufferInfo;
+
void ExpectErrorContains(const Status& status, StringPiece str) {
EXPECT_NE(Status::OK(), status);
EXPECT_TRUE(str_util::StrContains(status.error_message(), str))
@@ -171,8 +173,14 @@ TEST(CodegenTest, Golden) {
fetch->mutable_id()->set_node_name("fetch0");
fetch->set_name("myfetch");
CompileResult compile_result;
- compile_result.aot.reset(
- new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5, {}));
+ compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult(
+ {},
+ {BufferInfo::MakeTempBuffer(1),
+ BufferInfo::MakeEntryParameter(/*size=*/8, /*param_number=*/0),
+ BufferInfo::MakeTempBuffer(2),
+ BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
+ BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
+ 5, {}));
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
{
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index 6641d45e83..e4d8a02877 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -65,9 +65,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static constexpr size_t kNumArgs = 2;
// Byte size of each argument buffer. There are kNumArgs entries.
- static const intptr_t* ArgSizes() {
- static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96};
- return kArgSizes;
+ static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
+ return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
}
// Returns static data used to create an XlaCompiledCpuFunction.
@@ -75,17 +74,17 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
XlaCompiledCpuFunction::StaticData* data =
new XlaCompiledCpuFunction::StaticData;
- data->raw_function = entry_point;
- data->arg_sizes = ArgSizes();
- data->num_args = kNumArgs;
- data->temp_sizes = TempSizes();
- data->num_temps = kNumTemps;
- data->result_index = kResultIndex;
- data->arg_names = StaticArgNames();
- data->result_names = StaticResultNames();
- data->program_shape = StaticProgramShape();
- data->hlo_profile_printer_data = StaticHloProfilePrinterData();
-
+ data->set_raw_function(entry_point);
+ data->set_buffer_infos(BufferInfos());
+ data->set_num_buffers(kNumBuffers);
+ data->set_arg_index_table(ArgIndexToBufferIndex());
+ data->set_num_args(kNumArgs);
+ data->set_result_index(kResultIndex);
+ data->set_arg_names(StaticArgNames());
+ data->set_result_names(StaticResultNames());
+ data->set_program_shape(StaticProgramShape());
+ data->set_hlo_profile_printer_data(StaticHloProfilePrinterData());
+
return data;
}();
return *kStaticData;
@@ -215,17 +214,32 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
}
private:
- // Number of result and temporary buffers for the compiled computation.
- static constexpr size_t kNumTemps = 6;
- // The 0-based index of the result tuple in the temporary buffers.
- static constexpr size_t kResultIndex = 5;
+ // Number of buffers for the compiled computation.
+ static constexpr size_t kNumBuffers = 6;
+
+ static const ::tensorflow::cpu_function_runtime::BufferInfo* BufferInfos() {
+ static const ::tensorflow::cpu_function_runtime::BufferInfo
+ kBufferInfos[kNumBuffers] = {
+::tensorflow::cpu_function_runtime::BufferInfo({5ULL, ~0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({34ULL, 0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({9ULL, ~0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({386ULL, 1ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({13ULL, ~0ULL}),
+::tensorflow::cpu_function_runtime::BufferInfo({481ULL, ~0ULL})
+ };
+ return kBufferInfos;
+ }
- // Byte size of each result / temporary buffer. There are kNumTemps entries.
- static const intptr_t* TempSizes() {
- static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
- return kTempSizes;
+ static const ::tensorflow::int32* ArgIndexToBufferIndex() {
+ static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
+1, 3
+ };
+ return kArgIndexToBufferIndex;
}
+ // The 0-based index of the result tuple in the temporary buffers.
+ static constexpr size_t kResultIndex = 5;
+
// Array of names of each positional argument, terminated by nullptr.
static const char** StaticArgNames() {
static const char* kNames[] = {"myfeed", nullptr};
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
index 6b098049cb..df966767b3 100644
--- a/tensorflow/compiler/aot/test.cc
+++ b/tensorflow/compiler/aot/test.cc
@@ -51,11 +51,9 @@ namespace tensorflow {
namespace tfcompile {
namespace {
-void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
- for (int i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
- memset(bufs[i], 0, sizes[i]);
- }
+void zero_buffers(void** bufs, const XlaCompiledCpuFunction& computation) {
+ for (int i = 0; i < computation.num_args(); ++i) {
+ memset(bufs[i], 0, computation.arg_size(i));
}
}
@@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+ zero_buffers(computation.args(), computation);
EXPECT_TRUE(computation.Run());
}
@@ -80,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+ zero_buffers(computation.args(), computation);
testing::StartTiming();
while (--iters) {