aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/mkl_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r--tensorflow/core/util/mkl_util.h691
1 files changed, 624 insertions, 67 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 1bfa4f83a3..118ff0d0d6 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -26,18 +26,23 @@ limitations under the License.
#include "mkl_trans.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
-
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/graph/mkl_graph_util.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
+
+using mkldnn::memory;
+using mkldnn::reorder;
+using mkldnn::primitive;
+using mkldnn::padding_kind;
+using mkldnn::engine;
#endif
// The file contains a number of utility classes and functions used by MKL
@@ -51,6 +56,8 @@ namespace tensorflow {
// Tensorflow tensor.
typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
+typedef enum { Dim_N = 0, Dim_C = 1, Dim_H = 2, Dim_W = 3,
+ Dim_O = 0, Dim_I = 1 } MklDnnDims;
class MklShape {
public:
@@ -143,7 +150,9 @@ class MklShape {
size_t GetDimension() const { return dimension_; }
const size_t* GetSizes() const { return sizes_; }
int64 dim_size(int index) const { return sizes_[index]; }
- int64 tf_dim_size(int index) const { return sizes_[tf_to_mkl_dim_map_[index]]; }
+ int64 tf_dim_size(int index) const {
+ return sizes_[tf_to_mkl_dim_map_[index]];
+ }
const size_t* GetStrides() const { return strides_; }
const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
@@ -227,7 +236,8 @@ class MklShape {
(IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
// Location of sizes. Note dim is not used here, left here
// to make macros consistent.
-#define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
+#define SIZES_OFFSET(dims) \
+ (DIMS_OFFSET + sizeof(size_t))
#define STRIDES_OFFSET(dims) \
(SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
#define MKL_LAYOUT_OFFSET(dims) \
@@ -309,6 +319,266 @@ class MklShape {
nullptr; // TF dimension corresponding to this MKL dimension
};
+#ifdef INTEL_MKL_DNN
+
+// Forward decl
+TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
+
+class MklDnnShape {
+ private:
+ typedef struct {
+ /// Flag to indicate if the tensor is an MKL tensor or not
+ bool is_mkl_tensor_ = false;
+ /// Number of dimensions in Tensorflow format
+ size_t dimension_ = 0;
+ /// Required by MKLDNN for conversions
+ mkldnn_dims_t sizes_; // Required by MKL for conversions
+ memory::format tf_data_format_ = memory::format::format_undef;
+ memory::data_type T_ = memory::data_type::data_undef;
+ // MKL layout
+ mkldnn_memory_desc_t mkl_md_;
+ /// TF dimension corresponding to this MKL dimension
+ mkldnn_dims_t map_;
+ } MklShapeData;
+ MklShapeData data_;
+
+ typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
+#define INVALID_DIM_SIZE -1
+
+
+ public:
+ MklDnnShape() {
+ for (size_t i = 0; i < sizeof(data_.sizes_) /
+ sizeof(data_.sizes_[0]); ++i) {
+ data_.sizes_[i] = -1;
+ }
+ for (size_t i = 0; i < sizeof(data_.map_) /
+ sizeof(data_.map_[0]); ++i) {
+ data_.map_[i] = -1;
+ }
+ }
+
+ ~MklDnnShape() {}
+ TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape); // Cannot copy
+
+ inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
+ inline void SetMklTensor(bool is_mkl_tensor) {
+ data_.is_mkl_tensor_ = is_mkl_tensor;
+ }
+
+ inline void SetDimensions(const size_t dimension) {
+ data_.dimension_ = dimension;
+ }
+ inline size_t GetDimension(char dimension)const {
+ int index = GetMklDnnTensorDimIndex(dimension);
+ CHECK(index >= 0 && index < this->GetDimension())
+ << "Invalid index from the dimension: " << index << ", " << dimension;
+ return this->DimSize(index);
+ }
+
+ inline int32 GetMklDnnTensorDimIndex(char dimension)const {
+ switch (dimension) {
+ case 'N':
+ return MklDnnDims::Dim_N;
+ case 'C':
+ return MklDnnDims::Dim_C;
+ case 'H':
+ return MklDnnDims::Dim_H;
+ case 'W':
+ return MklDnnDims::Dim_W;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ }
+
+ inline size_t GetDimension() const { return data_.dimension_; }
+ inline const int* GetSizes() const {
+ return reinterpret_cast<const int*>(&data_.sizes_[0]);
+ }
+
+ // Returns an mkldnn::memory::dims object that contains the sizes of this
+ // MklDnnShape object.
+ inline memory::dims GetSizesAsMklDnnDims() const {
+ memory::dims retVal;
+ if (data_.is_mkl_tensor_) {
+ int dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
+ for (size_t i = 0 ; i < dimensions; i++) {
+ if (data_.sizes_[i] != INVALID_DIM_SIZE)
+ retVal.push_back(data_.sizes_[i]);
+ }
+ } else {
+ CHECK_EQ(data_.is_mkl_tensor_, true);
+ }
+ return retVal;
+ }
+
+ inline int64 DimSize(int index) const {
+ CHECK_LT(index, sizeof(data_.sizes_)/sizeof(data_.sizes_[0]));
+ return data_.sizes_[index];
+ }
+
+ /// Return TensorShape that describes the Tensorflow shape of the tensor
+ /// represented by this MklShape.
+ inline TensorShape GetTfShape() {
+ CHECK_EQ(data_.is_mkl_tensor_, true);
+
+ std::vector<int32> shape(data_.dimension_, -1);
+ for (size_t idx = 0; idx < data_.dimension_; ++idx) {
+ shape[idx] = data_.sizes_[TfDimIdx(idx)];
+ }
+
+ TensorShape ts;
+ bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
+ CHECK_EQ(ret, true);
+ return ts;
+ }
+
+ inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
+ inline const memory::data_type GetElemType() { return data_.T_; }
+
+ inline void SetMklLayout(memory::primitive_desc* pd) {
+ CHECK_NOTNULL(pd);
+ data_.mkl_md_ = pd->desc().data;
+ }
+ inline const memory::desc GetMklLayout() const {
+ return memory::desc(data_.mkl_md_);
+ }
+
+ inline memory::format GetTfDataFormat() const {
+ return data_.tf_data_format_;
+ }
+ /// We don't create primitive_descriptor for TensorFlow layout now.
+ /// We use lazy evaluation and create it only when needed.
+ inline void SetTfLayout(size_t dims, const memory::dims& sizes,
+ memory::format format) {
+ CHECK_EQ(dims, sizes.size());
+ data_.dimension_ = dims;
+ for (size_t ii = 0; ii < dims; ii++) {
+ data_.sizes_[ii] = sizes[ii];
+ }
+ data_.tf_data_format_ = format;
+ SetTfDimOrder(dims, format);
+ }
+ inline const memory::desc GetTfLayout() const {
+ memory::dims dims;
+ for (size_t ii = 0; ii < data_.dimension_; ii++) {
+ dims.push_back(data_.sizes_[ii]);
+ }
+ return memory::desc(dims, data_.T_, data_.tf_data_format_);
+ }
+ inline const memory::desc GetCurLayout() const {
+ return IsMklTensor() ? GetMklLayout() : GetTfLayout();
+ }
+
+ // nhasabni - I've removed SetTfDimOrder that was setting default order in
+ // case of MKL-ML. We don't need a case of default dimension order because
+ // when an operator that does not get data_format attribute gets all inputs
+ // in Tensorflow format, it will produce output in Tensorflow format.
+ inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
+ CHECK(dimension == data_.dimension_);
+ for (size_t ii = 0; ii < dimension; ii++) {
+ data_.map_[ii] = map[ii];
+ }
+ }
+
+ inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
+ // TODO(nhasabni): Why do we restrict this to 4D?
+ CHECK_EQ(dimension, 4);
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ }
+
+ inline void SetTfDimOrder(const size_t dimension, memory::format format) {
+ TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
+ SetTfDimOrder(dimension, data_format);
+ }
+
+ inline const mkldnn_dim_t* GetTfToMklDimMap() const {
+ return &data_.map_[0];
+ }
+ inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
+ inline int64 TfDimSize(int index) const {
+ return data_.sizes_[TfDimIdx(index)];
+ }
+
+ /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ /// corresponds to MKL's Channel dimension.
+ inline bool IsMklChannelDim(int d) const {
+ return TfDimIdx(d) == MklDnnDims::Dim_C;
+ }
+ /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ /// corresponds to MKL's Batch dimension.
+ inline bool IsMklBatchDim(int d) const {
+ return TfDimIdx(d) == MklDnnDims::Dim_N;
+ }
+ /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ /// corresponds to MKL's Width dimension.
+ inline bool IsMklWidthDim(int d) const {
+ return TfDimIdx(d) == MklDnnDims::Dim_W;
+ }
+ /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
+ /// corresponds to MKL's Height dimension.
+ inline bool IsMklHeightDim(int d) const {
+ return TfDimIdx(d) == MklDnnDims::Dim_H;
+ }
+
+ /// Check if the TF-Mkl dimension ordering map specifies if the input
+ /// tensor is in NCHW format.
+ inline bool IsTensorInNCHWFormat() const {
+ TensorFormat data_format = FORMAT_NCHW;
+ return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
+ IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
+ IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
+ IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
+ }
+
+ /// Check if the TF-Mkl dimension ordering map specifies if the input
+ /// tensor is in NHWC format.
+ inline bool IsTensorInNHWCFormat() const {
+ TensorFormat data_format = FORMAT_NHWC;
+ return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
+ IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
+ IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
+ IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
+ }
+
+ /// The following methods are used for serializing and de-serializing the
+ /// contents of the mklshape object.
+ /// The data is serialized in this order
+ /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
+
+ /// Size of buffer to hold the serialized object, the size is computed by
+ /// following above mentioned order
+ inline size_t GetSerializeBufferSize() const {
+ return sizeof(MklShapeData);
+ }
+
+ void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
+ CHECK(buf_size >= GetSerializeBufferSize())
+ << "Buffer size is too small to SerializeMklDnnShape";
+ *reinterpret_cast<MklShapeData*>(buf) = data_;
+ }
+
+ void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
+ // Make sure buffer holds at least is_mkl_tensor_.
+ CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
+ << "Buffer size is too small in DeSerializeMklDnnShape";
+
+ const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
+ if (is_mkl_tensor) { // If it is an MKL Tensor then read the rest
+ CHECK(buf_size >= GetSerializeBufferSize())
+ << "Buffer size is too small in DeSerializeMklDnnShape";
+ data_ = *reinterpret_cast<const MklShapeData*>(buf);
+ }
+ }
+};
+
+#endif
+
// List of MklShape objects. Used in Concat/Split layers.
typedef std::vector<MklShape> MklShapeList;
@@ -347,6 +617,36 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
return output_tensor;
}
+#ifdef INTEL_MKL_DNN
+template <typename T>
+inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
+ const MklDnnShape& mkl_shape) {
+ Tensor output_tensor;
+ TensorShape output_shape;
+
+#if 0
+ // TODO(nhasabni): need to implement
+ for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
+ // Outermost to innermost dimension
+ output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
+ }
+
+ // Allocate output tensor.
+ context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
+
+ dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
+ void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
+ void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
+
+ if (mkl_tensor.NumElements() != 0) {
+ mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
+ }
+#endif
+
+ return output_tensor;
+}
+#endif
+
// Get the MKL shape from the second string tensor
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape(
@@ -359,6 +659,20 @@ inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
sizeof(uint8));
}
+#ifdef INTEL_MKL_DNN
+inline void GetMklShape(OpKernelContext* ctext, int n,
+ MklDnnShape* mklshape) {
+ mklshape->DeSerializeMklDnnShape(
+ ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+ .flat<uint8>()
+ .data(),
+ ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
+ .flat<uint8>()
+ .size() *
+ sizeof(uint8));
+}
+#endif
+
// Gets the actual input
inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
@@ -382,6 +696,27 @@ inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
}
}
+#ifdef INTEL_MKL_DNN
+/// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
+/// If the input tensor is in MKL layout, then obtains TensorShape from
+/// MklShape.
+inline TensorShape GetTfShape(OpKernelContext* context,
+ size_t input_idx) {
+ // Sanity check.
+ CHECK_NOTNULL(context);
+ CHECK_LT(input_idx, context->num_inputs());
+
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, input_idx, &input_mkl_shape);
+ if (input_mkl_shape.IsMklTensor()) {
+ return input_mkl_shape.GetTfShape();
+ } else {
+ const Tensor& t = MklGetInput(context, input_idx);
+ return t.shape();
+ }
+}
+#endif
+
// Allocate the second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -397,6 +732,23 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
+#ifdef INTEL_MKL_DNN
+// Allocate the second output tensor that will contain
+// the MKL shape serialized
+inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
+ const MklDnnShape& mkl_shape) {
+ Tensor* second_tensor = nullptr;
+ TensorShape second_shape;
+ second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
+ OP_REQUIRES_OK(ctext, ctext->allocate_output(
+ GetTensorMetaDataIndex(n, ctext->num_outputs()),
+ second_shape, &second_tensor));
+ mkl_shape.SerializeMklDnnShape(
+ second_tensor->flat<uint8>().data(),
+ second_tensor->flat<uint8>().size() * sizeof(uint8));
+}
+#endif
+
// Allocate the output tensor, create a second output tensor that will contain
// the MKL shape serialized
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
@@ -417,9 +769,43 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
second_tensor->flat<uint8>().size() * sizeof(uint8));
}
+#ifdef INTEL_MKL_DNN
+// Allocate the output tensor, create a second output tensor that will contain
+// the MKL shape serialized
+inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
+ Tensor** output,
+ const TensorShape& tf_shape,
+ const MklDnnShape& mkl_shape) {
+ Tensor* second_tensor = nullptr;
+ TensorShape second_shape;
+ second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
+ OP_REQUIRES_OK(
+ ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
+ tf_shape, output));
+ OP_REQUIRES_OK(ctext, ctext->allocate_output(
+ GetTensorMetaDataIndex(n, ctext->num_outputs()),
+ second_shape, &second_tensor));
+ mkl_shape.SerializeMklDnnShape(
+ second_tensor->flat<uint8>().data(),
+ second_tensor->flat<uint8>().size() * sizeof(uint8));
+}
+#endif
+
// Allocates a temp tensor and returns the data buffer for temporary storage.
// Currently
-// we only support F32, will need to templatize if other types are added
+#ifdef INTEL_MKL_DNN
+template <typename T>
+inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
+ const memory::primitive_desc& pd, void** buf_out) {
+ TensorShape tf_shape;
+
+ tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
+ OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
+ tf_shape, tensor_out));
+ *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
+}
+#endif
+
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
dnnLayout_t lt_buff, void** buf_out) {
TensorShape tf_shape;
@@ -435,7 +821,7 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
template <typename T>
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
- TensorShape tf_shape) {
+ TensorShape tf_shape) {
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
tf_shape, tensor_out));
}
@@ -669,6 +1055,8 @@ inline bool MklCompareShapes(const TensorShape* input_shape_0,
return true;
}
+// These functions do not compile with MKL-DNN since mkl.h is missing.
+// We may need to remove them later.
// TODO(intel_tf): Remove this routine when faster MKL layout conversion is
// out.
inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
@@ -707,18 +1095,11 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
#ifdef INTEL_MKL_DNN
-using mkldnn::engine;
-using mkldnn::memory;
-using mkldnn::padding_kind;
-using mkldnn::primitive;
-using mkldnn::reorder;
-
/// Return MKL-DNN data type (memory::data_type) for input type T
///
/// @input None
/// @return memory::data_type corresponding to type T
-template <typename T>
-static memory::data_type MklDnnType();
+template<typename T> static memory::data_type MklDnnType();
/// Instantiation for float type. Add similar instantiations for other
/// type if needed.
@@ -733,15 +1114,26 @@ memory::data_type MklDnnType<float>() {
/// @return: memory::format corresponding to TensorFlow data format;
/// Fails with an error if invalid data format.
inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
- if (format == FORMAT_NHWC)
- return memory::format::nhwc;
- else if (format == FORMAT_NCHW)
- return memory::format::nchw;
- TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ if (format == FORMAT_NHWC) return memory::format::nhwc;
+ else if (format == FORMAT_NCHW) return memory::format::nchw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT,
+ "Unsupported data format"));
// Return to get rid of compiler warning
return memory::format::format_undef;
}
+/// Map MKL-DNN data format to TensorFlow's data format
+///
+/// @input: memory::format
+/// @return: Tensorflow data format corresponding to memory::format
+/// Fails with an error if invalid data format.
+inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
+ if (format == memory::format::nhwc) return FORMAT_NHWC;
+ else if (format == memory::format::nchw) return FORMAT_NCHW;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT,
+ "Unsupported data format"));
+}
+
/// Map TensorShape object into memory::dims required by MKL-DNN
///
/// This function will simply map input TensorShape into MKL-DNN dims
@@ -753,7 +1145,7 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return memory::dims corresponding to TensorShape
inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
memory::dims dims(shape.dims());
- for (unsigned int d = 0; d < shape.dims(); ++d) {
+ for (int d = 0; d < shape.dims(); ++d) {
dims[d] = shape.dim_size(d);
}
return dims;
@@ -769,7 +1161,7 @@ inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
/// @input TensorShape object in shape
/// @return memory::dims in MKL-DNN required NCHW format
inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
- TensorFormat format) {
+ TensorFormat format) {
// Check validity of format.
CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
memory::format::format_undef);
@@ -783,6 +1175,43 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
return memory::dims({n, c, h, w});
}
+/// Map MklDnn memory::dims object into TensorShape object.
+///
+/// This function will simply map input shape in MKL-DNN memory::dims format
+/// in Tensorflow's TensorShape object by perserving dimension order.
+///
+/// @input MKL-DNN memory::dims object
+/// @output TensorShape corresponding to memory::dims
+inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
+ std::vector<int32> shape(dims.size(), -1);
+ for (int d = 0; d < dims.size(); d++) {
+ shape[d] = dims[d];
+ }
+
+ TensorShape ret;
+ CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
+ return ret;
+}
+
+/// Function to calculate strides given tensor shape in Tensorflow order
+/// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
+/// dimesion with size 1 is outermost dimension; while dimension with size 4 is
+/// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
+/// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
+///
+/// @input Tensorflow shape in memory::dims type
+/// @return memory::dims containing strides for the tensor.
+inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
+ CHECK_GT(dims_tf_order.size(), 0);
+ memory::dims strides(dims_tf_order.size());
+ int last_dim_idx = dims_tf_order.size() - 1;
+ strides[last_dim_idx] = 1;
+ for (int d = last_dim_idx - 1; d >= 0; d--) {
+ strides[d] = strides[d + 1] * dims_tf_order[d + 1];
+ }
+ return strides;
+}
+
inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
// MKL-DNN only supports zero padding.
return padding_kind::zero;
@@ -808,23 +1237,21 @@ class MklDnnData {
const engine* cpu_engine_;
public:
- explicit MklDnnData(const engine* e)
- : user_memory_(nullptr),
- reorder_memory_(nullptr),
- op_md_(nullptr),
- cpu_engine_(e) {}
+ explicit MklDnnData(const engine* e) : user_memory_(nullptr),
+ reorder_memory_(nullptr),
+ op_md_(nullptr), cpu_engine_(e) {}
~MklDnnData() {
cpu_engine_ = nullptr; // We don't own this.
- delete (user_memory_);
- delete (reorder_memory_);
- delete (op_md_);
+ delete(user_memory_);
+ delete(reorder_memory_);
+ delete(op_md_);
}
- void* GetTensorBuffer(const Tensor* tensor) {
+ inline void* GetTensorBuffer(const Tensor* tensor) const {
CHECK_NOTNULL(tensor);
- return const_cast<void*>(
- static_cast<const void*>(tensor->flat<T>().data()));
+ return const_cast<void*>(static_cast<const void*>(
+ tensor->flat<T>().data()));
}
/// Set user memory primitive using specified dimensions, memory format and
@@ -835,35 +1262,83 @@ class MklDnnData {
/// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
/// memory format HWIO, and the buffer that contains actual values is
/// pointed by data_buffer.
- void SetUsrMem(memory::dims dim, memory::format fm, void* data_buffer) {
- CHECK_NOTNULL(data_buffer);
- CHECK_NOTNULL(cpu_engine_);
- // TODO(nhasabni): can we remove dynamic memory allocation?
- user_memory_ =
- new memory(memory::primitive_desc(
- memory::desc(dim, MklDnnType<T>(), fm), *cpu_engine_),
- data_buffer);
+ inline void SetUsrMem(const memory::dims& dim, memory::format fm,
+ void* data_buffer = nullptr) {
+ auto md = memory::desc(dim, MklDnnType<T>(), fm);
+ SetUsrMem(md, data_buffer);
}
- void SetUsrMem(memory::dims dim, memory::format fm, const Tensor* tensor) {
+ inline void SetUsrMem(const memory::dims& dim, memory::format fm,
+ const Tensor* tensor) {
CHECK_NOTNULL(tensor);
SetUsrMem(dim, fm, GetTensorBuffer(tensor));
}
+ /// Helper function to create memory descriptor in Blocked format
+ ///
+ /// @input: Tensor dimensions
+ /// @input: strides corresponding to dimensions. One can use utility
+ /// function such as CalculateTFStrides to compute strides
+ /// for given dimensions.
+ /// @return: memory::desc object corresponding to blocked memory format
+ /// for given dimensions and strides.
+ static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
+ const memory::dims& strides) {
+ CHECK_EQ(dim.size(), strides.size());
+
+ // We have to construct memory descriptor in a C style. This is not at all
+ // ideal but MKLDNN does not offer any API to construct descriptor in
+ // blocked format except a copy constructor that accepts
+ // mkldnn_memory_desc_t.
+ mkldnn_memory_desc_t md;
+ md.primitive_kind = mkldnn_memory;
+ md.ndims = dim.size();
+ md.format = mkldnn_blocked;
+ md.data_type = memory::convert_to_c(MklDnnType<T>());
+
+ for (size_t i = 0; i < dim.size(); i++) {
+ md.layout_desc.blocking.block_dims[i] = 1;
+ md.layout_desc.blocking.strides[1][i] = 1;
+ md.layout_desc.blocking.strides[0][i] = strides[i];
+ md.layout_desc.blocking.padding_dims[i] = dim[i];
+ md.layout_desc.blocking.offset_padding_to_data[i] = 0;
+ md.dims[i] = dim[i];
+ }
+ md.layout_desc.blocking.offset_padding = 0;
+
+ return memory::desc(md);
+ }
+
+ /// A version of SetUsrMem call that allows user to create memory in blocked
+ /// format. So in addition to accepting dimensions, it also accepts strides.
+ /// This allows user to create memory for tensor in a format that is not
+ /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
+ /// dimensional tensor as a native format. But by using blocked format, a user
+ /// can create memory for 6D tensor.
+ inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
+ void* data_buffer = nullptr) {
+ CHECK_EQ(dim.size(), strides.size());
+ auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
+ SetUsrMem(blocked_md, data_buffer);
+ }
+
+ inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
+ const Tensor* tensor) {
+ CHECK_NOTNULL(tensor);
+ SetUsrMem(dim, strides, GetTensorBuffer(tensor));
+ }
+
/// A version of function to set user memory primitive that accepts memory
/// descriptor directly, instead of accepting dimensions and format. This
/// function is more generic that the one above, but the function above is
/// sufficient in most cases.
- void SetUsrMem(memory::desc md, void* data_buffer) {
- CHECK_NOTNULL(data_buffer);
- CHECK_NOTNULL(cpu_engine_);
- // TODO(nhasabni): can we remove dynamic memory allocation?
- user_memory_ =
- new memory(memory::primitive_desc(md, *cpu_engine_), data_buffer);
+ inline void SetUsrMem(const memory::desc& md, void* data_buffer = nullptr) {
+ auto pd = memory::primitive_desc(md, *cpu_engine_);
+ SetUsrMem(pd, data_buffer);
}
/// A version of SetUsrMem with memory descriptor and tensor
- void SetUsrMem(memory::desc md, const Tensor* tensor) {
+ inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
CHECK_NOTNULL(tensor);
SetUsrMem(md, GetTensorBuffer(tensor));
}
@@ -872,41 +1347,60 @@ class MklDnnData {
/// descriptor directly, instead of accepting dimensions and format. This
/// function is more generic that the one above, but the function above is
/// sufficient in most cases.
- void SetUsrMem(memory::primitive_desc pd, void* data_buffer) {
- CHECK_NOTNULL(data_buffer);
+ inline void SetUsrMem(const memory::primitive_desc& pd,
+ void* data_buffer = nullptr) {
CHECK_NOTNULL(cpu_engine_);
// TODO(nhasabni): can we remove dynamic memory allocation?
- user_memory_ = new memory(pd, data_buffer);
+ if (data_buffer) {
+ user_memory_ = new memory(pd, data_buffer);
+ } else {
+ user_memory_ = new memory(pd);
+ }
}
/// A version of SetUsrMem with primitive descriptor and tensor
- void SetUsrMem(memory::primitive_desc pd, const Tensor* tensor) {
+ inline void SetUsrMem(const memory::primitive_desc& pd,
+ const Tensor* tensor) {
CHECK_NOTNULL(tensor);
SetUsrMem(pd, GetTensorBuffer(tensor));
}
/// Get function for user memory primitive.
- const memory* GetUsrMem() const { return user_memory_; }
+ inline const memory* GetUsrMem() const { return user_memory_; }
/// Get function for primitive descriptor of user memory primitive.
- const memory::primitive_desc GetUsrMemPrimDesc() const {
+ inline const memory::primitive_desc GetUsrMemPrimDesc() const {
CHECK_NOTNULL(user_memory_);
return user_memory_->get_primitive_desc();
}
/// Get function for descriptor of user memory.
- memory::desc GetUsrMemDesc() {
+ inline memory::desc GetUsrMemDesc() {
// This is ugly. Why MKL-DNN does not provide desc() method of const type??
const memory::primitive_desc pd = GetUsrMemPrimDesc();
return const_cast<memory::primitive_desc*>(&pd)->desc();
}
/// Get function for data buffer of user memory primitive.
- void* GetUsrMemDataHandle() const {
+ inline void* GetUsrMemDataHandle() const {
CHECK_NOTNULL(user_memory_);
return user_memory_->get_data_handle();
}
+ /// Set function for data buffer of user memory primitive.
+ inline void* SetUsrMemDataHandle(void* data_buffer) {
+ CHECK_NOTNULL(user_memory_);
+ CHECK_NOTNULL(data_buffer);
+ return user_memory_->set_data_handle(data_buffer);
+ }
+
+ /// Set function for data buffer of user memory primitive.
+ inline void SetUsrMemDataHandle(const Tensor* tensor) {
+ CHECK_NOTNULL(user_memory_);
+ CHECK_NOTNULL(tensor);
+ user_memory_->set_data_handle(GetTensorBuffer(tensor));
+ }
+
/// Get the memory primitive for input and output of an op. If inputs
/// to an op require reorders, then this function returns memory primitive
/// for reorder. Otherwise, it will return memory primitive for user memory.
@@ -915,7 +1409,7 @@ class MklDnnData {
/// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
/// required for I and F (say I_r is reorder primitive for I; F_r is reorder
/// primitive for F), then we need I_r and F_r to perform Conv2D.
- const memory& GetOpMem() const {
+ inline const memory& GetOpMem() const {
return reorder_memory_ ? *reorder_memory_ : *user_memory_;
}
@@ -923,13 +1417,32 @@ class MklDnnData {
/// format. E.g., For Conv2D, the dimensions would be same as user dimensions
/// but memory::format would be mkldnn::any because we want MKL-DNN to choose
/// best layout/format for given input dimensions.
- void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
+ inline void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
// TODO(nhasabni): can we remove dynamic memory allocation?
op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
}
/// Get function for memory descriptor for an operation
- const memory::desc& GetOpMemDesc() const { return *op_md_; }
+ inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
+
+ /// Predicate that checks if we need to reorder user's memory into memory
+ /// pointed by op_pd.
+ ///
+ /// @input: op_pd - memory primitive descriptor of the given input of an
+ /// operation
+ /// @return: true in case reorder of input is needed; false, otherwise.
+ inline bool IsReorderNeeded(const memory::primitive_desc& op_pd) const {
+ CHECK_NOTNULL(user_memory_);
+ return op_pd != user_memory_->get_primitive_desc();
+ }
+
+ /// Function to create a reorder from memory pointed by from to memory pointed
+ /// by to. Returns created primitive.
+ inline primitive CreateReorder(const memory* from, const memory* to) const {
+ CHECK_NOTNULL(from);
+ CHECK_NOTNULL(to);
+ return reorder(*from, *to);
+ }
/// Function to handle input reordering
///
@@ -945,19 +1458,62 @@ class MklDnnData {
/// operation
/// @input: net - net to which to add reorder primitive in case it is needed.
/// @return: true in case reorder of input is needed; false, otherwise.
- bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
- std::vector<primitive>* net) {
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ std::vector<primitive>* net) {
CHECK_NOTNULL(net);
CHECK_NOTNULL(user_memory_);
- if (op_pd != user_memory_->get_primitive_desc()) {
+ if (IsReorderNeeded(op_pd)) {
// TODO(nhasabni): can we remove dynamic memory allocation?
reorder_memory_ = new memory(op_pd);
- net->push_back(reorder(*user_memory_, *reorder_memory_));
+ net->push_back(CreateReorder(user_memory_, reorder_memory_));
+ return true;
+ }
+ return false;
+ }
+
+ /// Overloaded version of above function that accepts memory buffer
+ /// where output of reorder needs to be stored.
+ ///
+ /// @input: op_pd - memory primitive descriptor of the given input of an
+ /// operation
+ /// @reorder_data_handle - memory buffer where output of reorder needs to be
+ /// stored. Primitive does not check if buffer is
+ /// enough size to write.
+ /// @input: net - net to which to add reorder primitive in case it is needed.
+ /// @return: true in case reorder of input is needed; false, otherwise.
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ void* reorder_data_handle,
+ std::vector<primitive>* net) {
+ CHECK_NOTNULL(net);
+ CHECK_NOTNULL(reorder_data_handle);
+ CHECK_NOTNULL(user_memory_);
+ if (IsReorderNeeded(op_pd)) {
+ // TODO(nhasabni): can we remove dynamic memory allocation?
+ reorder_memory_ = new memory(op_pd, reorder_data_handle);
+ net->push_back(CreateReorder(user_memory_, reorder_memory_));
return true;
}
return false;
}
+ /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
+ /// where output of reorder needs to be stored.
+ ///
+ /// @input: op_pd - memory primitive descriptor of the given input of an
+ /// operation
+ /// @reorder_tensor - Tensor whose buffer is to be used to store output of
+ /// reorder. Primitive does not check if buffer is
+ /// enough size to write.
+ /// @input: net - net to which to add reorder primitive in case it is needed.
+ /// @return: true in case reorder of input is needed; false, otherwise.
+ inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
+ Tensor* reorder_tensor,
+ std::vector<primitive>* net) {
+ CHECK_NOTNULL(net);
+ CHECK_NOTNULL(reorder_tensor);
+ return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
+ }
+
/// Function to handle output reorder
///
/// This function performs very similar functionality as input reordering
@@ -970,9 +1526,10 @@ class MklDnnData {
///
/// @input memory primitive descriptor for the given output of an operation
/// @return: true in case reorder of output is needed; false, otherwise.
- bool PrepareReorderToUserMemIfReq(const memory::primitive_desc& op_pd) {
+ inline bool PrepareReorderToUserMemIfReq(
+ const memory::primitive_desc& op_pd) {
CHECK_NOTNULL(user_memory_);
- if (op_pd != user_memory_->get_primitive_desc()) {
+ if (IsReorderNeeded(op_pd)) {
// TODO(nhasabni): can we remove dynamic memory allocation?
reorder_memory_ = new memory(op_pd);
return true;
@@ -987,11 +1544,11 @@ class MklDnnData {
/// to the user-specified output buffer.
///
/// @input: net - net to which to add reorder primitive
- void InsertReorderToUserMem(std::vector<primitive>* net) {
+ inline void InsertReorderToUserMem(std::vector<primitive>* net) {
CHECK_NOTNULL(net);
CHECK_NOTNULL(user_memory_);
CHECK_NOTNULL(reorder_memory_);
- net->push_back(reorder(*reorder_memory_, *user_memory_));
+ net->push_back(CreateReorder(reorder_memory_, user_memory_));
}
};