diff options
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 38 |
1 files changed, 37 insertions, 1 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 6a37256ea9..67468bdc3f 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -23,7 +23,7 @@ limitations under the License. #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" #include "third_party/mkl/include/mkl_service.h" - +#include "third_party/mkl/include/mkl_trans.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/tensor_format.h" @@ -616,6 +616,42 @@ inline void ForwarMklTensorInToOut(OpKernelContext* context, } } + // TODO(intel_tf): Remove this routine when faster MKL layout conversion is + // out. +inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) { + const float* buf_in = input.flat<float>().data(); + float* buf_out = (*output)->flat<float>().data(); + + int64 N = input.dim_size(0); + int64 H = input.dim_size(1); + int64 W = input.dim_size(2); + int64 C = input.dim_size(3); + int64 stride_n = H*W*C; +# pragma omp parallel for num_threads(16) + for (int64 n = 0; n < N; ++n) { + mkl_somatcopy('R', 'T', H*W, C, 1, buf_in + n*stride_n, C, + buf_out + n*stride_n, H*W); + } +} + + // TODO(intel_tf): Remove this routine when faster MKL layout conversion is + // out. +inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) { + const float* buf_in = input.flat<float>().data(); + float* buf_out = (*output)->flat<float>().data(); + + int64 N = (*output)->dim_size(0); + int64 H = (*output)->dim_size(1); + int64 W = (*output)->dim_size(2); + int64 C = (*output)->dim_size(3); + int64 stride_n = H*W*C; +# pragma omp parallel for num_threads(16) + for (int64 n = 0; n < N; ++n) { + mkl_somatcopy('R', 'T', C, H*W, 1, buf_in + n*stride_n, H*W, + buf_out + n*stride_n, C); + } +} + namespace mkl_op_registry { static const char* kMklOpLabel = "MklOp"; static const char* kMklOpLabelPattern = "label='MklOp'"; |