aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot/test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/aot/test.cc')
-rw-r--r--tensorflow/compiler/aot/test.cc12
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
index 6b098049cb..df966767b3 100644
--- a/tensorflow/compiler/aot/test.cc
+++ b/tensorflow/compiler/aot/test.cc
@@ -51,11 +51,9 @@ namespace tensorflow {
namespace tfcompile {
namespace {
-void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
- for (int i = 0; i < n; ++i) {
- if (sizes[i] != -1) {
- memset(bufs[i], 0, sizes[i]);
- }
+void zero_buffers(void** bufs, const XlaCompiledCpuFunction& computation) {
+ for (int i = 0; i < computation.num_args(); ++i) {
+ memset(bufs[i], 0, computation.arg_size(i));
}
}
@@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+ zero_buffers(computation.args(), computation);
EXPECT_TRUE(computation.Run());
}
@@ -80,7 +78,7 @@ void BM_NAME(int iters) {
CPP_CLASS computation;
computation.set_thread_pool(&device);
- zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+ zero_buffers(computation.args(), computation);
testing::StartTiming();
while (--iters) {