aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-10 21:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 21:04:38 -0700
commit84ee0e2a2554e9b9ccfbaf0db1e2db62dd52d8cc (patch)
treede487c8640470f5e7f4d7c9d95a93537465ee016
parent349f9e65fb1c50f57dad53920eb95999fec6b8c2 (diff)
Remove XlaCompiledCpuFunction::args()
This lets us remove XlaCompiledCpuFunction::args_ and some awkwardness from XlaCompiledCpuFunction::Run. PiperOrigin-RevId: 208309249
-rw-r--r--tensorflow/compiler/aot/test.cc10
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc66
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc13
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h23
4 files changed, 47 insertions, 65 deletions
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
index df966767b3..5deb47d123 100644
--- a/tensorflow/compiler/aot/test.cc
+++ b/tensorflow/compiler/aot/test.cc
@@ -51,9 +51,9 @@ namespace tensorflow {
namespace tfcompile {
namespace {
-void zero_buffers(void** bufs, const XlaCompiledCpuFunction& computation) {
- for (int i = 0; i < computation.num_args(); ++i) {
- memset(bufs[i], 0, computation.arg_size(i));
+void zero_buffers(XlaCompiledCpuFunction* computation) {
+ for (int i = 0; i < computation->num_args(); ++i) {
+ memset(computation->arg_data(i), 0, computation->arg_size(i));
}
}
@@ -64,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), computation);
+ zero_buffers(&computation);
EXPECT_TRUE(computation.Run());
}
@@ -78,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), computation);
+ zero_buffers(&computation);
testing::StartTiming();
while (--iters) {
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index fee46280e9..0c0c676ece 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -44,8 +44,8 @@ using ::testing::IsSupersetOf;
TEST(TFCompileTest, Add) {
AddComp add;
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
- EXPECT_EQ(add.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
+ EXPECT_EQ(add.arg1_data(), add.arg_data(1));
add.arg0() = 1;
add.arg1() = 2;
@@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 123);
EXPECT_EQ(add_const.arg0_data()[0], 123);
- EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
EXPECT_EQ(add_const.arg1(), 456);
EXPECT_EQ(add_const.arg1_data()[0], 456);
- EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
EXPECT_EQ(add_const.result0(), 579);
EXPECT_EQ(add_const.result0_data()[0], 579);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) {
int32 arg_y = 32;
add.set_arg0_data(&arg_x);
add.set_arg1_data(&arg_y);
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
- EXPECT_EQ(add.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
+ EXPECT_EQ(add.arg1_data(), add.arg_data(1));
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
@@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) {
TEST(TFCompileTest, AddWithCkpt) {
AddWithCkptComp add;
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1;
EXPECT_TRUE(add.Run());
@@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
- EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
+ EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) {
TEST(TFCompileTest, AddWithCkptSaver) {
AddWithCkptSaverComp add;
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1;
EXPECT_TRUE(add.Run());
@@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
- EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
+ EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) {
TEST(TFCompileTest, Cond) {
CondComp cond;
- EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
- EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
- EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
+ EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
+ EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
+ EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
cond.arg1() = 10;
cond.arg2() = 20;
{
@@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) {
TEST(TFCompileTest, Gather) {
GatherComp gather;
- EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
- EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
+ EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
+ EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
// Successful gather.
{
@@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.arg0(i), params[i]);
EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
}
- EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
+ EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.arg1(i), indices[i]);
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
}
- EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
+ EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.result0(i), results[i]);
EXPECT_EQ(gather_const.result0_data()[i], results[i]);
@@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) {
foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device);
- EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
- EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
+ EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
+ EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
// Test using the argN() methods.
{
@@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) {
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
}
- EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
+ EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
}
- EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
+ EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
@@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) {
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
matmul.set_arg0_data(&arg0);
matmul.set_arg1_data(&arg1);
- EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
- EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
+ EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
+ EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), "");
@@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
MatMulAndAddComp muladd;
muladd.set_thread_pool(&device);
- EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
- EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
+ EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
+ EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
// Test methods with positional args and results.
{
@@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
}
- EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
+ EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
}
- EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
+ EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
@@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
}
- EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
+ EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
}
- EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
+ EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
@@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
- EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
- EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
+ EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
+ EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
add_fn.arg0() = 1;
add_fn.arg1() = 2;
@@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
AssertComp assert;
- EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
- EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
+ EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
+ EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
assert.arg0() = 2;
assert.arg1() = 1;
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index 09c5d1dd19..1f0f240135 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -23,7 +23,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
AllocMode alloc_mode)
: raw_function_(static_data.raw_function_),
result_index_(static_data.result_index_),
- args_(new void*[static_data.num_args_]),
buffer_table_(new void*[static_data.num_buffers_]),
buffer_infos_(static_data.buffer_infos_),
arg_index_table_(static_data.arg_index_table_),
@@ -39,11 +38,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
static_data.buffer_infos_, static_data.num_buffers_,
/*allocate_entry_params=*/allocate_entry_params, buffer_table_,
/*annotate_initialized=*/true);
- if (allocate_entry_params) {
- for (int32 i = 0; i < num_args_; i++) {
- args_[i] = buffer_table_[arg_index_table_[i]];
- }
- }
// If Hlo profiling is enabled the generated code expects an appropriately
// sized buffer to be passed in as the last argument. If Hlo profiling is
// disabled the last function argument is still present in the function
@@ -55,12 +49,6 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
}
bool XlaCompiledCpuFunction::Run() {
- // Propagate pointers to the argument buffers into the buffer table. Code
- // generated by XLA discovers the incoming argument pointers from the buffer
- // table.
- for (int32 i = 0; i < num_args_; i++) {
- buffer_table_[arg_index_table_[i]] = args_[i];
- }
raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
buffer_table_, profile_counters_);
return true;
@@ -68,7 +56,6 @@ bool XlaCompiledCpuFunction::Run() {
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
- delete[] args_;
delete[] buffer_table_;
delete[] profile_counters_;
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index 7dd8c24eb7..425e769346 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -175,17 +175,13 @@ class XlaCompiledCpuFunction {
// ------------------------------
// Arg methods for managing input buffers. Buffers are in row-major order.
- // Returns the underlying array of argument buffers, where args()[I] is the
- // buffer for the positional argument at index I.
- //
- // TODO(sanjoy): We should retire this in favor of explicit accessors. That
- // would let us elide the args_ array.
- void** args() { return args_; }
- const void* const* args() const { return args_; }
-
// Returns the buffer for the positional argument at the given `index`.
- void* arg_data(size_t index) { return args_[index]; }
- const void* arg_data(size_t index) const { return args_[index]; }
+ void* arg_data(size_t index) {
+ return buffer_table_[arg_index_table_[index]];
+ }
+ const void* arg_data(size_t index) const {
+ return buffer_table_[arg_index_table_[index]];
+ }
int num_args() const { return num_args_; }
@@ -210,7 +206,9 @@ class XlaCompiledCpuFunction {
//
// Aliasing of argument and result buffers is not allowed, and results in
// undefined behavior.
- void set_arg_data(size_t index, void* data) { args_[index] = data; }
+ void set_arg_data(size_t index, void* data) {
+ buffer_table_[arg_index_table_[index]] = data;
+ }
// ------------------------------
// Result methods for managing output buffers. Buffers are in row-major order.
@@ -280,9 +278,6 @@ class XlaCompiledCpuFunction {
const RawFunction raw_function_;
const size_t result_index_;
- // Array of argument buffers; entries in args_ may be overwritten by the user.
- void** const args_;
-
// Array containing pointers to argument and temp buffers (slots corresponding
// to constant and on-stack buffers are null).
void** const buffer_table_;