aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_lrn_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_lrn_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc72
1 files changed, 52 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index edca8e2553..ac432e13ce 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -104,6 +104,15 @@ class MklLRNOp : public OpKernel {
return;
}
+ // TODO(inteltf) MKL will support depth radius not equal to 2 in the future
+ if (depth_radius_ != 2) {
+ Tensor converted_tensor =
+ ConvertMklToTF<T>(context, input, mkl_context.input_shape);
+ mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
+ beta_, converted_tensor);
+ return;
+ }
+
if (input_in_mkl_format) {
// MKL supports normalization over channel dimension only
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
@@ -112,8 +121,10 @@ class MklLRNOp : public OpKernel {
static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
workspace_enabled_ = true;
} else {
+ Tensor converted_tensor =
+ ConvertMklToTF<T>(context, input, mkl_context.input_shape);
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
- beta_, input);
+ beta_, converted_tensor);
return;
}
}
@@ -160,9 +171,7 @@ class MklLRNOp : public OpKernel {
MklShape input_shape;
dnnPrimitive_t lrn_fwd = nullptr;
dnnPrimitive_t convert_input = nullptr;
- /* dnnPrimitive_t convert_output; */
dnnLayout_t lt_input = nullptr;
- /* dnnLayout_t lt_output; */
dnnLayout_t lt_internal_input = nullptr;
dnnLayout_t lt_internal_workspace = nullptr;
dnnLayout_t lt_internal_output = nullptr;
@@ -267,7 +276,7 @@ class MklLRNOp : public OpKernel {
}
// Fallback implementation - Taken from lrn_op.cc
- // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+ // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
// copy.
void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
float bias_, float alpha_, float beta_,
@@ -378,6 +387,7 @@ class MklLRNGradOp : public OpKernel {
mkl_context.MklDefaultToEigen(context);
return;
}
+
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
? &mkl_context.ingrad_shape
@@ -459,11 +469,11 @@ class MklLRNGradOp : public OpKernel {
const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
- mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor;
+ mkl_tmp_outimage_buf_tensor;
// Convert Inputs if needed
- mkl_context.MklPrepareLRNGradInput(
- context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor,
- &mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor);
+ mkl_context.MklPrepareLRNGradInput(context, &mkl_tmp_input_buf_tensor,
+ &mkl_tmp_image_buf_tensor,
+ &mkl_tmp_outimage_buf_tensor);
// We do not do any conversion for output. But we simply emit it
// in MKL format.
@@ -489,14 +499,11 @@ class MklLRNGradOp : public OpKernel {
MklShape ingrad_shape, inimage_shape, outimage_shape;
dnnPrimitive_t lrn_bwd = nullptr;
dnnPrimitive_t convert_input = nullptr;
- /* dnnPrimitive_t convert_output; */
dnnLayout_t lt_input = nullptr;
dnnLayout_t lt_output = nullptr;
dnnLayout_t lt_bdw_input = nullptr;
dnnLayout_t lt_workspace = nullptr;
dnnLayout_t lt_internal_input = nullptr;
- /* dnnLayout_t lt_internal_workspace;
- dnnLayout_t lt_internal_output; */
void* res_lrn_bwd[dnnResourceNumber];
// prepare mkl input
@@ -523,11 +530,13 @@ class MklLRNGradOp : public OpKernel {
void MklPrepareLRNGradInput(OpKernelContext* context,
Tensor* mkl_tmp_input_buf_tensor,
Tensor* mkl_tmp_image_buf_tensor,
- Tensor* mkl_tmp_outimage_buf_tensor,
- Tensor* mkl_tmp_workspace_buf_tensor) {
+ Tensor* mkl_tmp_outimage_buf_tensor) {
const Tensor& in_grads = MklGetInput(context, 0);
const Tensor& in_image = MklGetInput(context, 1);
const Tensor& out_image = MklGetInput(context, 2);
+ const Tensor& workspace = MklGetInput(
+ context,
+ 3); /*Worskpsace is enabled, get the buffer to the workspace */
void* user_input = const_cast<void*>(
static_cast<const void*>(in_grads.flat<T>().data()));
@@ -535,6 +544,9 @@ class MklLRNGradOp : public OpKernel {
static_cast<const void*>(in_image.flat<T>().data()));
void* user_fwd_output = const_cast<void*>(
static_cast<const void*>(out_image.flat<T>().data()));
+ void* workspace_buffer = const_cast<void*>(
+ static_cast<const void*>(workspace.flat<T>().data()));
+
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, lrn_bwd,
dnnResourceWorkspace),
E_SUCCESS);
@@ -609,9 +621,7 @@ class MklLRNGradOp : public OpKernel {
res_lrn_bwd[dnnResourceDst] = user_fwd_output;
}
- // Allocate buffer for workspace.
- AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace,
- &res_lrn_bwd[dnnResourceWorkspace]);
+ res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer;
}
// Fallback implementation - Taken from lrn_op.cc
@@ -619,14 +629,36 @@ class MklLRNGradOp : public OpKernel {
// copy.
void MklDefaultToEigen(OpKernelContext* context) {
// CHECK(false);
- Tensor in_grads = MklGetInput(context, 0);
- Tensor in_image = MklGetInput(context, 1);
- Tensor out_image = MklGetInput(context, 2);
+
+ Tensor in_grads;
+ Tensor in_image;
+ Tensor out_image;
GetMklShape(context, 0, &ingrad_shape);
GetMklShape(context, 1, &inimage_shape);
GetMklShape(context, 2, &outimage_shape);
+ if (ingrad_shape.IsMklTensor()) {
+ in_grads =
+ ConvertMklToTF<T>(context, MklGetInput(context, 0), ingrad_shape);
+ } else {
+ in_grads = MklGetInput(context, 0);
+ }
+
+ if (inimage_shape.IsMklTensor()) {
+ in_image =
+ ConvertMklToTF<T>(context, MklGetInput(context, 1), inimage_shape);
+ } else {
+ in_image = MklGetInput(context, 1);
+ }
+
+ if (outimage_shape.IsMklTensor()) {
+ out_image =
+ ConvertMklToTF<T>(context, MklGetInput(context, 2), outimage_shape);
+ } else {
+ out_image = MklGetInput(context, 2);
+ }
+
const int64 batch = static_cast<int64>(in_grads.dim_size(0));
const int64 rows = static_cast<int64>(in_grads.dim_size(1));
const int64 cols = static_cast<int64>(in_grads.dim_size(2));
@@ -677,7 +709,7 @@ class MklLRNGradOp : public OpKernel {
Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
depth * depth, shard);
}
-
+
// release mkl resources
void Mklcleanup() {
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();