diff options
author | AG Ramesh <ag.ramesh@intel.com> | 2018-08-16 13:32:01 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-16 13:32:01 -0700 |
commit | bb5d67ae856e66dd99600fbae8973e8fd7de801a (patch) | |
tree | 0965536ec669bd5378a1e161e2f2bedf49a15145 /tensorflow/core/util | |
parent | 135ac89cae38464a9c6ea21af244e4a1bda255ed (diff) | |
parent | 9c50882415cb87a7eb81048d42401c64bf0617ef (diff) |
Merge branch 'master' into pooling3d
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r-- | tensorflow/core/util/mkl_util.h | 13 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 1 |
3 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 79fc7500fc..0a96a603d0 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -96,7 +96,6 @@ typedef enum { Dim3d_I = 1 } MklDnnDims3D; - #ifdef INTEL_MKL_ML_ONLY class MklShape { public: @@ -361,6 +360,7 @@ class MklShape { #else // Forward decl +TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format); TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format); memory::dims CalculateTFStrides(const memory::dims& dims_tf_order); memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, @@ -470,7 +470,6 @@ class MklDnnShape { return this->DimSize(index); } - inline int32 GetMklDnnTensorDimIndex(char dimension) const { switch (dimension) { case 'N': @@ -1403,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) { else if (format == FORMAT_NCHW) return memory::format::nchw; TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format")); - // Return to get rid of compiler warning return memory::format::format_undef; } @@ -1466,7 +1464,7 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape, } inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape, - TensorFormat format) { + TensorFormat format) { // Check validity of format. CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format), memory::format::format_undef); @@ -1481,7 +1479,6 @@ inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape, return memory::dims({n, c, d, h, w}); } - /// Overloaded version of function above. Input parameters are /// self-explanatory. inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims, @@ -1594,6 +1591,8 @@ class MklDnnData { /// Operations memory descriptor memory::desc* op_md_; + // flat to indicate if data is 3D or not. + bool bIs3D; /// Operations temp buffer void* allocated_buffer_; /// CPU engine on which operation will be executed @@ -1620,6 +1619,10 @@ class MklDnnData { static_cast<const void*>(tensor->flat<T>().data())); } + void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; } + + bool GetIs3D() { return bIs3D; } + /// Set user memory primitive using specified dimensions, memory format and /// data_buffer. Function automatically uses element data type by using /// input type T used for creating call object. diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc index a5f7ecf0d1..f331973f5c 100644 --- a/tensorflow/core/util/tensor_format.cc +++ b/tensorflow/core/util/tensor_format.cc @@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() { return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' "; } +string GetConvnetDataFormat2D3DAttrString() { + return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' "; +} + string GetConvnetFilterFormatAttrString() { return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' "; } diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 918835e1fb..b0c349dd90 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString(); // Return the string that specifies the filter format for convnet operations. string GetConvnetFilterFormatAttrString(); string GetConvnet3dFilterFormatAttrString(); +string GetConvnetDataFormat2D3DAttrString(); // Returns a tensor shape for the specified format and dimension sizes. // Works for both 2D and 3D operations. The output shapes are as follows: |