aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-10 21:00:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 21:04:38 -0700
commit84ee0e2a2554e9b9ccfbaf0db1e2db62dd52d8cc (patch)
treede487c8640470f5e7f4d7c9d95a93537465ee016 /tensorflow/compiler/aot
parent349f9e65fb1c50f57dad53920eb95999fec6b8c2 (diff)
Remove XlaCompiledCpuFunction::args()
This lets us remove XlaCompiledCpuFunction::args_ and some awkwardness from XlaCompiledCpuFunction::Run. PiperOrigin-RevId: 208309249
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/test.cc10
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc66
2 files changed, 38 insertions, 38 deletions
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
index df966767b3..5deb47d123 100644
--- a/tensorflow/compiler/aot/test.cc
+++ b/tensorflow/compiler/aot/test.cc
@@ -51,9 +51,9 @@ namespace tensorflow {
namespace tfcompile {
namespace {
-void zero_buffers(void** bufs, const XlaCompiledCpuFunction& computation) {
- for (int i = 0; i < computation.num_args(); ++i) {
- memset(bufs[i], 0, computation.arg_size(i));
+void zero_buffers(XlaCompiledCpuFunction* computation) {
+ for (int i = 0; i < computation->num_args(); ++i) {
+ memset(computation->arg_data(i), 0, computation->arg_size(i));
}
}
@@ -64,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), computation);
+ zero_buffers(&computation);
EXPECT_TRUE(computation.Run());
}
@@ -78,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), computation);
+ zero_buffers(&computation);
testing::StartTiming();
while (--iters) {
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index fee46280e9..0c0c676ece 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -44,8 +44,8 @@ using ::testing::IsSupersetOf;
TEST(TFCompileTest, Add) {
AddComp add;
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
- EXPECT_EQ(add.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
+ EXPECT_EQ(add.arg1_data(), add.arg_data(1));
add.arg0() = 1;
add.arg1() = 2;
@@ -67,10 +67,10 @@ TEST(TFCompileTest, Add) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 123);
EXPECT_EQ(add_const.arg0_data()[0], 123);
- EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add_const.arg0_data(), add.arg_data(0));
EXPECT_EQ(add_const.arg1(), 456);
EXPECT_EQ(add_const.arg1_data()[0], 456);
- EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add_const.arg1_data(), add.arg_data(1));
EXPECT_EQ(add_const.result0(), 579);
EXPECT_EQ(add_const.result0_data()[0], 579);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@@ -85,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) {
int32 arg_y = 32;
add.set_arg0_data(&arg_x);
add.set_arg1_data(&arg_y);
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
- EXPECT_EQ(add.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
+ EXPECT_EQ(add.arg1_data(), add.arg_data(1));
EXPECT_TRUE(add.Run());
EXPECT_EQ(add.error_msg(), "");
@@ -97,7 +97,7 @@ TEST(TFCompileTest, Add_SetArg) {
TEST(TFCompileTest, AddWithCkpt) {
AddWithCkptComp add;
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1;
EXPECT_TRUE(add.Run());
@@ -117,7 +117,7 @@ TEST(TFCompileTest, AddWithCkpt) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
- EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
+ EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@@ -125,7 +125,7 @@ TEST(TFCompileTest, AddWithCkpt) {
TEST(TFCompileTest, AddWithCkptSaver) {
AddWithCkptSaverComp add;
- EXPECT_EQ(add.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add.arg0_data(), add.arg_data(0));
add.arg0() = 1;
EXPECT_TRUE(add.Run());
@@ -145,7 +145,7 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.error_msg(), "");
EXPECT_EQ(add_const.arg0(), 111);
EXPECT_EQ(add_const.arg0_data()[0], 111);
- EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
+ EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0));
EXPECT_EQ(add_const.result0(), 153);
EXPECT_EQ(add_const.result0_data()[0], 153);
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
@@ -153,9 +153,9 @@ TEST(TFCompileTest, AddWithCkptSaver) {
TEST(TFCompileTest, Cond) {
CondComp cond;
- EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
- EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
- EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
+ EXPECT_EQ(cond.arg0_data(), cond.arg_data(0));
+ EXPECT_EQ(cond.arg1_data(), cond.arg_data(1));
+ EXPECT_EQ(cond.arg2_data(), cond.arg_data(2));
cond.arg1() = 10;
cond.arg2() = 20;
{
@@ -178,8 +178,8 @@ TEST(TFCompileTest, Cond) {
TEST(TFCompileTest, Gather) {
GatherComp gather;
- EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
- EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
+ EXPECT_EQ(gather.arg0_data(), gather.arg_data(0));
+ EXPECT_EQ(gather.arg1_data(), gather.arg_data(1));
// Successful gather.
{
@@ -202,12 +202,12 @@ TEST(TFCompileTest, Gather) {
EXPECT_EQ(gather_const.arg0(i), params[i]);
EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
}
- EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
+ EXPECT_EQ(gather_const.arg0_data(), gather_const.arg_data(0));
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.arg1(i), indices[i]);
EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
}
- EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
+ EXPECT_EQ(gather_const.arg1_data(), gather_const.arg_data(1));
for (int i = 0; i < 2; ++i) {
EXPECT_EQ(gather_const.result0(i), results[i]);
EXPECT_EQ(gather_const.result0_data()[i], results[i]);
@@ -222,8 +222,8 @@ TEST(TFCompileTest, MatMul2) {
foo::bar::MatMulComp matmul;
matmul.set_thread_pool(&device);
- EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
- EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
+ EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
+ EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
// Test using the argN() methods.
{
@@ -271,12 +271,12 @@ TEST(TFCompileTest, MatMul2) {
EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
}
- EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
+ EXPECT_EQ(matmul_const.arg0_data(), matmul.arg_data(0));
for (int i = 0; i < 6; ++i) {
EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
}
- EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
+ EXPECT_EQ(matmul_const.arg1_data(), matmul.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
@@ -300,8 +300,8 @@ TEST(TFCompileTest, MatMul2_SetArg) {
float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
matmul.set_arg0_data(&arg0);
matmul.set_arg1_data(&arg1);
- EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
- EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
+ EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0));
+ EXPECT_EQ(matmul.arg1_data(), matmul.arg_data(1));
EXPECT_TRUE(matmul.Run());
EXPECT_EQ(matmul.error_msg(), "");
@@ -319,8 +319,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
MatMulAndAddComp muladd;
muladd.set_thread_pool(&device);
- EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
- EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
+ EXPECT_EQ(muladd.arg0_data(), muladd.arg_data(0));
+ EXPECT_EQ(muladd.arg1_data(), muladd.arg_data(1));
// Test methods with positional args and results.
{
@@ -346,12 +346,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
}
- EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
+ EXPECT_EQ(muladd_const.arg0_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
}
- EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
+ EXPECT_EQ(muladd_const.arg1_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
@@ -387,12 +387,12 @@ TEST(TFCompileTest, MatMulAndAdd1) {
EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
}
- EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
+ EXPECT_EQ(muladd_const.arg_x_data(), muladd.arg_data(0));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
}
- EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
+ EXPECT_EQ(muladd_const.arg_y_data(), muladd.arg_data(1));
for (int i = 0; i < 4; ++i) {
EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
@@ -407,8 +407,8 @@ TEST(TFCompileTest, MatMulAndAdd1) {
TEST(TFCompileTest, Function) {
// The function is equivalent to an addition
FunctionComp add_fn;
- EXPECT_EQ(add_fn.arg0_data(), add_fn.args()[0]);
- EXPECT_EQ(add_fn.arg1_data(), add_fn.args()[1]);
+ EXPECT_EQ(add_fn.arg0_data(), add_fn.arg_data(0));
+ EXPECT_EQ(add_fn.arg1_data(), add_fn.arg_data(1));
add_fn.arg0() = 1;
add_fn.arg1() = 2;
@@ -451,8 +451,8 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
AssertComp assert;
- EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
- EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
+ EXPECT_EQ(assert.arg0_data(), assert.arg_data(0));
+ EXPECT_EQ(assert.arg1_data(), assert.arg_data(1));
assert.arg0() = 2;
assert.arg1() = 1;