diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_tfconv_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_tfconv_op.cc | 54 |
1 files changed, 43 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc index 588d6874dd..b4aae67ca6 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.cc +++ b/tensorflow/core/kernels/mkl_tfconv_op.cc @@ -24,12 +24,13 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/mkl_util.h" #include "third_party/mkl/include/mkl_dnn.h" #include "third_party/mkl/include/mkl_dnn_types.h" -#include "tensorflow/core/util/mkl_util.h" namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -44,10 +45,11 @@ class MklToTfOp : public OpKernel { explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); + has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); } void Compute(OpKernelContext* context) override { - // 1. Check that input tensor is in MKL format. + // Check that input tensor is in MKL format. const Tensor& input_tensor = MklGetInput(context, 0); MklShape input_shape; GetMklShape(context, 0, &input_shape); @@ -68,9 +70,12 @@ class MklToTfOp : public OpKernel { CHECK_EQ(op_data_type, output_data_type); TensorShape output_shape; - for (size_t i = 0; i < input_shape.GetDimension(); i++) { + size_t ndims = input_shape.GetDimension(); + size_t* in_sizes = new size_t[ndims]; + for (size_t i = 0; i < ndims; i++) { // Outermost to innermost dimension output_shape.AddDim(input_shape.GetSizes()[input_shape.tf_dim_idx(i)]); + in_sizes[i] = input_shape.GetSizes()[i]; } // Allocate output tensor. @@ -78,17 +83,41 @@ class MklToTfOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor)); - // 3. Get input and output layout pointers. - dnnLayout_t output_layout = - static_cast<dnnLayout_t>(input_shape.GetTfLayout()); + // If data format is NHWC, transform MKL tensor to NCHW format and then + // do NCHW -> NHWC. + dnnLayout_t lt_trans_input = nullptr; + Tensor mkl_tmp_trans_input_buf_tensor; + void* buf_trans_input = nullptr; + bool input_fmt_nhwc = input_shape.IsTensorInNHWCFormat(); + if (input_fmt_nhwc && ndims == 4 && has_avx512f_) { + size_t strides_nchw[4]; + GetStridesFromSizes(FORMAT_NCHW, strides_nchw, in_sizes); + CHECK_EQ( + dnnLayoutCreate_F32(<_trans_input, ndims, in_sizes, strides_nchw), + E_SUCCESS); + AllocTmpBuffer(context, &mkl_tmp_trans_input_buf_tensor, lt_trans_input, + &buf_trans_input); + } else { + lt_trans_input = static_cast<dnnLayout_t>(input_shape.GetTfLayout()); + buf_trans_input = + static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data())); + } - // 4. Execute DNNConversion. + // Execute DNNConversion. void* input_buffer = static_cast<void*>(const_cast<T*>(input_tensor.flat<T>().data())); - void* output_buffer = - static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data())); - input_shape.GetConvertedFlatData(output_layout, input_buffer, - output_buffer); + input_shape.GetConvertedFlatData(lt_trans_input, input_buffer, + buf_trans_input); + // NCHW -> NHWC, if data format is NHWC + if (input_fmt_nhwc && ndims == 4 && has_avx512f_) { + dnnLayoutDelete_F32(lt_trans_input); + TensorShape nhwc_shape = ShapeFromFormat( + FORMAT_NHWC, in_sizes[MklDims::N], in_sizes[MklDims::H], + in_sizes[MklDims::W], in_sizes[MklDims::C]); + MklNCHWToNHWC(mkl_tmp_trans_input_buf_tensor, &output_tensor); + } + + delete[] in_sizes; VLOG(1) << "MKLToTFConversion complete successfully."; } @@ -99,6 +128,9 @@ class MklToTfOp : public OpKernel { /// Data type of the operation DataType op_data_type; + + /// CPUIDInfo + bool has_avx512f_ = false; }; /////////////////////////////////////////////////////////// |