aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_lrn_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_lrn_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_lrn_op.cc34
1 files changed, 21 insertions, 13 deletions
diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc
index a8f28202f4..95e0404ba8 100644
--- a/tensorflow/core/kernels/mkl_lrn_op.cc
+++ b/tensorflow/core/kernels/mkl_lrn_op.cc
@@ -43,7 +43,7 @@ limitations under the License.
using mkldnn::lrn_forward;
using mkldnn::lrn_backward;
using mkldnn::prop_kind;
-using mkldnn::algorithm::lrn_across_channels;
+using mkldnn::lrn_across_channels;
using mkldnn::stream;
#endif
@@ -910,17 +910,23 @@ class MklLRNOp : public OpKernel {
Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
GetBandMatrix<T>(depth, depth_radius_, &multiplier);
- Tensor *output_dnn_data, *workspace;
- MklDnnShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
+ Tensor *output_dnn_data = nullptr;
+ MklDnnShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
mkl_output_mkl_shape.SetDimensions(4);
AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
input.shape(), mkl_output_mkl_shape);
+ CHECK_NOTNULL(output_dnn_data);
- mkl_workspace_mkl_shape.SetMklTensor(false);
- mkl_workspace_mkl_shape.SetDimensions(4);
- AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace,
- input.shape(), mkl_workspace_mkl_shape);
+ Tensor* workspace_tensor = nullptr;
+ MklDnnShape workspace_mkl_shape;
+ workspace_mkl_shape.SetMklTensor(false);
+ TensorShape workspace_tf_shape;
+ workspace_tf_shape.AddDim(0);
+ AllocateOutputSetMklShape(context, kIdxWorkspace,
+ &workspace_tensor,
+ workspace_tf_shape, workspace_mkl_shape);
+ CHECK_NOTNULL(workspace_tensor);
auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
@@ -1344,12 +1350,14 @@ class MklLRNGradOp : public OpKernel {
errors::InvalidArgument("Output image must be 4-dimensional"));
}
- if (workspace_dnn_shape.IsMklTensor()) {
- OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false,
- errors::InvalidArgument("Workspace should not be MKL Tensor."));
- } else {
- OP_REQUIRES(context, workspace_tensor.dims() == 1,
- errors::InvalidArgument("Workspace must be 1-dimensional"));
+ if (workspace_enabled_) {
+ if (workspace_dnn_shape.IsMklTensor()) {
+ OP_REQUIRES(context, workspace_dnn_shape.IsMklTensor() == false,
+ errors::InvalidArgument("Workspace should not be MKL Tensor."));
+ } else {
+ OP_REQUIRES(context, workspace_tensor.dims() == 1,
+ errors::InvalidArgument("Workspace must be 1-dimensional"));
+ }
}
}