diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_tfconv_op.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_tfconv_op.h | 80 |
1 files changed, 75 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h index a240ee44fb..0a5be4fec9 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.h +++ b/tensorflow/core/kernels/mkl_tfconv_op.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL - #ifndef TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ #define TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ +#ifdef INTEL_MKL + #include <algorithm> #include <vector> #include "tensorflow/core/framework/numeric_op.h" @@ -35,6 +35,10 @@ limitations under the License. #include "mkl_dnn_types.h" #include "tensorflow/core/util/mkl_util.h" +#ifdef INTEL_MKL_DNN +using mkldnn::stream; +#endif + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -57,6 +61,71 @@ class MklToTfOp : public OpKernel { VLOG(1) << "MKLToTFConversion complete successfully."; } +#ifdef INTEL_MKL_DNN + static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, + string data_format_str, DataType op_data_type, + bool has_avx512f, uint input_number) { + try { + // Check that input tensor is in MKL format. + const Tensor& input_tensor = MklGetInput(context, input_number); + MklDnnShape input_shape; + GetMklShape(context, input_number, &input_shape); + + // if input is already in Tf format, then copy input tensor to output. + if (!input_shape.IsMklTensor()) { + context->set_output(input_number, input_tensor); + VLOG(1) << "MKLToTFConversion: No conversion needed, " + << "copying input to output"; + return; + } + + // Check that input data type is same as operator data type and that it + // is same as output data type. + DataType input_data_type = op_kernel->input_type(input_number); + DataType output_data_type = op_kernel->output_type(input_number); + CHECK_EQ(op_data_type, input_data_type); + CHECK_EQ(op_data_type, output_data_type); + + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> input(&cpu_engine); + + // Get Mkl layout of input tensor. + auto input_mkl_md = input_shape.GetMklLayout(); + // Get TensorFlow layout of input tensor. Expected output of conversion + // has same layout as Tensorflow layout of input tensor. + auto output_tf_md = input_shape.GetTfLayout(); + auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine); + // Set input Mkl layout as the user layout. + input.SetUsrMem(input_mkl_md, &input_tensor); + + // Allocate output tensor. + TensorShape output_shape = input_shape.GetTfShape(); + Tensor* output_tensor = NULL; + OP_REQUIRES_OK(context, context->allocate_output(input_number, + output_shape, &output_tensor)); + CHECK_NOTNULL(output_tensor); + + // Do we need to reorder Mkl layout into TensorFlow layout? + if (input.IsReorderNeeded(output_tf_pd)) { + // Insert reorder between Mkl layout and TensorFlow layout. + std::vector<primitive> net; + CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, output_tensor, &net), + true); + stream(stream::kind::eager).submit(net).wait(); + } else { + // If not, just forward input tensor to output tensor. + CHECK(output_tensor->CopyFrom(input_tensor, output_shape)); + } + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + std::string(e.message) + + ", in file " + std::string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } +#else static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, string data_format_str, DataType op_data_type, bool has_avx512f, uint input_number) { @@ -91,8 +160,8 @@ class MklToTfOp : public OpKernel { // Allocate output tensor. Tensor* output_tensor = NULL; - OP_REQUIRES_OK(context, - context->allocate_output(input_number, output_shape, &output_tensor)); + OP_REQUIRES_OK(context, context->allocate_output(input_number, + output_shape, &output_tensor)); dnnLayout_t output_layout = static_cast<dnnLayout_t>(input_shape.GetTfLayout()); @@ -106,6 +175,7 @@ class MklToTfOp : public OpKernel { output_buffer); VLOG(1) << "MKLToTFConversion complete successfully."; } +#endif private: /// Data format of the operation @@ -132,5 +202,5 @@ class MklToTfOp : public OpKernel { TF_CALL_NUMBER_TYPES(REGISTER_CPU); #undef REGISTER_CPU } // namespace tensorflow -#endif // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ #endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_TFCONV_OP_H_ |