diff options
author | Shanqing Cai <cais@google.com> | 2017-12-23 20:12:06 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-12-23 20:12:06 -0500 |
commit | 26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca (patch) | |
tree | 502bd704023204a68368f2433f6ef111cc76adef | |
parent | 0242356e0438cf43f6ac6c7cddf00cb79888da91 (diff) | |
parent | 3ce813dec7cb65fbe3b9cfa37972ed7344dd94b0 (diff) |
Merge pull request #15594 from Intel-tensorflow/pr-mkl-dnn-lrn
MKL: Adding MKL-DNN support for LRN op
-rw-r--r-- | tensorflow/core/graph/mkl_layout_pass.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_lrn_op.cc | 659 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 8 |
3 files changed, 666 insertions, 3 deletions
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 3beca1e5d2..0ffdc42852 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2495,14 +2495,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsDataType, AlwaysRewrite}); - /* rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsLRN, AlwaysRewrite}); rinfo_.push_back({csinfo_.lrn_grad, mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), CopyAttrsLRN, AlwaysRewrite}); - */ rinfo_.push_back({csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool), CopyAttrsPooling, NonDepthBatchWisePoolRewrite}); diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 227765e46d..a8f28202f4 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -17,7 +17,7 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL // layout and primitives, use MKL dnn primitives to compute local // response normalization -#undef INTEL_MKL + #ifdef INTEL_MKL #define EIGEN_USE_THREADS @@ -38,6 +38,15 @@ limitations under the License. #include "tensorflow/core/util/work_sharder.h" #endif +#ifdef INTEL_MKL_DNN +#include "mkldnn.hpp" +using mkldnn::lrn_forward; +using mkldnn::lrn_backward; +using mkldnn::prop_kind; +using mkldnn::algorithm::lrn_across_channels; +using mkldnn::stream; +#endif + namespace tensorflow { namespace { @@ -58,6 +67,8 @@ void GetBandMatrix(int depth, int depth_radius, } // namespace +#ifndef INTEL_MKL_DNN + template <typename T> class MklLRNOp : public OpKernel { public: @@ -328,6 +339,7 @@ class MklLRNOp : public OpKernel { float beta_; }; + template <typename T> class MklLRNGradOp : public OpKernel { public: @@ -648,6 +660,7 @@ class MklLRNGradOp : public OpKernel { const auto nodes = cols * rows; auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth}); + auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth}); auto activations = out_image.shaped<T, 2>({nodes * batch, depth}); @@ -717,6 +730,649 @@ class MklLRNGradOp : public OpKernel { float beta_; }; +#else + +template <typename T> +class MklLRNOp : public OpKernel { + public: + ~MklLRNOp() {} + + explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) { + int64 depth_radius64; + OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); + OP_REQUIRES(context, FastBoundsCheck(depth_radius64, + std::numeric_limits<int>::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); + depth_radius_ = static_cast<size_t>(depth_radius64); + + OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); + OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_)); + workspace_enabled_ = false; + context->GetAttr("workspace_enabled", &workspace_enabled_); + } + + void Compute(OpKernelContext* context) override { + try { + SanityCheckInputs(context); + if (!context->status().ok()) return; + + auto cpu_engine = engine(engine::cpu, 0); + const Tensor& src_tensor = MklGetInput(context, kIdxInput); + MklDnnShape src_dnn_shape; + GetMklShape(context, kIdxInput, &src_dnn_shape); + + // MKL-DNN has a notion of kernel_size and not depth_radius. + int kernel_size = 2 * depth_radius_ + 1; + float new_alpha = alpha_ * kernel_size; + + // if the input tensor is not an MKL Tensor, or if the last + // dimension is not channel, then just use Eigen. + // MKL only support normalization over the channel dimension. + if (!src_dnn_shape.IsMklTensor()) { + MklDefaultToEigen(context, src_tensor); + return; + } else if (!src_dnn_shape.IsMklChannelDim( + src_dnn_shape.GetDimension() - 1) ) { + Tensor converted_tensor = + ConvertMklToTF<T>(context, src_tensor, src_dnn_shape); + MklDefaultToEigen(context, converted_tensor); + return; + } + // At this point, we can assume that the src is an MklTensor + // and we can enable the workspace + workspace_enabled_ = true; + + MklDnnData<T> src_dnn_data(&cpu_engine); + MklDnnData<T> dst_dnn_data(&cpu_engine); + MklDnnData<uint8> workspace_dnn_data(&cpu_engine); + + TensorShape tf_output_shape = src_tensor.shape(); + + memory::desc src_md = src_dnn_shape.GetCurLayout(); + memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims(); + + // Create memory for user input. + // Since Tensorflow always performs normalization over last dimension, + // and MKL-DNN performs normalization over Channel, we tell MKL-DNN + // that input is in NHWC layout with Channel being the last dimension. + src_dnn_data.SetUsrMem(src_md, &src_tensor); + src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc); + + // output_dnn_data and workspace both have the same shape as input + dst_dnn_data.SetUsrMem(src_md); + dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc); + + // Create LRN primitive descriptor. + // Tensorflow's normalization semantics is across channels. + // MKL-DNN also supports normalization within channel. + auto lrn_desc = lrn_forward::desc(prop_kind::forward, + lrn_across_channels, + src_dnn_data.GetUsrMemDesc(), + kernel_size, + new_alpha, beta_, bias_); + auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine); + + // Allocate output_dnn_data tensor. + Tensor* output_tensor = nullptr; + memory::format input_format = src_dnn_shape.GetTfDataFormat(); + AllocateOutputTensor(context, lrn_prim_desc, input_dims, + input_format, &output_tensor); + OP_REQUIRES_OK(context, context->status()); + CHECK_NOTNULL(output_tensor); + dst_dnn_data.SetUsrMemDataHandle(output_tensor); + + // Handle workspace required for MKL-DNN. + AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data); + OP_REQUIRES_OK(context, context->status()); + + PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, + &dst_dnn_data, &workspace_dnn_data); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); + } + } + + private: + void PrepareAndExecuteNet( + const lrn_forward::primitive_desc& lrn_fwd_desc, + MklDnnData<T>* src_dnn_data, + MklDnnData<T>* dst_dnn_data, + MklDnnData<uint8>* wksp_dnn_data = nullptr) { + std::vector<primitive> net; + + // Check for input reorder + src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net); + + // Create pooling primitive and add it to net + if (wksp_dnn_data != nullptr) { + net.push_back(lrn_forward(lrn_fwd_desc, + src_dnn_data->GetOpMem(), + wksp_dnn_data->GetOpMem(), + dst_dnn_data->GetOpMem())); + } else { + net.push_back(lrn_forward(lrn_fwd_desc, + src_dnn_data->GetOpMem(), + dst_dnn_data->GetOpMem())); + } + stream(stream::kind::eager).submit(net).wait(); + } + + void AllocateOutputTensor(OpKernelContext* context, + const lrn_forward::primitive_desc& lrn_fwd_prim_desc, + const memory::dims output_dims_mkl_order, + const memory::format& output_tf_format, + Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc(); + + MklDnnShape output_mkl_shape; + // We only handle the case when the inputs and output are in Mkl format + // Any other case is handled by Eigen + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType<T>()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, + output_tf_format); + TensorShape output_tf_shape; + // only allocate enough space for the elements we need. + size_t num_bytes = dst_pd.get_size(); + CHECK_EQ(num_bytes % sizeof(T), 0); + output_tf_shape.AddDim(num_bytes / sizeof(T)); + AllocateOutputSetMklShape(context, kIdxOutput, + output_tensor, + output_tf_shape, output_mkl_shape); + } + + // Fallback implementation - Taken from lrn_op.cc + // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a + // copy. + void MklDefaultToEigen(OpKernelContext* context, + const Tensor& input) { + const int batch = static_cast<int>(input.dim_size(0)); + const int rows = static_cast<int>(input.dim_size(1)); + const int cols = static_cast<int>(input.dim_size(2)); + const int depth = static_cast<int>(input.dim_size(3)); + const int nodes = cols * rows; + + auto in_shaped = input.shaped<T, 2>({nodes * batch, depth}); + // Multiplying the input with the band matrix has the effect of reducing + // the + // correct patch along the depth. + Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth); + GetBandMatrix<T>(depth, depth_radius_, &multiplier); + + Tensor *output_dnn_data, *workspace; + MklDnnShape mkl_output_mkl_shape, mkl_workspace_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(false); + mkl_output_mkl_shape.SetDimensions(4); + AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, + input.shape(), mkl_output_mkl_shape); + + mkl_workspace_mkl_shape.SetMklTensor(false); + mkl_workspace_mkl_shape.SetDimensions(4); + AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace, + input.shape(), mkl_workspace_mkl_shape); + + auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth}); + Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}}; + auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_; + if (beta_ == T(1)) { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * tmp.inverse(); + } else if (beta_ == T(0.5)) { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * tmp.rsqrt(); + } else { + out_shaped.device(context->eigen_cpu_device()) = + in_shaped * (tmp.log() * -beta_).exp(); + } + } + + void AllocateWorkspaceTensor(OpKernelContext* context, + const lrn_forward::primitive_desc& lrn_fwd_prim_desc, + MklDnnData<uint8>* dnn_data_wksp) { + CHECK_NOTNULL(dnn_data_wksp); + Tensor* workspace_tensor = nullptr; + memory::primitive_desc workspace_pd + = lrn_fwd_prim_desc.workspace_primitive_desc(); + size_t workspace_bytes = workspace_pd.get_size(); + MklDnnShape workspace_mkl_shape; + // the workspace tensor is a uint8 tensor that has + // exactly the number of bytes necessary + workspace_mkl_shape.SetMklTensor(false); + TensorShape workspace_tf_shape; + workspace_tf_shape.AddDim(workspace_bytes); + AllocateOutputSetMklShape(context, kIdxWorkspace, + &workspace_tensor, + workspace_tf_shape, workspace_mkl_shape); + CHECK_NOTNULL(workspace_tensor); + dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); + } + + void SanityCheckInputs(OpKernelContext* context) { + const Tensor& src_tensor = MklGetInput(context, kIdxInput); + MklDnnShape src_dnn_shape; + GetMklShape(context, kIdxInput, &src_dnn_shape); + if (src_dnn_shape.IsMklTensor()) { + OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4, + errors::InvalidArgument("input must be 4-dimensional")); + OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(), + std::numeric_limits<int>::max()), + errors::InvalidArgument("argument to LRN too large")); + } else { + OP_REQUIRES(context, src_tensor.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional")); + OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(), + std::numeric_limits<int>::max()), + errors::InvalidArgument("argument to LRN too large")); + } + } + const int kIdxInput = 0, + kIdxOutput = 0, + kIdxWorkspace = 1; + + typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair; + bool workspace_enabled_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + + +template <typename T> +class MklLRNGradOp : public OpKernel { + public: + explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) { + int64 depth_radius64; + OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); + OP_REQUIRES(context, FastBoundsCheck(depth_radius64, + std::numeric_limits<int>::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); + depth_radius_ = static_cast<int>(depth_radius64); + OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); + OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); + OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_)); + workspace_enabled_ = false; + context->GetAttr("workspace_enabled", &workspace_enabled_); + } + + void Compute(OpKernelContext* context) override { + try { + SanityCheckInputs(context); + if (!context->status().ok()) return; + + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> input_grad_dnn_data(&cpu_engine); + MklDnnData<T> orig_input_dnn_data(&cpu_engine); + MklDnnData<T> orig_output_dnn_data(&cpu_engine); + MklDnnData<T> output_dnn_data(&cpu_engine); + + MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, + orig_output_dnn_shape; + GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); + GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); + GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); + + // We only use MKLDNN if all of the necessary inputs are present + // in mkldnn format, and Channel is the last dimension + bool can_use_mkldnn = workspace_enabled_ && + input_grad_dnn_shape.IsMklTensor() && + orig_input_dnn_shape.IsMklTensor() && + orig_output_dnn_shape.IsMklTensor() && + input_grad_dnn_shape.IsMklChannelDim( + input_grad_dnn_shape.GetDimension() - 1) && + orig_input_dnn_shape.IsMklChannelDim( + orig_input_dnn_shape.GetDimension() - 1) && + orig_output_dnn_shape.IsMklChannelDim( + orig_output_dnn_shape.GetDimension() - 1); + + if (!can_use_mkldnn) { + // Fallback to eigen + MklDefaultToEigen(context); + return; + } + // At this point, we have the all clear to use MklDnn constructs + // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor. + const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient); + const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput); + const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput); + + // Get input sizes in MKL-DNN required NCHW format. + // LRN does not have data_format attribute. But by default it has + // NHWC format. + memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout(); + memory::desc target_diff_dst_md = ConfigureInputGradient( + input_grad_tensor, + input_grad_dnn_shape, + &input_grad_dnn_data); + + memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout(); + memory::dims orig_input_dims = + orig_input_dnn_shape.GetSizesAsMklDnnDims(); + orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor); + orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); + + // output_dnn_data has the same shape as original input + output_dnn_data.SetUsrMem(orig_input_md); + output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); + + // MKL-DNN has a notion of kernel_size and not depth_radius. + int kernel_size = 2 * depth_radius_ + 1; + float new_alpha = alpha_ * kernel_size; + + // Create LRN backward primitive descriptor. It requires LRN forward + // primitive descriptor also. + auto lrn_fwd_desc = lrn_forward::desc(prop_kind::forward, + lrn_across_channels, + orig_input_md, + kernel_size, + new_alpha, beta_, bias_); + auto lrn_fwd_prim_desc = lrn_forward::primitive_desc(lrn_fwd_desc, + cpu_engine); + auto lrn_bwd_desc = lrn_backward::desc(lrn_across_channels, + original_output_md, + target_diff_dst_md, + kernel_size, + new_alpha, beta_, bias_); + auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(lrn_bwd_desc, + cpu_engine, + lrn_fwd_prim_desc); + + Tensor* output_tensor = nullptr; + memory::format orig_input_format + = orig_input_dnn_shape.GetTfDataFormat(); + AllocateOutputTensor(context, lrn_bwd_prim_desc, + orig_input_dims, orig_input_format, &output_tensor); + OP_REQUIRES_OK(context, context->status()); + CHECK_NOTNULL(output_tensor); + output_dnn_data.SetUsrMemDataHandle(output_tensor); + + // Create LRN primitive and add it to the net + // At this point, workspace is enabled, so we don't need + // to check. Pass input workspace to LRN backward primitive. + const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); + MklDnnData<uint8> workspace_dnn_data(&cpu_engine); + ConfigureWorkspace(workspace_tensor, + lrn_fwd_prim_desc.workspace_primitive_desc(), + &workspace_dnn_data); + + PrepareAndExecuteNet(lrn_bwd_prim_desc, + lrn_fwd_prim_desc, + &orig_input_dnn_data, + &input_grad_dnn_data, + &output_dnn_data, + memory::primitive_desc(target_diff_dst_md, cpu_engine), + &workspace_dnn_data); + } catch (mkldnn::error &e) { + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + + ", in file " + string(__FILE__) + ":" + + std::to_string(__LINE__); + OP_REQUIRES_OK(context, + errors::Aborted("Operation received an exception:", + error_msg)); + } + } + + void AllocateOutputTensor(OpKernelContext* context, + const lrn_backward::primitive_desc& lrn_bkwd_prim_desc, + const memory::dims output_dims_mkl_order, + const memory::format& output_tf_format, + Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + memory::primitive_desc dst_pd + = lrn_bkwd_prim_desc.diff_src_primitive_desc(); + MklDnnShape output_mkl_shape; + + // We assume that all outputs at this point are MKL Tensors + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType<T>()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, + output_tf_format); + + TensorShape output_tf_shape; + size_t num_bytes = dst_pd.get_size(); + CHECK_EQ(num_bytes % sizeof(T), 0); + output_tf_shape.AddDim(num_bytes / sizeof(T)); + AllocateOutputSetMklShape(context, kIdxOutput, + output_tensor, + output_tf_shape, output_mkl_shape); + } + + memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor, + const MklDnnShape& input_grad_dnn_shape, + MklDnnData<T> *input_grad_dnn_data) { + CHECK_NOTNULL(input_grad_dnn_data); + // This shouldn't be necessary at this point, but just in case + CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true); + + memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout(); + memory::dims orig_input_dims = + input_grad_dnn_shape.GetSizesAsMklDnnDims(); + input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor); + input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc); + return input_grad_md; + } + + void PrepareAndExecuteNet( + const lrn_backward::primitive_desc& lrn_bkwd_desc, + const lrn_forward::primitive_desc& lrn_fwd_desc, + MklDnnData<T>* src_dnn_data, + MklDnnData<T>* input_gradient_diff_dst, + MklDnnData<T>* output_diff_src, + const memory::primitive_desc& target_diff_dst_pd, + const MklDnnData<uint8>* workspace_dnn_data = nullptr) { + std::vector<primitive> net; + + // Check for input reordering on the diff dst input + input_gradient_diff_dst->CheckReorderToOpMem( + lrn_bkwd_desc.diff_dst_primitive_desc(), &net); + + // Check for input reordering on the original input + src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), + &net); + // Create pooling primitive and add it to net + if (nullptr == workspace_dnn_data) { + net.push_back(lrn_backward(lrn_bkwd_desc, + src_dnn_data->GetOpMem(), + input_gradient_diff_dst->GetOpMem(), + output_diff_src->GetOpMem())); + } else { + net.push_back(lrn_backward(lrn_bkwd_desc, + src_dnn_data->GetOpMem(), + input_gradient_diff_dst->GetOpMem(), + workspace_dnn_data->GetOpMem(), + output_diff_src->GetOpMem())); + } + stream(stream::kind::eager).submit(net).wait(); + } + + void ConfigureWorkspace(const Tensor& workspace_tensor, + memory::primitive_desc workspace_pd, + MklDnnData<uint8> *workspace_dnn_data) { + CHECK_NOTNULL(workspace_dnn_data); + + workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); + } + + // Fallback implementation - Taken from lrn_op.cc + // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a + // copy. + void MklDefaultToEigen(OpKernelContext* context) { + Tensor input_gradient_tensor; + Tensor orig_input_tensor; + Tensor orig_output_tensor; + + MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, + orig_output_dnn_shape; + GetMklShape(context, kIdxGradient, &input_grad_dnn_shape); + GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape); + GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape); + + if (input_grad_dnn_shape.IsMklTensor()) { + input_gradient_tensor = + ConvertMklToTF<T>(context, + MklGetInput(context, kIdxGradient), + input_grad_dnn_shape); + } else { + input_gradient_tensor = MklGetInput(context, kIdxGradient); + } + + if (orig_input_dnn_shape.IsMklTensor()) { + orig_input_tensor = + ConvertMklToTF<T>(context, + MklGetInput(context, kIdxOrigInput), + orig_input_dnn_shape); + } else { + orig_input_tensor = MklGetInput(context, kIdxOrigInput); + } + + if (orig_output_dnn_shape.IsMklTensor()) { + orig_output_tensor = + ConvertMklToTF<T>(context, + MklGetInput(context, kIdxOrigOutput), + orig_output_dnn_shape); + } else { + orig_output_tensor = MklGetInput(context, kIdxOrigOutput); + } + + const int64 batch = static_cast<int64>(input_gradient_tensor.dim_size(0)); + const int64 rows = static_cast<int64>(input_gradient_tensor.dim_size(1)); + const int64 cols = static_cast<int64>(input_gradient_tensor.dim_size(2)); + const int64 depth = static_cast<int64>(input_gradient_tensor.dim_size(3)); + const auto nodes = cols * rows; + + auto grads_shaped = + input_gradient_tensor.shaped<T, 2>({nodes * batch, depth}); + + auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth}); + auto activations = + orig_output_tensor.shaped<T, 2>({nodes * batch, depth}); + + Tensor* output_dnn_data; + MklShape mkl_output_mkl_shape; + mkl_output_mkl_shape.SetMklTensor(false); + mkl_output_mkl_shape.SetDimensions(4); + AllocateOutputSetMklShape(context, kIdxOutput, + &output_dnn_data, + input_gradient_tensor.shape(), + mkl_output_mkl_shape); + + auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth}); + out_shaped.setZero(); + auto shard = [this, activations, in_shaped, grads_shaped, out_shaped, + depth](int64 begin, int64 end) { + for (int64 i = begin; i < end; ++i) { + for (int64 j = 0; j < depth; ++j) { + int64 depth_begin = std::max<int64>(0, j - depth_radius_); + int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1); + + T norm(0); + for (int64 k = depth_begin; k < depth_end; ++k) { + norm += in_shaped(i, k) * in_shaped(i, k); + } + norm = alpha_ * norm + bias_; + DCHECK_GT(norm, T(1e-6)); + for (int64 k = depth_begin; k < depth_end; ++k) { + T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) * + activations(i, j) / norm; + if (k == j) { + dyi += Eigen::numext::pow(norm, -beta_); + } + dyi *= grads_shaped(i, j); + const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += + dyi; + } + } + } + }; + auto worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch, + depth * depth, shard); + } + + void SanityCheckInputs(OpKernelContext* context) { + const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient); + const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput); + const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput); + const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); + MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape, + workspace_dnn_shape; + GetMklShape(context, kIdxGradient, &in_grads_dnn_shape); + GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape); + GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape); + GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape); + if (in_grads_dnn_shape.IsMklTensor()) { + OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4, + errors::InvalidArgument("Input gradient must be " + "4-dimensional")); + } else { + OP_REQUIRES(context, input_gradient_tensor.dims() == 4, + errors::InvalidArgument("input gradient must be 4-dimensional")); + } + + if (in_image_dnn_shape.IsMklTensor()) { + OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4, + errors::InvalidArgument("input images must be " + "4-dimensional")); + } else { + OP_REQUIRES(context, orig_input_tensor.dims() == 4, + errors::InvalidArgument("input images must be " + "4-dimensional")); + } + + if (out_image_dnn_shape.IsMklTensor()) { + OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4, + errors::InvalidArgument("Output image must be " + "4-dimensional")); + } else { + OP_REQUIRES(context, orig_output_tensor.dims() == 4, + errors::InvalidArgument("Output image must be 4-dimensional")); + } + + if (workspace_dnn_shape.IsMklTensor()) { + OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false, + errors::InvalidArgument("Workspace should not be MKL Tensor.")); + } else { + OP_REQUIRES(context, workspace_tensor.dims() == 1, + errors::InvalidArgument("Workspace must be 1-dimensional")); + } + } + +// Input("input_grads: T") +// Input("input_image: T") +// Input("output_image: T") +// Input("workspace: uint8") + const int kIdxGradient = 0, + kIdxOrigInput = 1, + kIdxOrigOutput = 2, + kIdxWorkspace = 3, + kIdxOutput = 0; + + typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair; + bool workspace_enabled_; + int depth_radius_; + float bias_; + float alpha_; + float beta_; +}; + +#endif // INTEL_MKL_DNN + #define REGISTER_MKL_LRN_CPU(T) \ REGISTER_KERNEL_BUILDER(Name("_MklLRN") \ .Device(DEVICE_CPU) \ @@ -729,6 +1385,7 @@ class MklLRNGradOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklLRNGradOp<T>); + TF_CALL_float(REGISTER_MKL_LRN_CPU); } // namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index d2dfe23888..8ad2c06741 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -3361,7 +3361,11 @@ REGISTER_OP("_MklLRN") .Input("input: T") .Input("mkl_input: uint8") .Output("output: T") +#ifndef INTEL_MKL_DNN .Output("workspace: T") +#else + .Output("workspace: uint8") +#endif .Output("mkl_output: uint8") .Output("mkl_workspace: uint8") .Attr("depth_radius: int = 5") @@ -3385,7 +3389,11 @@ REGISTER_OP("_MklLRNGrad") .Input("input_grads: T") .Input("input_image: T") .Input("output_image: T") +#ifndef INTEL_MKL_DNN .Input("workspace: T") +#else + .Input("workspace: uint8") +#endif .Input("mkl_input_grads: uint8") .Input("mkl_input_image: uint8") .Input("mkl_output_image: uint8") |