aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:04:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 13:05:05 -0700
commit9c50882415cb87a7eb81048d42401c64bf0617ef (patch)
treec550925b2d9e7f6997ace0e3bb3268572f7066b7 /tensorflow/core/util
parent19cafed2ae69ce5cbc4d2b2fc9176fb4c550040f (diff)
parent62191da0819b25906c1b2ed96159cfe36ba00383 (diff)
Merge pull request #21324 from Intel-tensorflow:conv3d
PiperOrigin-RevId: 209032082
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/mkl_util.h103
-rw-r--r--tensorflow/core/util/tensor_format.cc4
-rw-r--r--tensorflow/core/util/tensor_format.h1
3 files changed, 98 insertions, 10 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 159a787d05..422be9356d 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -87,6 +87,16 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+typedef enum {
+ Dim3d_N = 0,
+ Dim3d_C = 1,
+ Dim3d_D = 2,
+ Dim3d_H = 3,
+ Dim3d_W = 4,
+ Dim3d_O = 0,
+ Dim3d_I = 1
+} MklDnnDims3D;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -351,6 +361,7 @@ class MklShape {
#else
// Forward decl
+TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
@@ -453,6 +464,13 @@ class MklDnnShape {
return this->DimSize(index);
}
+ inline size_t GetDimension3D(char dimension) const {
+ int index = GetMklDnnTensor3DDimIndex(dimension);
+ CHECK(index >= 0 && index < this->GetDimension())
+ << "Invalid index from the dimension: " << index << ", " << dimension;
+ return this->DimSize(index);
+ }
+
inline int32 GetMklDnnTensorDimIndex(char dimension) const {
switch (dimension) {
case 'N':
@@ -469,6 +487,24 @@ class MklDnnShape {
}
}
+ inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
+ switch (dimension) {
+ case 'N':
+ return MklDnnDims3D::Dim3d_N;
+ case 'C':
+ return MklDnnDims3D::Dim3d_C;
+ case 'D':
+ return MklDnnDims3D::Dim3d_D;
+ case 'H':
+ return MklDnnDims3D::Dim3d_H;
+ case 'W':
+ return MklDnnDims3D::Dim3d_W;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ }
+
inline size_t GetDimension() const { return data_.dimension_; }
inline const int* GetSizes() const {
return reinterpret_cast<const int*>(&data_.sizes_[0]);
@@ -587,13 +623,26 @@ class MklDnnShape {
}
inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
- // TODO(nhasabni): Why do we restrict this to 4D?
- CHECK_EQ(dimension, 4);
- CHECK(dimension == data_.dimension_);
- data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
- data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
- data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
- data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ if (dimension == 5) {
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
+ MklDnnDims3D::Dim3d_D;
+ data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
+ MklDnnDims3D::Dim3d_H;
+ data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
+ MklDnnDims3D::Dim3d_W;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
+ MklDnnDims3D::Dim3d_C;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
+ MklDnnDims3D::Dim3d_N;
+ } else {
+ CHECK_EQ(dimension, 4);
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ }
}
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
@@ -1329,6 +1378,19 @@ memory::data_type MklDnnType<float>() {
return memory::data_type::f32;
}
+/// Map TensorFlow's data format into MKL-DNN 3D data format
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::ndhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::ncdhw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ return memory::format::format_undef;
+}
+
/// Map TensorFlow's data format into MKL-DNN data format
///
/// @input: TensorFlow data format
@@ -1340,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
else if (format == FORMAT_NCHW)
return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
- // Return to get rid of compiler warning
return memory::format::format_undef;
}
@@ -1350,9 +1411,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return: Tensorflow data format corresponding to memory::format
/// Fails with an error if invalid data format.
inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
- if (format == memory::format::nhwc)
+ if (format == memory::format::nhwc || format == memory::format::ndhwc)
return FORMAT_NHWC;
- else if (format == memory::format::nchw)
+ else if (format == memory::format::nchw || format == memory::format::ncdhw)
return FORMAT_NCHW;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
@@ -1402,6 +1463,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
return memory::dims({n, c, h, w});
}
+inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
+ int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
+ int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
+ int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
+
+ // MKL-DNN requires dimensions in NCDHW format.
+ return memory::dims({n, c, d, h, w});
+}
+
/// Overloaded version of function above. Input parameters are
/// self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
@@ -1514,6 +1591,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
+ // flat to indicate if data is 3D or not.
+ bool bIs3D;
/// Operations temp buffer
void* allocated_buffer_;
/// CPU engine on which operation will be executed
@@ -1540,6 +1619,10 @@ class MklDnnData {
static_cast<const void*>(tensor->flat<T>().data()));
}
+ void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
+
+ bool GetIs3D() { return bIs3D; }
+
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index a5f7ecf0d1..f331973f5c 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() {
return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' ";
}
+string GetConvnetDataFormat2D3DAttrString() {
+ return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' ";
+}
+
string GetConvnetFilterFormatAttrString() {
return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' ";
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 918835e1fb..b0c349dd90 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString();
// Return the string that specifies the filter format for convnet operations.
string GetConvnetFilterFormatAttrString();
string GetConvnet3dFilterFormatAttrString();
+string GetConvnetDataFormat2D3DAttrString();
// Returns a tensor shape for the specified format and dimension sizes.
// Works for both 2D and 3D operations. The output shapes are as follows: