aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/mkl_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/mkl_util.h')
-rw-r--r--tensorflow/core/util/mkl_util.h38
1 files changed, 37 insertions, 1 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 6a37256ea9..67468bdc3f 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -23,7 +23,7 @@ limitations under the License.
#include "third_party/mkl/include/mkl_dnn.h"
#include "third_party/mkl/include/mkl_dnn_types.h"
#include "third_party/mkl/include/mkl_service.h"
-
+#include "third_party/mkl/include/mkl_trans.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -616,6 +616,42 @@ inline void ForwarMklTensorInToOut(OpKernelContext* context,
}
}
+ // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
+ // out.
+inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
+ const float* buf_in = input.flat<float>().data();
+ float* buf_out = (*output)->flat<float>().data();
+
+ int64 N = input.dim_size(0);
+ int64 H = input.dim_size(1);
+ int64 W = input.dim_size(2);
+ int64 C = input.dim_size(3);
+ int64 stride_n = H*W*C;
+# pragma omp parallel for num_threads(16)
+ for (int64 n = 0; n < N; ++n) {
+ mkl_somatcopy('R', 'T', H*W, C, 1, buf_in + n*stride_n, C,
+ buf_out + n*stride_n, H*W);
+ }
+}
+
+ // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
+ // out.
+inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
+ const float* buf_in = input.flat<float>().data();
+ float* buf_out = (*output)->flat<float>().data();
+
+ int64 N = (*output)->dim_size(0);
+ int64 H = (*output)->dim_size(1);
+ int64 W = (*output)->dim_size(2);
+ int64 C = (*output)->dim_size(3);
+ int64 stride_n = H*W*C;
+# pragma omp parallel for num_threads(16)
+ for (int64 n = 0; n < N; ++n) {
+ mkl_somatcopy('R', 'T', C, H*W, 1, buf_in + n*stride_n, H*W,
+ buf_out + n*stride_n, C);
+ }
+}
+
namespace mkl_op_registry {
static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'";