aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-08-17 14:35:32 -0700
committerGravatar Guozhong Zhuang <guozhong.zhuang@intel.com>2018-08-17 14:35:32 -0700
commit2459fd5f01b2a135335b588803fd8946ea761387 (patch)
treef465cac06b3229242c91cf5830e29a21563011a3 /tensorflow/core/util
parent86d2a1ef43f883bafe9276bcba14eac2bb0cb637 (diff)
disable primitive reuse for conv2d fwd/bwd (some cases) to avoid big memory caching
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/mkl_util.h41
1 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 422be9356d..ee02debddf 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <string>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -33,6 +34,12 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
+// Using pragma message since #warning doesn't work with all compilers
+#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.")
+#pragma message("Please use MKL DNN (the default option for --config=mkl)")
+#endif
+
+#ifdef INTEL_MKL_ML_ONLY
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#include "mkl_service.h"
@@ -50,6 +57,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/env_var.h"
#ifndef INTEL_MKL_ML_ONLY
#include "mkldnn.hpp"
@@ -1994,7 +2002,9 @@ const mkldnn::memory::dims NONE_DIMS = {};
template <typename T>
class MklPrimitiveFactory {
public:
- MklPrimitiveFactory() {}
+ MklPrimitiveFactory() {
+ }
+
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
@@ -2017,6 +2027,22 @@ class MklPrimitiveFactory {
map[key] = op;
}
+ /// Function to decide whether HW has AVX512 or AVX2
+ /// For those legacy device(w/o AVX512 and AVX2),
+ /// MKL-DNN GEMM will be used.
+ static inline bool IsLegacyPlatform() {
+ return (!port::TestCPUFeature(port::CPUFeature::AVX512F)
+ && !port::TestCPUFeature(port::CPUFeature::AVX2));
+ }
+
+ /// Fuction to check whether primitive memory optimization is enabled
+ static inline bool IsPrimitiveMemOptEnabled() {
+ bool is_primitive_mem_opt_enabled = true;
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
+ return is_primitive_mem_opt_enabled;
+ }
+
private:
static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<string, MklPrimitive*> map_;
@@ -2089,7 +2115,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -2131,7 +2157,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- private:
+ private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
@@ -2176,6 +2202,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
return *reorder_prim->GetPrimitive();
}
+// utility function to determine if it is conv 1x1 and stride != 1
+// for purpose of temporarily disabling primitive reuse
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+ if (filter_dims.size() != 4 || strides.size() != 2) return false;
+
+ return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
+ ((strides[0] != 1) || (strides[1] != 1)));
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow