From 135ac89cae38464a9c6ea21af244e4a1bda255ed Mon Sep 17 00:00:00 2001 From: Guozhong Zhuang Date: Mon, 13 Aug 2018 15:52:43 -0700 Subject: enable pooling3D op --- tensorflow/core/util/mkl_util.h | 114 ++++++++++++++++++++++++++++++++++------ 1 file changed, 99 insertions(+), 15 deletions(-) (limited to 'tensorflow/core/util') diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 159a787d05..79fc7500fc 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -66,7 +66,6 @@ using mkldnn::reorder; typedef unsigned int uint; #endif - namespace tensorflow { // The file contains a number of utility classes and functions used by MKL @@ -87,6 +86,17 @@ typedef enum { Dim_I = 1 } MklDnnDims; +typedef enum { + Dim3d_N = 0, + Dim3d_C = 1, + Dim3d_D = 2, + Dim3d_H = 3, + Dim3d_W = 4, + Dim3d_O = 0, + Dim3d_I = 1 +} MklDnnDims3D; + + #ifdef INTEL_MKL_ML_ONLY class MklShape { public: @@ -453,6 +463,14 @@ class MklDnnShape { return this->DimSize(index); } + inline size_t GetDimension3D(char dimension) const { + int index = GetMklDnnTensor3DDimIndex(dimension); + CHECK(index >= 0 && index < this->GetDimension()) + << "Invalid index from the dimension: " << index << ", " << dimension; + return this->DimSize(index); + } + + inline int32 GetMklDnnTensorDimIndex(char dimension) const { switch (dimension) { case 'N': @@ -469,6 +487,24 @@ class MklDnnShape { } } + inline int32 GetMklDnnTensor3DDimIndex(char dimension) const { + switch (dimension) { + case 'N': + return MklDnnDims3D::Dim3d_N; + case 'C': + return MklDnnDims3D::Dim3d_C; + case 'D': + return MklDnnDims3D::Dim3d_D; + case 'H': + return MklDnnDims3D::Dim3d_H; + case 'W': + return MklDnnDims3D::Dim3d_W; + default: + LOG(FATAL) << "Invalid dimension: " << dimension; + return -1; // Avoid compiler warning about missing return value + } + } + inline size_t GetDimension() const { return data_.dimension_; } inline const int* GetSizes() const { return reinterpret_cast(&data_.sizes_[0]); @@ -587,15 +623,29 @@ class MklDnnShape { } inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) { - // TODO(nhasabni): Why do we restrict this to 4D? - CHECK_EQ(dimension, 4); - CHECK(dimension == data_.dimension_); - data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; - data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; - data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; - data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; + if (dimension == 5) { + CHECK(dimension == data_.dimension_); + data_.map_[GetTensorDimIndex<3>(data_format, '0')] = + MklDnnDims3D::Dim3d_D; + data_.map_[GetTensorDimIndex<3>(data_format, '1')] = + MklDnnDims3D::Dim3d_H; + data_.map_[GetTensorDimIndex<3>(data_format, '2')] = + MklDnnDims3D::Dim3d_W; + data_.map_[GetTensorDimIndex<3>(data_format, 'C')] = + MklDnnDims3D::Dim3d_C; + data_.map_[GetTensorDimIndex<3>(data_format, 'N')] = + MklDnnDims3D::Dim3d_N; + } else { + CHECK_EQ(dimension, 4); + CHECK(dimension == data_.dimension_); + data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; + data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; + data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; + data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; + } } + inline void SetTfDimOrder(const size_t dimension, memory::format format) { TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format); SetTfDimOrder(dimension, data_format); @@ -1329,6 +1379,19 @@ memory::data_type MklDnnType() { return memory::data_type::f32; } +/// Map TensorFlow's data format into MKL-DNN 3D data format +/// @input: TensorFlow data format +/// @return: memory::format corresponding to TensorFlow data format; +/// Fails with an error if invalid data format. +inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { + if (format == FORMAT_NHWC) + return memory::format::ndhwc; + else if (format == FORMAT_NCHW) + return memory::format::ncdhw; + TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); + return memory::format::format_undef; +} + /// Map TensorFlow's data format into MKL-DNN data format /// /// @input: TensorFlow data format @@ -1350,9 +1413,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { /// @return: Tensorflow data format corresponding to memory::format /// Fails with an error if invalid data format. inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) { - if (format == memory::format::nhwc) + if (format == memory::format::nhwc || format == memory::format::ndhwc) return FORMAT_NHWC; - else if (format == memory::format::nchw) + else if (format == memory::format::nchw || format == memory::format::ncdhw) return FORMAT_NCHW; TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); @@ -1402,6 +1465,23 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape, return memory::dims({n, c, h, w}); } +inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape, + TensorFormat format) { + // Check validity of format. + CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format), + memory::format::format_undef); + + int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N')); + int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C')); + int d = shape.dim_size(GetTensorDimIndex<3>(format, '0')); + int h = shape.dim_size(GetTensorDimIndex<3>(format, '1')); + int w = shape.dim_size(GetTensorDimIndex<3>(format, '2')); + + // MKL-DNN requires dimensions in NCDHW format. + return memory::dims({n, c, d, h, w}); +} + + /// Overloaded version of function above. Input parameters are /// self-explanatory. inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims, @@ -1976,16 +2056,20 @@ class FactoryKeyCreator { } }; -static inline memory::format get_desired_format(int channel) { + +static inline memory::format get_desired_format(int channel, + bool is_2d = true) { memory::format fmt_desired = memory::format::any; - if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) { - fmt_desired = memory::format::nChw16c; + if (port::TestCPUFeature(port::CPUFeature::AVX512F)) { + fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c; } else if (port::TestCPUFeature(port::CPUFeature::AVX2) && (channel % 8) == 0) { - fmt_desired = memory::format::nChw8c; + fmt_desired = is_2d + ? memory::format::nChw8c + : memory::format::ncdhw; //not support avx2 for 3d yet. } else { - fmt_desired = memory::format::nchw; + fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw; } return fmt_desired; } -- cgit v1.2.3