diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_lrn_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_lrn_op.cc | 72 |
1 files changed, 52 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index edca8e2553..ac432e13ce 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -104,6 +104,15 @@ class MklLRNOp : public OpKernel { return; } + // TODO(inteltf) MKL will support depth radius not equal to 2 in the future + if (depth_radius_ != 2) { + Tensor converted_tensor = + ConvertMklToTF<T>(context, input, mkl_context.input_shape); + mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_, + beta_, converted_tensor); + return; + } + if (input_in_mkl_format) { // MKL supports normalization over channel dimension only if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) == @@ -112,8 +121,10 @@ class MklLRNOp : public OpKernel { static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout()); workspace_enabled_ = true; } else { + Tensor converted_tensor = + ConvertMklToTF<T>(context, input, mkl_context.input_shape); mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_, - beta_, input); + beta_, converted_tensor); return; } } @@ -160,9 +171,7 @@ class MklLRNOp : public OpKernel { MklShape input_shape; dnnPrimitive_t lrn_fwd = nullptr; dnnPrimitive_t convert_input = nullptr; - /* dnnPrimitive_t convert_output; */ dnnLayout_t lt_input = nullptr; - /* dnnLayout_t lt_output; */ dnnLayout_t lt_internal_input = nullptr; dnnLayout_t lt_internal_workspace = nullptr; dnnLayout_t lt_internal_output = nullptr; @@ -267,7 +276,7 @@ class MklLRNOp : public OpKernel { } // Fallback implementation - Taken from lrn_op.cc - // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a + // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a // copy. void MklDefaultToEigen(OpKernelContext* context, int depth_radius_, float bias_, float alpha_, float beta_, @@ -378,6 +387,7 @@ class MklLRNGradOp : public OpKernel { mkl_context.MklDefaultToEigen(context); return; } + if (ingrad_in_mkl_format || inimage_in_mkl_format) { const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format) ? &mkl_context.ingrad_shape @@ -459,11 +469,11 @@ class MklLRNGradOp : public OpKernel { const_cast<void*>(static_cast<const void*>(output->flat<T>().data())); Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor, - mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor; + mkl_tmp_outimage_buf_tensor; // Convert Inputs if needed - mkl_context.MklPrepareLRNGradInput( - context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor, - &mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor); + mkl_context.MklPrepareLRNGradInput(context, &mkl_tmp_input_buf_tensor, + &mkl_tmp_image_buf_tensor, + &mkl_tmp_outimage_buf_tensor); // We do not do any conversion for output. But we simply emit it // in MKL format. @@ -489,14 +499,11 @@ class MklLRNGradOp : public OpKernel { MklShape ingrad_shape, inimage_shape, outimage_shape; dnnPrimitive_t lrn_bwd = nullptr; dnnPrimitive_t convert_input = nullptr; - /* dnnPrimitive_t convert_output; */ dnnLayout_t lt_input = nullptr; dnnLayout_t lt_output = nullptr; dnnLayout_t lt_bdw_input = nullptr; dnnLayout_t lt_workspace = nullptr; dnnLayout_t lt_internal_input = nullptr; - /* dnnLayout_t lt_internal_workspace; - dnnLayout_t lt_internal_output; */ void* res_lrn_bwd[dnnResourceNumber]; // prepare mkl input @@ -523,11 +530,13 @@ class MklLRNGradOp : public OpKernel { void MklPrepareLRNGradInput(OpKernelContext* context, Tensor* mkl_tmp_input_buf_tensor, Tensor* mkl_tmp_image_buf_tensor, - Tensor* mkl_tmp_outimage_buf_tensor, - Tensor* mkl_tmp_workspace_buf_tensor) { + Tensor* mkl_tmp_outimage_buf_tensor) { const Tensor& in_grads = MklGetInput(context, 0); const Tensor& in_image = MklGetInput(context, 1); const Tensor& out_image = MklGetInput(context, 2); + const Tensor& workspace = MklGetInput( + context, + 3); /*Worskpsace is enabled, get the buffer to the workspace */ void* user_input = const_cast<void*>( static_cast<const void*>(in_grads.flat<T>().data())); @@ -535,6 +544,9 @@ class MklLRNGradOp : public OpKernel { static_cast<const void*>(in_image.flat<T>().data())); void* user_fwd_output = const_cast<void*>( static_cast<const void*>(out_image.flat<T>().data())); + void* workspace_buffer = const_cast<void*>( + static_cast<const void*>(workspace.flat<T>().data())); + CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, lrn_bwd, dnnResourceWorkspace), E_SUCCESS); @@ -609,9 +621,7 @@ class MklLRNGradOp : public OpKernel { res_lrn_bwd[dnnResourceDst] = user_fwd_output; } - // Allocate buffer for workspace. - AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace, - &res_lrn_bwd[dnnResourceWorkspace]); + res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer; } // Fallback implementation - Taken from lrn_op.cc @@ -619,14 +629,36 @@ class MklLRNGradOp : public OpKernel { // copy. void MklDefaultToEigen(OpKernelContext* context) { // CHECK(false); - Tensor in_grads = MklGetInput(context, 0); - Tensor in_image = MklGetInput(context, 1); - Tensor out_image = MklGetInput(context, 2); + + Tensor in_grads; + Tensor in_image; + Tensor out_image; GetMklShape(context, 0, &ingrad_shape); GetMklShape(context, 1, &inimage_shape); GetMklShape(context, 2, &outimage_shape); + if (ingrad_shape.IsMklTensor()) { + in_grads = + ConvertMklToTF<T>(context, MklGetInput(context, 0), ingrad_shape); + } else { + in_grads = MklGetInput(context, 0); + } + + if (inimage_shape.IsMklTensor()) { + in_image = + ConvertMklToTF<T>(context, MklGetInput(context, 1), inimage_shape); + } else { + in_image = MklGetInput(context, 1); + } + + if (outimage_shape.IsMklTensor()) { + out_image = + ConvertMklToTF<T>(context, MklGetInput(context, 2), outimage_shape); + } else { + out_image = MklGetInput(context, 2); + } + const int64 batch = static_cast<int64>(in_grads.dim_size(0)); const int64 rows = static_cast<int64>(in_grads.dim_size(1)); const int64 cols = static_cast<int64>(in_grads.dim_size(2)); @@ -677,7 +709,7 @@ class MklLRNGradOp : public OpKernel { Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, depth * depth, shard); } - + // release mkl resources void Mklcleanup() { bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor(); |