diff options
Diffstat (limited to 'tensorflow/compiler/aot/test.cc')
-rw-r--r-- | tensorflow/compiler/aot/test.cc | 12 |
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc index 6b098049cb..df966767b3 100644 --- a/tensorflow/compiler/aot/test.cc +++ b/tensorflow/compiler/aot/test.cc @@ -51,11 +51,9 @@ namespace tensorflow { namespace tfcompile { namespace { -void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) { - for (int i = 0; i < n; ++i) { - if (sizes[i] != -1) { - memset(bufs[i], 0, sizes[i]); - } +void zero_buffers(void** bufs, const XlaCompiledCpuFunction& computation) { + for (int i = 0; i < computation.num_args(); ++i) { + memset(bufs[i], 0, computation.arg_size(i)); } } @@ -66,7 +64,7 @@ TEST(TEST_NAME, NoCrash) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(computation.args(), computation); EXPECT_TRUE(computation.Run()); } @@ -80,7 +78,7 @@ void BM_NAME(int iters) { CPP_CLASS computation; computation.set_thread_pool(&device); - zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs); + zero_buffers(computation.args(), computation); testing::StartTiming(); while (--iters) { |