diff options
author | Niranjan Hasabnis <niranjan.hasabnis@intel.com> | 2018-01-16 21:48:52 -0800 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-01-16 21:48:52 -0800 |
commit | ae700bb7462ee1bd4fed3c89441e962f64c89afd (patch) | |
tree | f509264b90a6ec86b39f08134e1ce3efcc2deea1 /tensorflow/core/kernels/mkl_pooling_ops_common.h | |
parent | fa0dd4436f88425639fcd589ce3261249c6499be (diff) |
Fixes for various MKLDNN unit test failures (#16059)
1. MklLayout pass changes
Making workspace type uint8 for MaxPool; Handling duplicate control edge insertion
1) Handles case of inserting duplicate control edge (fixing Mkl layout graph
pass unit test)
2) Enables uint8 as workspace tensor type (makes consistent with LRN workspace
handling)
Workspace tensor type change is also performed in MaxPool and MaxPoolGrad
operators.
2. Handling MklReshape failing case
MklReshape was failing on a unit test when Mkl layout and Tensorflow layout for
input tensors were same, but shape of input tensor and output tensor was
different. No reorder is required in such case, but reshape is needed. Before
this fix, we were asserting that reorder is performed.
3. Adding support for empty input/filter tensors in Convolution backprop operators
Diffstat (limited to 'tensorflow/core/kernels/mkl_pooling_ops_common.h')
-rw-r--r-- | tensorflow/core/kernels/mkl_pooling_ops_common.h | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index d33e91a15d..b974b2c59a 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -231,7 +231,7 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> { const pooling_forward::primitive_desc& pool_fwd_desc, const MklDnnData<T>* src, MklDnnData<T>* dst, - MklDnnData<T>* wksp = nullptr) { + MklDnnData<uint8>* wksp = nullptr) { std::vector<primitive> net; // Create pooling primitive and add it to net @@ -307,7 +307,7 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> { MklDnnData<T>* input_gradient_diff_dst, MklDnnData<T>* output_diff_src, const memory::primitive_desc& target_diff_dst_pd, - const MklDnnData<T>* workspace = nullptr) { + const MklDnnData<uint8>* workspace = nullptr) { std::vector<primitive> net; |