aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_lrn_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_lrn_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc112
1 files changed, 47 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index 070aeff49f..07a7e6b5da 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -22,9 +22,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <vector>
-#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "third_party/mkl/include/mkl_dnn.h"
-#include "third_party/mkl/include/mkl_dnn_types.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -33,6 +30,9 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "third_party/mkl/include/mkl_dnn.h"
+#include "third_party/mkl/include/mkl_dnn_types.h"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/util/work_sharder.h"
@@ -66,11 +66,10 @@ class MklLRNOp : public OpKernel {
explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
int64 depth_radius64;
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
- OP_REQUIRES(
- context,
- FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
- errors::InvalidArgument("depth_radius = ", depth_radius64,
- " larger than int max"));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
depth_radius_ = static_cast<size_t>(depth_radius64);
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
@@ -93,10 +92,9 @@ class MklLRNOp : public OpKernel {
: input.dims();
OP_REQUIRES(context, mkl_context.in_dims == 4,
errors::InvalidArgument("input must be 4-dimensional"));
- OP_REQUIRES(
- context,
- FastBoundsCheck(input.NumElements(), std::numeric_limits<int>::max()),
- errors::InvalidArgument("argument to LRN too large"));
+ OP_REQUIRES(context, FastBoundsCheck(input.NumElements(),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
if (!input_in_mkl_format) {
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
@@ -104,15 +102,6 @@ class MklLRNOp : public OpKernel {
return;
}
- // TODO(inteltf) MKL will support depth radius not equal to 2 in the future
- if (depth_radius_ != 2) {
- Tensor converted_tensor =
- ConvertMklToTF<T>(context, input, mkl_context.input_shape);
- mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
- beta_, converted_tensor);
- return;
- }
-
if (input_in_mkl_format) {
// MKL supports normalization over channel dimension only
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
@@ -345,11 +334,10 @@ class MklLRNGradOp : public OpKernel {
explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
int64 depth_radius64;
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
- OP_REQUIRES(
- context,
- FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
- errors::InvalidArgument("depth_radius = ", depth_radius64,
- " larger than int max"));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
depth_radius_ = static_cast<int>(depth_radius64);
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
@@ -553,6 +541,9 @@ class MklLRNGradOp : public OpKernel {
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_bdw_input, lrn_bwd,
dnnResourceDiffDst),
E_SUCCESS);
+ CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_input, lrn_bwd,
+ dnnResourceSrc),
+ E_SUCCESS);
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
if (ingrad_in_mkl_format) {
@@ -581,44 +572,37 @@ class MklLRNGradOp : public OpKernel {
}
}
-// Although MKL documentation for LRN does not specify setting/getting
-// of dnnResourceSrc and dnnResourceDst, Caffe code sets dnnResourceSrc.
-// So we set dnnResourceSrc here. But we do not know why we are setting
-// dnnResourceDst.
-#if 0
- // NOTE: The code below is kept just so that we know how we should handle
- // dnnResourceSrc if the primitive layout for dnnResourceSrc was supported.
-
- if (!dnnLayoutCompare_F32(lt_internal_input,
- static_cast<dnnLayout_t>inimage_shape.GetCurLayout())) {
- AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
- &res_lrn_bwd[dnnResourceSrc]);
- inimage_shape.GetConvertedFlatData(lt_internal_input,
- user_fwd_input,
- res_lrn_bwd[dnnResourceSrc]);
- } else {
- res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
- }
-#endif
-
- // Since we cannot get expected layout for dnnResourceSrc, we construct
- // buffer using
- // MKL format if input is in MKL format.
- if (inimage_shape.IsMklTensor()) {
- AllocTmpBuffer(context, mkl_tmp_image_buf_tensor,
- (dnnLayout_t)inimage_shape.GetCurLayout(),
- &res_lrn_bwd[dnnResourceSrc]);
+ bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
+ if (inimage_in_mkl_format) {
+ if (!dnnLayoutCompare_F32(
+ lt_internal_input,
+ static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()))) {
+ AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
+ &res_lrn_bwd[dnnResourceSrc]);
+ ingrad_shape.GetConvertedFlatData(lt_internal_input, user_fwd_input,
+ res_lrn_bwd[dnnResourceSrc]);
+ } else {
+ res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+ }
} else {
- res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
- }
+ if (!dnnLayoutCompare_F32(
+ lt_internal_input,
+ static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()))) {
+ CHECK_EQ(dnnConversionCreate_F32(
+ &convert_input,
+ static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()),
+ lt_internal_input),
+ E_SUCCESS);
- // Same comment as above.
- if (outimage_shape.IsMklTensor()) {
- AllocTmpBuffer(context, mkl_tmp_outimage_buf_tensor,
- (dnnLayout_t)outimage_shape.GetCurLayout(),
- &res_lrn_bwd[dnnResourceDst]);
- } else {
- res_lrn_bwd[dnnResourceDst] = user_fwd_output;
+ AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
+ &res_lrn_bwd[dnnResourceSrc]);
+ CHECK_EQ(dnnConversionExecute_F32(convert_input, user_fwd_input,
+ res_lrn_bwd[dnnResourceSrc]),
+ E_SUCCESS);
+ dnnDelete_F32(convert_input);
+ } else {
+ res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
+ }
}
res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer;
@@ -628,8 +612,6 @@ class MklLRNGradOp : public OpKernel {
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
// copy.
void MklDefaultToEigen(OpKernelContext* context) {
- // CHECK(false);
-
Tensor in_grads;
Tensor in_image;
Tensor out_image;
@@ -709,7 +691,7 @@ class MklLRNGradOp : public OpKernel {
Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
depth * depth, shard);
}
-
+
// release mkl resources
void Mklcleanup() {
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();