diff options
author | Guozhong Zhuang <guozhong.zhuang@intel.com> | 2018-08-17 14:35:32 -0700 |
---|---|---|
committer | Guozhong Zhuang <guozhong.zhuang@intel.com> | 2018-08-17 14:35:32 -0700 |
commit | 2459fd5f01b2a135335b588803fd8946ea761387 (patch) | |
tree | f465cac06b3229242c91cf5830e29a21563011a3 /tensorflow/core/util | |
parent | 86d2a1ef43f883bafe9276bcba14eac2bb0cb637 (diff) |
disable primitive reuse for conv2d fwd/bwd (some cases) to avoid big memory caching
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 41 |
1 files changed, 38 insertions, 3 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 422be9356d..ee02debddf 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ #ifdef INTEL_MKL +#include <string> #include <memory> #include <unordered_map> #include <utility> @@ -33,6 +34,12 @@ limitations under the License. #endif #ifdef INTEL_MKL_ML_ONLY +// Using pragma message since #warning doesn't work with all compilers +#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.") +#pragma message("Please use MKL DNN (the default option for --config=mkl)") +#endif + +#ifdef INTEL_MKL_ML_ONLY #include "mkl_dnn.h" #include "mkl_dnn_types.h" #include "mkl_service.h" @@ -50,6 +57,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/env_var.h" #ifndef INTEL_MKL_ML_ONLY #include "mkldnn.hpp" @@ -1994,7 +2002,9 @@ const mkldnn::memory::dims NONE_DIMS = {}; template <typename T> class MklPrimitiveFactory { public: - MklPrimitiveFactory() {} + MklPrimitiveFactory() { + } + ~MklPrimitiveFactory() {} MklPrimitive* GetOp(const string& key) { @@ -2017,6 +2027,22 @@ class MklPrimitiveFactory { map[key] = op; } + /// Function to decide whether HW has AVX512 or AVX2 + /// For those legacy device(w/o AVX512 and AVX2), + /// MKL-DNN GEMM will be used. + static inline bool IsLegacyPlatform() { + return (!port::TestCPUFeature(port::CPUFeature::AVX512F) + && !port::TestCPUFeature(port::CPUFeature::AVX2)); + } + + /// Fuction to check whether primitive memory optimization is enabled + static inline bool IsPrimitiveMemOptEnabled() { + bool is_primitive_mem_opt_enabled = true; + TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true, + &is_primitive_mem_opt_enabled)); + return is_primitive_mem_opt_enabled; + } + private: static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() { static thread_local std::unordered_map<string, MklPrimitive*> map_; @@ -2089,7 +2115,7 @@ class MklReorderPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(to->get_data_handle()); } - private: + private: struct ReorderContext { std::shared_ptr<mkldnn::memory> src_mem; std::shared_ptr<mkldnn::memory> dst_mem; @@ -2131,7 +2157,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> { return instance_; } - private: + private: MklReorderPrimitiveFactory() {} ~MklReorderPrimitiveFactory() {} @@ -2176,6 +2202,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) { return *reorder_prim->GetPrimitive(); } +// utility function to determine if it is conv 1x1 and stride != 1 +// for purpose of temporarily disabling primitive reuse +inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) { + if (filter_dims.size() != 4 || strides.size() != 2) return false; + + return ((filter_dims[2] == 1) && (filter_dims[3] == 1) && + ((strides[0] != 1) || (strides[1] != 1))); +} + #endif // INTEL_MKL_DNN } // namespace tensorflow |