diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_lrn_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_lrn_op.cc | 34 |
1 files changed, 21 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index a8f28202f4..95e0404ba8 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -43,7 +43,7 @@ limitations under the License. using mkldnn::lrn_forward; using mkldnn::lrn_backward; using mkldnn::prop_kind; -using mkldnn::algorithm::lrn_across_channels; +using mkldnn::lrn_across_channels; using mkldnn::stream; #endif @@ -910,17 +910,23 @@ class MklLRNOp : public OpKernel { Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth); GetBandMatrix<T>(depth, depth_radius_, &multiplier); - Tensor *output_dnn_data, *workspace; - MklDnnShape mkl_output_mkl_shape, mkl_workspace_mkl_shape; + Tensor *output_dnn_data = nullptr; + MklDnnShape mkl_output_mkl_shape; mkl_output_mkl_shape.SetMklTensor(false); mkl_output_mkl_shape.SetDimensions(4); AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, input.shape(), mkl_output_mkl_shape); + CHECK_NOTNULL(output_dnn_data); - mkl_workspace_mkl_shape.SetMklTensor(false); - mkl_workspace_mkl_shape.SetDimensions(4); - AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace, - input.shape(), mkl_workspace_mkl_shape); + Tensor* workspace_tensor = nullptr; + MklDnnShape workspace_mkl_shape; + workspace_mkl_shape.SetMklTensor(false); + TensorShape workspace_tf_shape; + workspace_tf_shape.AddDim(0); + AllocateOutputSetMklShape(context, kIdxWorkspace, + &workspace_tensor, + workspace_tf_shape, workspace_mkl_shape); + CHECK_NOTNULL(workspace_tensor); auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth}); Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}}; @@ -1344,12 +1350,14 @@ class MklLRNGradOp : public OpKernel { errors::InvalidArgument("Output image must be 4-dimensional")); } - if (workspace_dnn_shape.IsMklTensor()) { - OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false, - errors::InvalidArgument("Workspace should not be MKL Tensor.")); - } else { - OP_REQUIRES(context, workspace_tensor.dims() == 1, - errors::InvalidArgument("Workspace must be 1-dimensional")); + if (workspace_enabled_) { + if (workspace_dnn_shape.IsMklTensor()) { + OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false, + errors::InvalidArgument("Workspace should not be MKL Tensor.")); + } else { + OP_REQUIRES(context, workspace_tensor.dims() == 1, + errors::InvalidArgument("Workspace must be 1-dimensional")); + } } } |