diff options
Diffstat (limited to 'tensorflow/core/util/mkl_util_test.cc')
-rw-r--r-- | tensorflow/core/util/mkl_util_test.cc | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/tensorflow/core/util/mkl_util_test.cc b/tensorflow/core/util/mkl_util_test.cc new file mode 100644 index 0000000000..6aef3d86e9 --- /dev/null +++ b/tensorflow/core/util/mkl_util_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL + +#include "tensorflow/core/util/mkl_util.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +#ifdef INTEL_MKL_DNN + +TEST(MklUtilTest, MklDnnTfShape) { + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<float> a(&cpu_engine); + + const int N = 1, C = 2, H = 3, W = 4; + memory::dims a_dims = {N, C, H, W}; + MklDnnShape a_mkldnn_shape; + a_mkldnn_shape.SetMklTensor(true); + // Create TF layout in NCHW. + a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, memory::format::nchw); + TensorShape a_tf_shape_nchw({N, C, H, W}); + TensorShape a_tf_shape_nhwc({N, H, W, C}); + TensorShape a_mkldnn_tf_shape = a_mkldnn_shape.GetTfShape(); + // Check that returned shape is in NCHW format. + EXPECT_EQ(a_tf_shape_nchw, a_mkldnn_tf_shape); + EXPECT_NE(a_tf_shape_nhwc, a_mkldnn_tf_shape); + + memory::dims b_dims = {N, C, H, W}; + MklDnnShape b_mkldnn_shape; + b_mkldnn_shape.SetMklTensor(true); + // Create TF layout in NHWC. + b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, memory::format::nhwc); + TensorShape b_tf_shape_nhwc({N, H, W, C}); + TensorShape b_tf_shape_nchw({N, C, H, W}); + TensorShape b_mkldnn_tf_shape = b_mkldnn_shape.GetTfShape(); + // Check that returned shape is in NHWC format. + EXPECT_EQ(b_tf_shape_nhwc, b_mkldnn_tf_shape); + EXPECT_NE(b_tf_shape_nchw, b_mkldnn_tf_shape); +} + + +TEST(MklUtilTest, MklDnnBlockedFormatTest) { + // Let's create 2D tensor of shape {3, 4} with 3 being innermost dimension + // first (case 1) and then it being outermost dimension (case 2). + auto cpu_engine = engine(engine::cpu, 0); + + // Setting for case 1 + MklDnnData<float> a(&cpu_engine); + memory::dims dim1 = {3, 4}; + memory::dims strides1 = {1, 3}; + a.SetUsrMem(dim1, strides1); + + memory::desc a_md1 = a.GetUsrMemDesc(); + EXPECT_EQ(a_md1.data.ndims, 2); + EXPECT_EQ(a_md1.data.dims[0], 3); + EXPECT_EQ(a_md1.data.dims[1], 4); + EXPECT_EQ(a_md1.data.format, mkldnn_blocked); + + // Setting for case 2 + MklDnnData<float> b(&cpu_engine); + memory::dims dim2 = {3, 4}; + memory::dims strides2 = {4, 1}; + b.SetUsrMem(dim2, strides2); + + memory::desc b_md2 = b.GetUsrMemDesc(); + EXPECT_EQ(b_md2.data.ndims, 2); + EXPECT_EQ(b_md2.data.dims[0], 3); + EXPECT_EQ(b_md2.data.dims[1], 4); + EXPECT_EQ(b_md2.data.format, mkldnn_blocked); +} + +#endif // INTEL_MKL_DNN +} // namespace +} // namespace tensorflow + +#endif // INTEL_MKL |