aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 19:00:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 19:00:50 -0700
commitc4156ee08bed83ce54ab14a606af498dc8ebdbe6 (patch)
treea41fbe5865114bb3a1650a5173dd0e244f0896b9 /tensorflow/core/util
parentfa607e7e9224b4d88ead0a81fc65c7884d25950a (diff)
parent0fb7fcaa22c7d4167b4586c8a44f08b8830c0471 (diff)
Merge pull request #21586 from Intel-tensorflow:pooling3d
PiperOrigin-RevId: 210474549
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/mkl_util.h16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 422be9356d..0a96a603d0 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -66,7 +66,6 @@ using mkldnn::reorder;
typedef unsigned int uint;
#endif
-
namespace tensorflow {
// The file contains a number of utility classes and functions used by MKL
@@ -645,6 +644,7 @@ class MklDnnShape {
}
}
+
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
SetTfDimOrder(dimension, data_format);
@@ -2059,16 +2059,20 @@ class FactoryKeyCreator {
}
};
-static inline memory::format get_desired_format(int channel) {
+
+static inline memory::format get_desired_format(int channel,
+ bool is_2d = true) {
memory::format fmt_desired = memory::format::any;
- if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) {
- fmt_desired = memory::format::nChw16c;
+ if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
+ fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = memory::format::nChw8c;
+ fmt_desired = is_2d
+ ? memory::format::nChw8c
+ : memory::format::ncdhw; //not support avx2 for 3d yet.
} else {
- fmt_desired = memory::format::nchw;
+ fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
return fmt_desired;
}