diff options
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r-- | tensorflow/core/kernels/mkl_maxpooling_op.cc | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index de4d7d2e72..82c5229bab 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -517,7 +517,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { MklDnnData<T> dnn_data_input(&cpu_engine); MklDnnData<T> dnn_data_output(&cpu_engine); - MklDnnData<T> dnn_data_wksp(&cpu_engine); + MklDnnData<uint8> dnn_data_wksp(&cpu_engine); // initialize variables for the pooling op MklPoolParameters pool_params; @@ -588,16 +588,16 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { void AllocateWorkspaceTensor(OpKernelContext* context, const pooling_forward::primitive_desc& pool_fwd_prim_desc, - MklDnnData<T>* dnn_data_wksp) { + MklDnnData<uint8>* dnn_data_wksp) { CHECK_NOTNULL(dnn_data_wksp); Tensor* workspace_tensor = nullptr; memory::primitive_desc workspace_pd = pool_fwd_prim_desc.workspace_primitive_desc(); - size_t workspace_t_elems = this->GetNumTElements(workspace_pd); + size_t workspace_bytes = workspace_pd.get_size(); MklDnnShape workspace_mkl_shape; workspace_mkl_shape.SetMklTensor(false); TensorShape workspace_tf_shape; - workspace_tf_shape.AddDim(workspace_t_elems); + workspace_tf_shape.AddDim(workspace_bytes); AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, &workspace_tensor, workspace_tf_shape, workspace_mkl_shape); @@ -651,7 +651,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { if (!context->status().ok()) return; MklDnnData<T> grad_dnn_data(&cpu_engine); - MklDnnData<T> workspace_dnn_data(&cpu_engine); + MklDnnData<uint8> workspace_dnn_data(&cpu_engine); MklDnnData<T> output_dnn_data(&cpu_engine); Tensor* output_tensor = nullptr; MklPoolParameters pool_params; @@ -770,7 +770,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { void ConfigureWorkspace(const Tensor& workspace_tensor, memory::primitive_desc workspace_pd, - MklDnnData<T> *workspace_dnn_data) { + MklDnnData<uint8> *workspace_dnn_data) { CHECK_NOTNULL(workspace_dnn_data); workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); @@ -811,7 +811,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { errors::InvalidArgument("Gradient must be " "4-dimensional")); } - if (this->workspace_enabled_){ + if (this->workspace_enabled_) { // The workspace should not be an MKL tensor OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false, errors::InvalidArgument("Workspace tensor should not" |