aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_tfconv_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_tfconv_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.cc54
1 files changed, 43 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.cc b/tensorflow/core/kernels/mkl_tfconv_op.cc
index 588d6874dd..b4aae67ca6 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.cc
+++ b/tensorflow/core/kernels/mkl_tfconv_op.cc
@@ -24,12 +24,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/mkl_util.h"
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -44,10 +45,11 @@ class MklToTfOp : public OpKernel {
explicit MklToTfOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
+ has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
}
void Compute(OpKernelContext* context) override {
- // 1. Check that input tensor is in MKL format.
+ // Check that input tensor is in MKL format.
const Tensor& input_tensor = MklGetInput(context, 0);
MklShape input_shape;
GetMklShape(context, 0, &input_shape);
@@ -68,9 +70,12 @@ class MklToTfOp : public OpKernel {
CHECK_EQ(op_data_type, output_data_type);
TensorShape output_shape;
- for (size_t i = 0; i < input_shape.GetDimension(); i++) {
+ size_t ndims = input_shape.GetDimension();
+ size_t* in_sizes = new size_t[ndims];
+ for (size_t i = 0; i < ndims; i++) {
// Outermost to innermost dimension
output_shape.AddDim(input_shape.GetSizes()[input_shape.tf_dim_idx(i)]);
+ in_sizes[i] = input_shape.GetSizes()[i];
}
// Allocate output tensor.
@@ -78,17 +83,41 @@ class MklToTfOp : public OpKernel {
OP_REQUIRES_OK(context,
context->allocate_output(0, output_shape, &output_tensor));
- // 3. Get input and output layout pointers.
- dnnLayout_t output_layout =
- static_cast<dnnLayout_t>(input_shape.GetTfLayout());
+ // If data format is NHWC, transform MKL tensor to NCHW format and then
+ // do NCHW -> NHWC.
+ dnnLayout_t lt_trans_input = nullptr;
+ Tensor mkl_tmp_trans_input_buf_tensor;
+ void* buf_trans_input = nullptr;
+ bool input_fmt_nhwc = input_shape.IsTensorInNHWCFormat();
+ if (input_fmt_nhwc && ndims == 4 && has_avx512f_) {
+ size_t strides_nchw[4];
+ GetStridesFromSizes(FORMAT_NCHW, strides_nchw, in_sizes);
+ CHECK_EQ(
+ dnnLayoutCreate_F32(&lt_trans_input, ndims, in_sizes, strides_nchw),
+ E_SUCCESS);
+ AllocTmpBuffer(context, &mkl_tmp_trans_input_buf_tensor, lt_trans_input,
+ &buf_trans_input);
+ } else {
+ lt_trans_input = static_cast<dnnLayout_t>(input_shape.GetTfLayout());
+ buf_trans_input =
+ static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
+ }
- // 4. Execute DNNConversion.
+ // Execute DNNConversion.
void* input_buffer =
static_cast<void*>(const_cast<T*>(input_tensor.flat<T>().data()));
- void* output_buffer =
- static_cast<void*>(const_cast<T*>(output_tensor->flat<T>().data()));
- input_shape.GetConvertedFlatData(output_layout, input_buffer,
- output_buffer);
+ input_shape.GetConvertedFlatData(lt_trans_input, input_buffer,
+ buf_trans_input);
+ // NCHW -> NHWC, if data format is NHWC
+ if (input_fmt_nhwc && ndims == 4 && has_avx512f_) {
+ dnnLayoutDelete_F32(lt_trans_input);
+ TensorShape nhwc_shape = ShapeFromFormat(
+ FORMAT_NHWC, in_sizes[MklDims::N], in_sizes[MklDims::H],
+ in_sizes[MklDims::W], in_sizes[MklDims::C]);
+ MklNCHWToNHWC(mkl_tmp_trans_input_buf_tensor, &output_tensor);
+ }
+
+ delete[] in_sizes;
VLOG(1) << "MKLToTFConversion complete successfully.";
}
@@ -99,6 +128,9 @@ class MklToTfOp : public OpKernel {
/// Data type of the operation
DataType op_data_type;
+
+ /// CPUIDInfo
+ bool has_avx512f_ = false;
};
///////////////////////////////////////////////////////////