aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_maxpooling_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/mkl_maxpooling_op.cc')
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index de4d7d2e72..82c5229bab 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -517,7 +517,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
MklDnnData<T> dnn_data_input(&cpu_engine);
MklDnnData<T> dnn_data_output(&cpu_engine);
- MklDnnData<T> dnn_data_wksp(&cpu_engine);
+ MklDnnData<uint8> dnn_data_wksp(&cpu_engine);
// initialize variables for the pooling op
MklPoolParameters pool_params;
@@ -588,16 +588,16 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
void AllocateWorkspaceTensor(OpKernelContext* context,
const pooling_forward::primitive_desc& pool_fwd_prim_desc,
- MklDnnData<T>* dnn_data_wksp) {
+ MklDnnData<uint8>* dnn_data_wksp) {
CHECK_NOTNULL(dnn_data_wksp);
Tensor* workspace_tensor = nullptr;
memory::primitive_desc workspace_pd
= pool_fwd_prim_desc.workspace_primitive_desc();
- size_t workspace_t_elems = this->GetNumTElements(workspace_pd);
+ size_t workspace_bytes = workspace_pd.get_size();
MklDnnShape workspace_mkl_shape;
workspace_mkl_shape.SetMklTensor(false);
TensorShape workspace_tf_shape;
- workspace_tf_shape.AddDim(workspace_t_elems);
+ workspace_tf_shape.AddDim(workspace_bytes);
AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace,
&workspace_tensor,
workspace_tf_shape, workspace_mkl_shape);
@@ -651,7 +651,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
if (!context->status().ok()) return;
MklDnnData<T> grad_dnn_data(&cpu_engine);
- MklDnnData<T> workspace_dnn_data(&cpu_engine);
+ MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
MklDnnData<T> output_dnn_data(&cpu_engine);
Tensor* output_tensor = nullptr;
MklPoolParameters pool_params;
@@ -770,7 +770,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
void ConfigureWorkspace(const Tensor& workspace_tensor,
memory::primitive_desc workspace_pd,
- MklDnnData<T> *workspace_dnn_data) {
+ MklDnnData<uint8> *workspace_dnn_data) {
CHECK_NOTNULL(workspace_dnn_data);
workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
@@ -811,7 +811,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
errors::InvalidArgument("Gradient must be "
"4-dimensional"));
}
- if (this->workspace_enabled_){
+ if (this->workspace_enabled_) {
// The workspace should not be an MKL tensor
OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false,
errors::InvalidArgument("Workspace tensor should not"