aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 12:37:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 12:37:51 -0700
commitb115e1905cfaca4ee4236ab5f9bfb82bf0b0b691 (patch)
tree269a204604d696db5c5d53e70f466d33c2ee9cab /tensorflow/core/util
parent886bc2c290701547ecb09f9f8d2bd8304d24254f (diff)
parent420bac9053f59c3650ce7cbc9291d90feca5c47b (diff)
Merge pull request #21690 from Intel-tensorflow:prim_reuse_disable
PiperOrigin-RevId: 211126300
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/mkl_util.h37
1 files changed, 34 insertions, 3 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 6474319370..2fa4b73e59 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <string>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -56,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/env_var.h"
#ifndef INTEL_MKL_ML_ONLY
#include "mkldnn.hpp"
@@ -102,6 +104,8 @@ typedef enum {
Dim3d_I = 1
} MklDnnDims3D;
+static const int kSmallBatchSize = 32;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -2000,7 +2004,9 @@ const mkldnn::memory::dims NONE_DIMS = {};
template <typename T>
class MklPrimitiveFactory {
public:
- MklPrimitiveFactory() {}
+ MklPrimitiveFactory() {
+ }
+
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
@@ -2023,6 +2029,22 @@ class MklPrimitiveFactory {
map[key] = op;
}
+ /// Function to decide whether HW has AVX512 or AVX2
+ /// For those legacy device(w/o AVX512 and AVX2),
+ /// MKL-DNN GEMM will be used.
+ static inline bool IsLegacyPlatform() {
+ return (!port::TestCPUFeature(port::CPUFeature::AVX512F)
+ && !port::TestCPUFeature(port::CPUFeature::AVX2));
+ }
+
+ /// Fuction to check whether primitive memory optimization is enabled
+ static inline bool IsPrimitiveMemOptEnabled() {
+ bool is_primitive_mem_opt_enabled = true;
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
+ return is_primitive_mem_opt_enabled;
+ }
+
private:
static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<string, MklPrimitive*> map_;
@@ -2099,7 +2121,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -2141,7 +2163,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- private:
+ private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
@@ -2186,6 +2208,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
return *reorder_prim->GetPrimitive();
}
+// utility function to determine if it is conv 1x1 and stride != 1
+// for purpose of temporarily disabling primitive reuse
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+ if (filter_dims.size() != 4 || strides.size() != 2) return false;
+
+ return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
+ ((strides[0] != 1) || (strides[1] != 1)));
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow