aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xsmm_conv2d_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xsmm_conv2d_test.cc')
-rw-r--r--tensorflow/core/kernels/xsmm_conv2d_test.cc15
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/xsmm_conv2d_test.cc b/tensorflow/core/kernels/xsmm_conv2d_test.cc
index f4ab6896ae..381ea39b77 100644
--- a/tensorflow/core/kernels/xsmm_conv2d_test.cc
+++ b/tensorflow/core/kernels/xsmm_conv2d_test.cc
@@ -188,6 +188,8 @@ class XsmmConv2DTest : public OpsTestBase {
TEST_F(XsmmConv2DTest, Basic) {
MakeOp(1);
+ // setup scoped allocator, which uses cpu_allocator() for this scope
+ const libxsmm_tf_allocator<libxsmm_scratch_allocator> tf_allocator;
int ifw = 14; /* input width, "W" */
int ifh = 14; /* input height, "H" */
@@ -223,9 +225,9 @@ TEST_F(XsmmConv2DTest, Basic) {
//Initialization of Filter and Image
/* allocate data */
- float *naive_input = (float*)libxsmm_aligned_malloc( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
- float *naive_output = (float*)libxsmm_aligned_malloc( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
- float *naive_filter = (float*)libxsmm_aligned_malloc( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
+ float *naive_input = (float*)libxsmm_aligned_scratch( nImg*nIfm*ifhp*ifwp*sizeof(float), 2097152);
+ float *naive_output = (float*)libxsmm_aligned_scratch( nImg*nOfm*ofhp*ofwp*sizeof(float), 2097152);
+ float *naive_filter = (float*)libxsmm_aligned_scratch( nOfm*nIfm*kh*kw* sizeof(float), 2097152);
/* initialize data */
init_buf(naive_input, nImg*nIfm*ifhp*ifwp, 0, 0);
zero_buf(naive_output, nImg*nOfm*ofhp*ofwp);
@@ -322,12 +324,11 @@ TEST(XsmmConv2DTest, Basic) {
desc.pad_w_out = 0;
desc.threads = num_threads;
desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
- desc.buffer_format = LIBXSMM_DNN_CONV_FORMAT_NHWC;
- desc.filter_format = LIBXSMM_DNN_CONV_FORMAT_LIBXSMM;//LIBXSMM_DNN_CONV_FORMAT_RSCK;
+ desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
+ desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;//LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
- desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
- desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
+ desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
if (!CanUseXsmmConv2D(desc, data_format)) {
return false;