aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/mkl_pooling_ops_common.h
diff options
context:
space:
mode:
authorGravatar Niranjan Hasabnis <niranjan.hasabnis@intel.com>2018-01-16 21:48:52 -0800
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-01-16 21:48:52 -0800
commitae700bb7462ee1bd4fed3c89441e962f64c89afd (patch)
treef509264b90a6ec86b39f08134e1ce3efcc2deea1 /tensorflow/core/kernels/mkl_pooling_ops_common.h
parentfa0dd4436f88425639fcd589ce3261249c6499be (diff)
Fixes for various MKLDNN unit test failures (#16059)
1. MklLayout pass changes Making workspace type uint8 for MaxPool; Handling duplicate control edge insertion 1) Handles case of inserting duplicate control edge (fixing Mkl layout graph pass unit test) 2) Enables uint8 as workspace tensor type (makes consistent with LRN workspace handling) Workspace tensor type change is also performed in MaxPool and MaxPoolGrad operators. 2. Handling MklReshape failing case MklReshape was failing on a unit test when Mkl layout and Tensorflow layout for input tensors were same, but shape of input tensor and output tensor was different. No reorder is required in such case, but reshape is needed. Before this fix, we were asserting that reorder is performed. 3. Adding support for empty input/filter tensors in Convolution backprop operators
Diffstat (limited to 'tensorflow/core/kernels/mkl_pooling_ops_common.h')
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index d33e91a15d..b974b2c59a 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -231,7 +231,7 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
const pooling_forward::primitive_desc& pool_fwd_desc,
const MklDnnData<T>* src,
MklDnnData<T>* dst,
- MklDnnData<T>* wksp = nullptr) {
+ MklDnnData<uint8>* wksp = nullptr) {
std::vector<primitive> net;
// Create pooling primitive and add it to net
@@ -307,7 +307,7 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
MklDnnData<T>* input_gradient_diff_dst,
MklDnnData<T>* output_diff_src,
const memory::primitive_desc& target_diff_dst_pd,
- const MklDnnData<T>* workspace = nullptr) {
+ const MklDnnData<uint8>* workspace = nullptr) {
std::vector<primitive> net;