aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 16:17:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 16:17:26 -0700
commit7bceb06eae02094c7b6a5c4fed010c72962f43ce (patch)
tree8610a0fa5f5b8cc98f4db05ce1f2e18b922e5f87 /tensorflow/core/util
parenta0c653b684afa859e7c468fea4b5e600271eec0f (diff)
parent6fdc6be324df7e3f7e3162e161ef4e869bd888fb (diff)
Merge pull request #19403 from Intel-tensorflow:primreuse_pooling
PiperOrigin-RevId: 207625392
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/mkl_util.h94
1 files changed, 59 insertions, 35 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index d90f85e422..a66b1215bd 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,9 +17,10 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
-#include <vector>
+#include <memory>
#include <unordered_map>
#include <utility>
+#include <vector>
#ifdef INTEL_MKL_ML
#include "mkl_dnn.h"
@@ -34,11 +35,11 @@ limitations under the License.
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
-
#ifndef INTEL_MKL_ML
#include "mkldnn.hpp"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -1503,7 +1504,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
-
+ /// Operations temp buffer
+ void* allocated_buffer_;
/// CPU engine on which operation will be executed
const engine* cpu_engine_;
@@ -1512,6 +1514,7 @@ class MklDnnData {
: user_memory_(nullptr),
reorder_memory_(nullptr),
op_md_(nullptr),
+ allocated_buffer_(nullptr),
cpu_engine_(e) {}
~MklDnnData() {
@@ -1652,6 +1655,14 @@ class MklDnnData {
user_memory_->set_data_handle(GetTensorBuffer(tensor));
}
+ /// allocate function for data buffer
+ inline void AllocateBuffer(size_t size) {
+ const int64 kMemoryAlginment = 64; // For AVX512 memory alignment.
+ allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlginment, size);
+ }
+
+ inline void* GetAllocatedBuffer() { return allocated_buffer_; }
+
/// Get the memory primitive for input and output of an op. If inputs
/// to an op require reorders, then this function returns memory primitive
/// for reorder. Otherwise, it will return memory primitive for user memory.
@@ -1873,7 +1884,6 @@ class MklDnnData {
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
stream(stream::kind::eager).submit(net).wait();
}
-
};
/// Base class for operations with reuse of primitives
@@ -1956,11 +1966,25 @@ class FactoryKeyCreator {
}
};
+static inline memory::format get_desired_format(int channel) {
+ memory::format fmt_desired = memory::format::any;
+
+ if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) {
+ fmt_desired = memory::format::nChw16c;
+ } else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
+ (channel % 8) == 0) {
+ fmt_desired = memory::format::nChw8c;
+ } else {
+ fmt_desired = memory::format::nchw;
+ }
+ return fmt_desired;
+}
+
class MklReorderPrimitive : public MklPrimitive {
- public:
- explicit MklReorderPrimitive(const memory* from, const memory* to) {
- Setup(from, to);
- }
+ public:
+ explicit MklReorderPrimitive(const memory* from, const memory* to) {
+ Setup(from, to);
+ }
~MklReorderPrimitive() {}
std::shared_ptr<primitive> GetPrimitive() {
@@ -1972,7 +1996,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -1996,28 +2020,27 @@ class MklReorderPrimitive : public MklPrimitive {
template <typename T>
class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
- public:
- static MklReorderPrimitive* Get(const memory* from,
- const memory* to) {
- auto reorderPrim = static_cast<MklReorderPrimitive*>(
+ public:
+ static MklReorderPrimitive* Get(const memory* from, const memory* to) {
+ auto reorderPrim = static_cast<MklReorderPrimitive*>(
MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
- if (reorderPrim == nullptr) {
- reorderPrim = new MklReorderPrimitive(from, to);
- MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(
- from, to, reorderPrim);
- }
- reorderPrim->SetMemory(from, to);
- return reorderPrim;
+ if (reorderPrim == nullptr) {
+ reorderPrim = new MklReorderPrimitive(from, to);
+ MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
+ reorderPrim);
}
+ reorderPrim->SetMemory(from, to);
+ return reorderPrim;
+ }
static MklReorderPrimitiveFactory & GetInstance() {
static MklReorderPrimitiveFactory instance_;
return instance_;
}
- private:
- MklReorderPrimitiveFactory() {};
- ~MklReorderPrimitiveFactory() {};
+ private:
+ MklReorderPrimitiveFactory() {}
+ ~MklReorderPrimitiveFactory() {}
static string CreateKey(const memory* from, const memory* to) {
string prefix = "reorder";
@@ -2047,18 +2070,19 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
}
};
- /// Fuction to find(or create) a reorder from memory pointed by from to memory pointed
- /// by to, it will created primitive or get primitive from pool if it is cached.
- /// Returns the primitive.
- template <typename T>
- inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
- CHECK_NOTNULL(from);
- CHECK_NOTNULL(to);
- MklReorderPrimitive *reorder_prim =
- MklReorderPrimitiveFactory<T>::Get(from, to);
- return *reorder_prim->GetPrimitive();
- }
-
+/// Fuction to find(or create) a reorder from memory pointed by
+/// from to memory pointed by to, it will created primitive or
+/// get primitive from pool if it is cached.
+/// Returns the primitive.
+template <typename T>
+inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
+ CHECK_NOTNULL(from);
+ CHECK_NOTNULL(to);
+ MklReorderPrimitive* reorder_prim =
+ MklReorderPrimitiveFactory<T>::Get(from, to);
+ return *reorder_prim->GetPrimitive();
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow