diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 13:04:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 13:05:05 -0700 |
commit | 9c50882415cb87a7eb81048d42401c64bf0617ef (patch) | |
tree | c550925b2d9e7f6997ace0e3bb3268572f7066b7 /tensorflow/core/util | |
parent | 19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff) | |
parent | 62191da0819b25906c1b2ed96159cfe36ba00383 (diff) |
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 103 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 1 |
3 files changed, 98 insertions, 10 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 159a787d05..422be9356d 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -87,6 +87,16 @@ typedef enum { Dim_I = 1 } MklDnnDims; +typedef enum { + Dim3d_N = 0, + Dim3d_C = 1, + Dim3d_D = 2, + Dim3d_H = 3, + Dim3d_W = 4, + Dim3d_O = 0, + Dim3d_I = 1 +} MklDnnDims3D; + #ifdef INTEL_MKL_ML_ONLY class MklShape { public: @@ -351,6 +361,7 @@ class MklShape { #else // Forward decl +TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format); TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format); memory::dims CalculateTFStrides(const memory::dims& dims_tf_order); memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, @@ -453,6 +464,13 @@ class MklDnnShape { return this->DimSize(index); } + inline size_t GetDimension3D(char dimension) const { + int index = GetMklDnnTensor3DDimIndex(dimension); + CHECK(index >= 0 && index < this->GetDimension()) + << "Invalid index from the dimension: " << index << ", " << dimension; + return this->DimSize(index); + } + inline int32 GetMklDnnTensorDimIndex(char dimension) const { switch (dimension) { case 'N': @@ -469,6 +487,24 @@ class MklDnnShape { } } + inline int32 GetMklDnnTensor3DDimIndex(char dimension) const { + switch (dimension) { + case 'N': + return MklDnnDims3D::Dim3d_N; + case 'C': + return MklDnnDims3D::Dim3d_C; + case 'D': + return MklDnnDims3D::Dim3d_D; + case 'H': + return MklDnnDims3D::Dim3d_H; + case 'W': + return MklDnnDims3D::Dim3d_W; + default: + LOG(FATAL) << "Invalid dimension: " << dimension; + return -1; // Avoid compiler warning about missing return value + } + } + inline size_t GetDimension() const { return data_.dimension_; } inline const int* GetSizes() const { return reinterpret_cast<const int*>(&data_.sizes_[0]); @@ -587,13 +623,26 @@ class MklDnnShape { } inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) { - // TODO(nhasabni): Why do we restrict this to 4D? - CHECK_EQ(dimension, 4); - CHECK(dimension == data_.dimension_); - data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; - data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; - data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; - data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; + if (dimension == 5) { + CHECK(dimension == data_.dimension_); + data_.map_[GetTensorDimIndex<3>(data_format, '0')] = + MklDnnDims3D::Dim3d_D; + data_.map_[GetTensorDimIndex<3>(data_format, '1')] = + MklDnnDims3D::Dim3d_H; + data_.map_[GetTensorDimIndex<3>(data_format, '2')] = + MklDnnDims3D::Dim3d_W; + data_.map_[GetTensorDimIndex<3>(data_format, 'C')] = + MklDnnDims3D::Dim3d_C; + data_.map_[GetTensorDimIndex<3>(data_format, 'N')] = + MklDnnDims3D::Dim3d_N; + } else { + CHECK_EQ(dimension, 4); + CHECK(dimension == data_.dimension_); + data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W; + data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H; + data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C; + data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N; + } } inline void SetTfDimOrder(const size_t dimension, memory::format format) { @@ -1329,6 +1378,19 @@ memory::data_type MklDnnType<float>() { return memory::data_type::f32; } +/// Map TensorFlow's data format into MKL-DNN 3D data format +/// @input: TensorFlow data format +/// @return: memory::format corresponding to TensorFlow data format; +/// Fails with an error if invalid data format. +inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) { + if (format == FORMAT_NHWC) + return memory::format::ndhwc; + else if (format == FORMAT_NCHW) + return memory::format::ncdhw; + TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); + return memory::format::format_undef; +} + /// Map TensorFlow's data format into MKL-DNN data format /// /// @input: TensorFlow data format @@ -1340,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { else if (format == FORMAT_NCHW) return memory::format::nchw; TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); - // Return to get rid of compiler warning return memory::format::format_undef; } @@ -1350,9 +1411,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { /// @return: Tensorflow data format corresponding to memory::format /// Fails with an error if invalid data format. inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) { - if (format == memory::format::nhwc) + if (format == memory::format::nhwc || format == memory::format::ndhwc) return FORMAT_NHWC; - else if (format == memory::format::nchw) + else if (format == memory::format::nchw || format == memory::format::ncdhw) return FORMAT_NCHW; TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); @@ -1402,6 +1463,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape, return memory::dims({n, c, h, w}); } +inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape, + TensorFormat format) { + // Check validity of format. + CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format), + memory::format::format_undef); + + int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N')); + int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C')); + int d = shape.dim_size(GetTensorDimIndex<3>(format, '0')); + int h = shape.dim_size(GetTensorDimIndex<3>(format, '1')); + int w = shape.dim_size(GetTensorDimIndex<3>(format, '2')); + + // MKL-DNN requires dimensions in NCDHW format. + return memory::dims({n, c, d, h, w}); +} + /// Overloaded version of function above. Input parameters are /// self-explanatory. inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims, @@ -1514,6 +1591,8 @@ class MklDnnData { /// Operations memory descriptor memory::desc* op_md_; + // flat to indicate if data is 3D or not. + bool bIs3D; /// Operations temp buffer void* allocated_buffer_; /// CPU engine on which operation will be executed @@ -1540,6 +1619,10 @@ class MklDnnData { static_cast<const void*>(tensor->flat<T>().data())); } + void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; } + + bool GetIs3D() { return bIs3D; } + /// Set user memory primitive using specified dimensions, memory format and /// data_buffer. Function automatically uses element data type by using /// input type T used for creating call object. diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc index a5f7ecf0d1..f331973f5c 100644 --- a/tensorflow/core/util/tensor_format.cc +++ b/tensorflow/core/util/tensor_format.cc @@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() { return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' "; } +string GetConvnetDataFormat2D3DAttrString() { + return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' "; +} + string GetConvnetFilterFormatAttrString() { return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' "; } diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 918835e1fb..b0c349dd90 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString(); // Return the string that specifies the filter format for convnet operations. string GetConvnetFilterFormatAttrString(); string GetConvnet3dFilterFormatAttrString(); +string GetConvnetDataFormat2D3DAttrString(); // Returns a tensor shape for the specified format and dimension sizes. // Works for both 2D and 3D operations. The output shapes are as follows: |