diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 19:00:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 19:00:50 -0700 |
commit | c4156ee08bed83ce54ab14a606af498dc8ebdbe6 (patch) | |
tree | a41fbe5865114bb3a1650a5173dd0e244f0896b9 /tensorflow/core/util | |
parent | fa607e7e9224b4d88ead0a81fc65c7884d25950a (diff) | |
parent | 0fb7fcaa22c7d4167b4586c8a44f08b8830c0471 (diff) |
Merge pull request #21586 from Intel-tensorflow:pooling3d
PiperOrigin-RevId: 210474549
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 422be9356d..0a96a603d0 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -66,7 +66,6 @@ using mkldnn::reorder; typedef unsigned int uint; #endif - namespace tensorflow { // The file contains a number of utility classes and functions used by MKL @@ -645,6 +644,7 @@ class MklDnnShape { } } + inline void SetTfDimOrder(const size_t dimension, memory::format format) { TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format); SetTfDimOrder(dimension, data_format); @@ -2059,16 +2059,20 @@ class FactoryKeyCreator { } }; -static inline memory::format get_desired_format(int channel) { + +static inline memory::format get_desired_format(int channel, + bool is_2d = true) { memory::format fmt_desired = memory::format::any; - if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) { - fmt_desired = memory::format::nChw16c; + if (port::TestCPUFeature(port::CPUFeature::AVX512F)) { + fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c; } else if (port::TestCPUFeature(port::CPUFeature::AVX2) && (channel % 8) == 0) { - fmt_desired = memory::format::nChw8c; + fmt_desired = is_2d + ? memory::format::nChw8c + : memory::format::ncdhw; //not support avx2 for 3d yet. } else { - fmt_desired = memory::format::nchw; + fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw; } return fmt_desired; } |