aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-23 20:12:06 -0500
committerGravatar GitHub <noreply@github.com>2017-12-23 20:12:06 -0500
commit26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca (patch)
tree502bd704023204a68368f2433f6ef111cc76adef
parent0242356e0438cf43f6ac6c7cddf00cb79888da91 (diff)
parent3ce813dec7cb65fbe3b9cfa37972ed7344dd94b0 (diff)
Merge pull request #15594 from Intel-tensorflow/pr-mkl-dnn-lrn
MKL: Adding MKL-DNN support for LRN op
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc2
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc659
-rw-r--r--tensorflow/core/ops/nn_ops.cc8
3 files changed, 666 insertions, 3 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 3beca1e5d2..0ffdc42852 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -2495,14 +2495,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsDataType, AlwaysRewrite});
- /*
rinfo_.push_back({csinfo_.lrn,
mkl_op_registry::GetMklOpName(csinfo_.lrn),
CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, AlwaysRewrite});
- */
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index 227765e46d..a8f28202f4 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
// layout and primitives, use MKL dnn primitives to compute local
// response normalization
-#undef INTEL_MKL
+
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
@@ -38,6 +38,15 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
#endif
+#ifdef INTEL_MKL_DNN
+#include "mkldnn.hpp"
+using mkldnn::lrn_forward;
+using mkldnn::lrn_backward;
+using mkldnn::prop_kind;
+using mkldnn::algorithm::lrn_across_channels;
+using mkldnn::stream;
+#endif
+
namespace tensorflow {
namespace {
@@ -58,6 +67,8 @@ void GetBandMatrix(int depth, int depth_radius,
} // namespace
+#ifndef INTEL_MKL_DNN
+
template <typename T>
class MklLRNOp : public OpKernel {
public:
@@ -328,6 +339,7 @@ class MklLRNOp : public OpKernel {
float beta_;
};
+
template <typename T>
class MklLRNGradOp : public OpKernel {
public:
@@ -648,6 +660,7 @@ class MklLRNGradOp : public OpKernel {
const auto nodes = cols * rows;
auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
+
auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
@@ -717,6 +730,649 @@ class MklLRNGradOp : public OpKernel {
float beta_;
};
+#else
+
+template <typename T>
+class MklLRNOp : public OpKernel {
+ public:
+ ~MklLRNOp() {}
+
+ explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
+ int64 depth_radius64;
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
+ depth_radius_ = static_cast<size_t>(depth_radius64);
+
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ workspace_enabled_ = false;
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ SanityCheckInputs(context);
+ if (!context->status().ok()) return;
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ const Tensor& src_tensor = MklGetInput(context, kIdxInput);
+ MklDnnShape src_dnn_shape;
+ GetMklShape(context, kIdxInput, &src_dnn_shape);
+
+ // MKL-DNN has a notion of kernel_size and not depth_radius.
+ int kernel_size = 2 * depth_radius_ + 1;
+ float new_alpha = alpha_ * kernel_size;
+
+ // if the input tensor is not an MKL Tensor, or if the last
+ // dimension is not channel, then just use Eigen.
+ // MKL only support normalization over the channel dimension.
+ if (!src_dnn_shape.IsMklTensor()) {
+ MklDefaultToEigen(context, src_tensor);
+ return;
+ } else if (!src_dnn_shape.IsMklChannelDim(
+ src_dnn_shape.GetDimension() - 1) ) {
+ Tensor converted_tensor =
+ ConvertMklToTF<T>(context, src_tensor, src_dnn_shape);
+ MklDefaultToEigen(context, converted_tensor);
+ return;
+ }
+ // At this point, we can assume that the src is an MklTensor
+ // and we can enable the workspace
+ workspace_enabled_ = true;
+
+ MklDnnData<T> src_dnn_data(&cpu_engine);
+ MklDnnData<T> dst_dnn_data(&cpu_engine);
+ MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
+
+ TensorShape tf_output_shape = src_tensor.shape();
+
+ memory::desc src_md = src_dnn_shape.GetCurLayout();
+ memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims();
+
+ // Create memory for user input.
+ // Since Tensorflow always performs normalization over last dimension,
+ // and MKL-DNN performs normalization over Channel, we tell MKL-DNN
+ // that input is in NHWC layout with Channel being the last dimension.
+ src_dnn_data.SetUsrMem(src_md, &src_tensor);
+ src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
+
+ // output_dnn_data and workspace both have the same shape as input
+ dst_dnn_data.SetUsrMem(src_md);
+ dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
+
+ // Create LRN primitive descriptor.
+ // Tensorflow's normalization semantics is across channels.
+ // MKL-DNN also supports normalization within channel.
+ auto lrn_desc = lrn_forward::desc(prop_kind::forward,
+ lrn_across_channels,
+ src_dnn_data.GetUsrMemDesc(),
+ kernel_size,
+ new_alpha, beta_, bias_);
+ auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
+
+ // Allocate output_dnn_data tensor.
+ Tensor* output_tensor = nullptr;
+ memory::format input_format = src_dnn_shape.GetTfDataFormat();
+ AllocateOutputTensor(context, lrn_prim_desc, input_dims,
+ input_format, &output_tensor);
+ OP_REQUIRES_OK(context, context->status());
+ CHECK_NOTNULL(output_tensor);
+ dst_dnn_data.SetUsrMemDataHandle(output_tensor);
+
+ // Handle workspace required for MKL-DNN.
+ AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
+ OP_REQUIRES_OK(context, context->status());
+
+ PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data,
+ &dst_dnn_data, &workspace_dnn_data);
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
+ }
+ }
+
+ private:
+ void PrepareAndExecuteNet(
+ const lrn_forward::primitive_desc& lrn_fwd_desc,
+ MklDnnData<T>* src_dnn_data,
+ MklDnnData<T>* dst_dnn_data,
+ MklDnnData<uint8>* wksp_dnn_data = nullptr) {
+ std::vector<primitive> net;
+
+ // Check for input reorder
+ src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net);
+
+ // Create pooling primitive and add it to net
+ if (wksp_dnn_data != nullptr) {
+ net.push_back(lrn_forward(lrn_fwd_desc,
+ src_dnn_data->GetOpMem(),
+ wksp_dnn_data->GetOpMem(),
+ dst_dnn_data->GetOpMem()));
+ } else {
+ net.push_back(lrn_forward(lrn_fwd_desc,
+ src_dnn_data->GetOpMem(),
+ dst_dnn_data->GetOpMem()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+ }
+
+ void AllocateOutputTensor(OpKernelContext* context,
+ const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
+ const memory::dims output_dims_mkl_order,
+ const memory::format& output_tf_format,
+ Tensor** output_tensor) {
+ CHECK_NOTNULL(output_tensor);
+ memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc();
+
+ MklDnnShape output_mkl_shape;
+ // We only handle the case when the inputs and output are in Mkl format
+ // Any other case is handled by Eigen
+ output_mkl_shape.SetMklTensor(true);
+ output_mkl_shape.SetMklLayout(&dst_pd);
+ output_mkl_shape.SetElemType(MklDnnType<T>());
+ output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
+ output_dims_mkl_order,
+ output_tf_format);
+ TensorShape output_tf_shape;
+ // only allocate enough space for the elements we need.
+ size_t num_bytes = dst_pd.get_size();
+ CHECK_EQ(num_bytes % sizeof(T), 0);
+ output_tf_shape.AddDim(num_bytes / sizeof(T));
+ AllocateOutputSetMklShape(context, kIdxOutput,
+ output_tensor,
+ output_tf_shape, output_mkl_shape);
+ }
+
+ // Fallback implementation - Taken from lrn_op.cc
+ // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
+ // copy.
+ void MklDefaultToEigen(OpKernelContext* context,
+ const Tensor& input) {
+ const int batch = static_cast<int>(input.dim_size(0));
+ const int rows = static_cast<int>(input.dim_size(1));
+ const int cols = static_cast<int>(input.dim_size(2));
+ const int depth = static_cast<int>(input.dim_size(3));
+ const int nodes = cols * rows;
+
+ auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
+ // Multiplying the input with the band matrix has the effect of reducing
+ // the
+ // correct patch along the depth.
+ Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
+ GetBandMatrix<T>(depth, depth_radius_, &multiplier);
+
+ Tensor *output_dnn_data, *workspace;
+ MklDnnShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ mkl_output_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
+ input.shape(), mkl_output_mkl_shape);
+
+ mkl_workspace_mkl_shape.SetMklTensor(false);
+ mkl_workspace_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace,
+ input.shape(), mkl_workspace_mkl_shape);
+
+ auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
+ Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
+ auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
+ if (beta_ == T(1)) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.inverse();
+ } else if (beta_ == T(0.5)) {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * tmp.rsqrt();
+ } else {
+ out_shaped.device(context->eigen_cpu_device()) =
+ in_shaped * (tmp.log() * -beta_).exp();
+ }
+ }
+
+ void AllocateWorkspaceTensor(OpKernelContext* context,
+ const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
+ MklDnnData<uint8>* dnn_data_wksp) {
+ CHECK_NOTNULL(dnn_data_wksp);
+ Tensor* workspace_tensor = nullptr;
+ memory::primitive_desc workspace_pd
+ = lrn_fwd_prim_desc.workspace_primitive_desc();
+ size_t workspace_bytes = workspace_pd.get_size();
+ MklDnnShape workspace_mkl_shape;
+ // the workspace tensor is a uint8 tensor that has
+ // exactly the number of bytes necessary
+ workspace_mkl_shape.SetMklTensor(false);
+ TensorShape workspace_tf_shape;
+ workspace_tf_shape.AddDim(workspace_bytes);
+ AllocateOutputSetMklShape(context, kIdxWorkspace,
+ &workspace_tensor,
+ workspace_tf_shape, workspace_mkl_shape);
+ CHECK_NOTNULL(workspace_tensor);
+ dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
+ }
+
+ void SanityCheckInputs(OpKernelContext* context) {
+ const Tensor& src_tensor = MklGetInput(context, kIdxInput);
+ MklDnnShape src_dnn_shape;
+ GetMklShape(context, kIdxInput, &src_dnn_shape);
+ if (src_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("input must be 4-dimensional"));
+ OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
+ } else {
+ OP_REQUIRES(context, src_tensor.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional"));
+ OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("argument to LRN too large"));
+ }
+ }
+ const int kIdxInput = 0,
+ kIdxOutput = 0,
+ kIdxWorkspace = 1;
+
+ typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+ bool workspace_enabled_;
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+
+template <typename T>
+class MklLRNGradOp : public OpKernel {
+ public:
+ explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
+ int64 depth_radius64;
+ OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
+ OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("depth_radius = ", depth_radius64,
+ " larger than int max"));
+ depth_radius_ = static_cast<int>(depth_radius64);
+ OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
+ workspace_enabled_ = false;
+ context->GetAttr("workspace_enabled", &workspace_enabled_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ try {
+ SanityCheckInputs(context);
+ if (!context->status().ok()) return;
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> input_grad_dnn_data(&cpu_engine);
+ MklDnnData<T> orig_input_dnn_data(&cpu_engine);
+ MklDnnData<T> orig_output_dnn_data(&cpu_engine);
+ MklDnnData<T> output_dnn_data(&cpu_engine);
+
+ MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
+ orig_output_dnn_shape;
+ GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
+ GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
+ GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
+
+ // We only use MKLDNN if all of the necessary inputs are present
+ // in mkldnn format, and Channel is the last dimension
+ bool can_use_mkldnn = workspace_enabled_ &&
+ input_grad_dnn_shape.IsMklTensor() &&
+ orig_input_dnn_shape.IsMklTensor() &&
+ orig_output_dnn_shape.IsMklTensor() &&
+ input_grad_dnn_shape.IsMklChannelDim(
+ input_grad_dnn_shape.GetDimension() - 1) &&
+ orig_input_dnn_shape.IsMklChannelDim(
+ orig_input_dnn_shape.GetDimension() - 1) &&
+ orig_output_dnn_shape.IsMklChannelDim(
+ orig_output_dnn_shape.GetDimension() - 1);
+
+ if (!can_use_mkldnn) {
+ // Fallback to eigen
+ MklDefaultToEigen(context);
+ return;
+ }
+ // At this point, we have the all clear to use MklDnn constructs
+ // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor.
+ const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient);
+ const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
+ const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
+
+ // Get input sizes in MKL-DNN required NCHW format.
+ // LRN does not have data_format attribute. But by default it has
+ // NHWC format.
+ memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout();
+ memory::desc target_diff_dst_md = ConfigureInputGradient(
+ input_grad_tensor,
+ input_grad_dnn_shape,
+ &input_grad_dnn_data);
+
+ memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout();
+ memory::dims orig_input_dims =
+ orig_input_dnn_shape.GetSizesAsMklDnnDims();
+ orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
+ orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
+
+ // output_dnn_data has the same shape as original input
+ output_dnn_data.SetUsrMem(orig_input_md);
+ output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
+
+ // MKL-DNN has a notion of kernel_size and not depth_radius.
+ int kernel_size = 2 * depth_radius_ + 1;
+ float new_alpha = alpha_ * kernel_size;
+
+ // Create LRN backward primitive descriptor. It requires LRN forward
+ // primitive descriptor also.
+ auto lrn_fwd_desc = lrn_forward::desc(prop_kind::forward,
+ lrn_across_channels,
+ orig_input_md,
+ kernel_size,
+ new_alpha, beta_, bias_);
+ auto lrn_fwd_prim_desc = lrn_forward::primitive_desc(lrn_fwd_desc,
+ cpu_engine);
+ auto lrn_bwd_desc = lrn_backward::desc(lrn_across_channels,
+ original_output_md,
+ target_diff_dst_md,
+ kernel_size,
+ new_alpha, beta_, bias_);
+ auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(lrn_bwd_desc,
+ cpu_engine,
+ lrn_fwd_prim_desc);
+
+ Tensor* output_tensor = nullptr;
+ memory::format orig_input_format
+ = orig_input_dnn_shape.GetTfDataFormat();
+ AllocateOutputTensor(context, lrn_bwd_prim_desc,
+ orig_input_dims, orig_input_format, &output_tensor);
+ OP_REQUIRES_OK(context, context->status());
+ CHECK_NOTNULL(output_tensor);
+ output_dnn_data.SetUsrMemDataHandle(output_tensor);
+
+ // Create LRN primitive and add it to the net
+ // At this point, workspace is enabled, so we don't need
+ // to check. Pass input workspace to LRN backward primitive.
+ const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
+ MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
+ ConfigureWorkspace(workspace_tensor,
+ lrn_fwd_prim_desc.workspace_primitive_desc(),
+ &workspace_dnn_data);
+
+ PrepareAndExecuteNet(lrn_bwd_prim_desc,
+ lrn_fwd_prim_desc,
+ &orig_input_dnn_data,
+ &input_grad_dnn_data,
+ &output_dnn_data,
+ memory::primitive_desc(target_diff_dst_md, cpu_engine),
+ &workspace_dnn_data);
+ } catch (mkldnn::error &e) {
+ string error_msg = "Status: " + std::to_string(e.status) +
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
+ }
+ }
+
+ void AllocateOutputTensor(OpKernelContext* context,
+ const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
+ const memory::dims output_dims_mkl_order,
+ const memory::format& output_tf_format,
+ Tensor** output_tensor) {
+ CHECK_NOTNULL(output_tensor);
+ memory::primitive_desc dst_pd
+ = lrn_bkwd_prim_desc.diff_src_primitive_desc();
+ MklDnnShape output_mkl_shape;
+
+ // We assume that all outputs at this point are MKL Tensors
+ output_mkl_shape.SetMklTensor(true);
+ output_mkl_shape.SetMklLayout(&dst_pd);
+ output_mkl_shape.SetElemType(MklDnnType<T>());
+ output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
+ output_dims_mkl_order,
+ output_tf_format);
+
+ TensorShape output_tf_shape;
+ size_t num_bytes = dst_pd.get_size();
+ CHECK_EQ(num_bytes % sizeof(T), 0);
+ output_tf_shape.AddDim(num_bytes / sizeof(T));
+ AllocateOutputSetMklShape(context, kIdxOutput,
+ output_tensor,
+ output_tf_shape, output_mkl_shape);
+ }
+
+ memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
+ const MklDnnShape& input_grad_dnn_shape,
+ MklDnnData<T> *input_grad_dnn_data) {
+ CHECK_NOTNULL(input_grad_dnn_data);
+ // This shouldn't be necessary at this point, but just in case
+ CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true);
+
+ memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
+ memory::dims orig_input_dims =
+ input_grad_dnn_shape.GetSizesAsMklDnnDims();
+ input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
+ input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc);
+ return input_grad_md;
+ }
+
+ void PrepareAndExecuteNet(
+ const lrn_backward::primitive_desc& lrn_bkwd_desc,
+ const lrn_forward::primitive_desc& lrn_fwd_desc,
+ MklDnnData<T>* src_dnn_data,
+ MklDnnData<T>* input_gradient_diff_dst,
+ MklDnnData<T>* output_diff_src,
+ const memory::primitive_desc& target_diff_dst_pd,
+ const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
+ std::vector<primitive> net;
+
+ // Check for input reordering on the diff dst input
+ input_gradient_diff_dst->CheckReorderToOpMem(
+ lrn_bkwd_desc.diff_dst_primitive_desc(), &net);
+
+ // Check for input reordering on the original input
+ src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(),
+ &net);
+ // Create pooling primitive and add it to net
+ if (nullptr == workspace_dnn_data) {
+ net.push_back(lrn_backward(lrn_bkwd_desc,
+ src_dnn_data->GetOpMem(),
+ input_gradient_diff_dst->GetOpMem(),
+ output_diff_src->GetOpMem()));
+ } else {
+ net.push_back(lrn_backward(lrn_bkwd_desc,
+ src_dnn_data->GetOpMem(),
+ input_gradient_diff_dst->GetOpMem(),
+ workspace_dnn_data->GetOpMem(),
+ output_diff_src->GetOpMem()));
+ }
+ stream(stream::kind::eager).submit(net).wait();
+ }
+
+ void ConfigureWorkspace(const Tensor& workspace_tensor,
+ memory::primitive_desc workspace_pd,
+ MklDnnData<uint8> *workspace_dnn_data) {
+ CHECK_NOTNULL(workspace_dnn_data);
+
+ workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
+ }
+
+ // Fallback implementation - Taken from lrn_op.cc
+ // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
+ // copy.
+ void MklDefaultToEigen(OpKernelContext* context) {
+ Tensor input_gradient_tensor;
+ Tensor orig_input_tensor;
+ Tensor orig_output_tensor;
+
+ MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
+ orig_output_dnn_shape;
+ GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
+ GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
+ GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
+
+ if (input_grad_dnn_shape.IsMklTensor()) {
+ input_gradient_tensor =
+ ConvertMklToTF<T>(context,
+ MklGetInput(context, kIdxGradient),
+ input_grad_dnn_shape);
+ } else {
+ input_gradient_tensor = MklGetInput(context, kIdxGradient);
+ }
+
+ if (orig_input_dnn_shape.IsMklTensor()) {
+ orig_input_tensor =
+ ConvertMklToTF<T>(context,
+ MklGetInput(context, kIdxOrigInput),
+ orig_input_dnn_shape);
+ } else {
+ orig_input_tensor = MklGetInput(context, kIdxOrigInput);
+ }
+
+ if (orig_output_dnn_shape.IsMklTensor()) {
+ orig_output_tensor =
+ ConvertMklToTF<T>(context,
+ MklGetInput(context, kIdxOrigOutput),
+ orig_output_dnn_shape);
+ } else {
+ orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
+ }
+
+ const int64 batch = static_cast<int64>(input_gradient_tensor.dim_size(0));
+ const int64 rows = static_cast<int64>(input_gradient_tensor.dim_size(1));
+ const int64 cols = static_cast<int64>(input_gradient_tensor.dim_size(2));
+ const int64 depth = static_cast<int64>(input_gradient_tensor.dim_size(3));
+ const auto nodes = cols * rows;
+
+ auto grads_shaped =
+ input_gradient_tensor.shaped<T, 2>({nodes * batch, depth});
+
+ auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth});
+ auto activations =
+ orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
+
+ Tensor* output_dnn_data;
+ MklShape mkl_output_mkl_shape;
+ mkl_output_mkl_shape.SetMklTensor(false);
+ mkl_output_mkl_shape.SetDimensions(4);
+ AllocateOutputSetMklShape(context, kIdxOutput,
+ &output_dnn_data,
+ input_gradient_tensor.shape(),
+ mkl_output_mkl_shape);
+
+ auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
+ out_shaped.setZero();
+ auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
+ depth](int64 begin, int64 end) {
+ for (int64 i = begin; i < end; ++i) {
+ for (int64 j = 0; j < depth; ++j) {
+ int64 depth_begin = std::max<int64>(0, j - depth_radius_);
+ int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
+
+ T norm(0);
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ norm += in_shaped(i, k) * in_shaped(i, k);
+ }
+ norm = alpha_ * norm + bias_;
+ DCHECK_GT(norm, T(1e-6));
+ for (int64 k = depth_begin; k < depth_end; ++k) {
+ T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
+ activations(i, j) / norm;
+ if (k == j) {
+ dyi += Eigen::numext::pow(norm, -beta_);
+ }
+ dyi *= grads_shaped(i, j);
+ const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) +=
+ dyi;
+ }
+ }
+ }
+ };
+ auto worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+ Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
+ depth * depth, shard);
+ }
+
+ void SanityCheckInputs(OpKernelContext* context) {
+ const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient);
+ const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
+ const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
+ const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
+ MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape,
+ workspace_dnn_shape;
+ GetMklShape(context, kIdxGradient, &in_grads_dnn_shape);
+ GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape);
+ GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape);
+ GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape);
+ if (in_grads_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("Input gradient must be "
+ "4-dimensional"));
+ } else {
+ OP_REQUIRES(context, input_gradient_tensor.dims() == 4,
+ errors::InvalidArgument("input gradient must be 4-dimensional"));
+ }
+
+ if (in_image_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("input images must be "
+ "4-dimensional"));
+ } else {
+ OP_REQUIRES(context, orig_input_tensor.dims() == 4,
+ errors::InvalidArgument("input images must be "
+ "4-dimensional"));
+ }
+
+ if (out_image_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4,
+ errors::InvalidArgument("Output image must be "
+ "4-dimensional"));
+ } else {
+ OP_REQUIRES(context, orig_output_tensor.dims() == 4,
+ errors::InvalidArgument("Output image must be 4-dimensional"));
+ }
+
+ if (workspace_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false,
+ errors::InvalidArgument("Workspace should not be MKL Tensor."));
+ } else {
+ OP_REQUIRES(context, workspace_tensor.dims() == 1,
+ errors::InvalidArgument("Workspace must be 1-dimensional"));
+ }
+ }
+
+// Input("input_grads: T")
+// Input("input_image: T")
+// Input("output_image: T")
+// Input("workspace: uint8")
+ const int kIdxGradient = 0,
+ kIdxOrigInput = 1,
+ kIdxOrigOutput = 2,
+ kIdxWorkspace = 3,
+ kIdxOutput = 0;
+
+ typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
+ bool workspace_enabled_;
+ int depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+#endif // INTEL_MKL_DNN
+
#define REGISTER_MKL_LRN_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklLRN") \
.Device(DEVICE_CPU) \
@@ -729,6 +1385,7 @@ class MklLRNGradOp : public OpKernel {
.Label(mkl_op_registry::kMklOpLabel), \
MklLRNGradOp<T>);
+
TF_CALL_float(REGISTER_MKL_LRN_CPU);
} // namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d2dfe23888..8ad2c06741 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -3361,7 +3361,11 @@ REGISTER_OP("_MklLRN")
.Input("input: T")
.Input("mkl_input: uint8")
.Output("output: T")
+#ifndef INTEL_MKL_DNN
.Output("workspace: T")
+#else
+ .Output("workspace: uint8")
+#endif
.Output("mkl_output: uint8")
.Output("mkl_workspace: uint8")
.Attr("depth_radius: int = 5")
@@ -3385,7 +3389,11 @@ REGISTER_OP("_MklLRNGrad")
.Input("input_grads: T")
.Input("input_image: T")
.Input("output_image: T")
+#ifndef INTEL_MKL_DNN
.Input("workspace: T")
+#else
+ .Input("workspace: uint8")
+#endif
.Input("mkl_input_grads: uint8")
.Input("mkl_input_image: uint8")
.Input("mkl_output_image: uint8")