aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util
diff options
context:
space:
mode:
authorGravatar AG Ramesh <ag.ramesh@intel.com>2018-08-16 13:32:01 -0700
committerGravatar GitHub <noreply@github.com>2018-08-16 13:32:01 -0700
commitbb5d67ae856e66dd99600fbae8973e8fd7de801a (patch)
tree0965536ec669bd5378a1e161e2f2bedf49a15145 /tensorflow/core/util
parent135ac89cae38464a9c6ea21af244e4a1bda255ed (diff)
parent9c50882415cb87a7eb81048d42401c64bf0617ef (diff)
Merge branch 'master' into pooling3d
Diffstat (limited to 'tensorflow/core/util')
-rw-r--r--tensorflow/core/util/mkl_util.h13
-rw-r--r--tensorflow/core/util/tensor_format.cc4
-rw-r--r--tensorflow/core/util/tensor_format.h1
3 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 79fc7500fc..0a96a603d0 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -96,7 +96,6 @@ typedef enum {
Dim3d_I = 1
} MklDnnDims3D;
-
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -361,6 +360,7 @@ class MklShape {
#else
// Forward decl
+TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
@@ -470,7 +470,6 @@ class MklDnnShape {
return this->DimSize(index);
}
-
inline int32 GetMklDnnTensorDimIndex(char dimension) const {
switch (dimension) {
case 'N':
@@ -1403,7 +1402,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
else if (format == FORMAT_NCHW)
return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
- // Return to get rid of compiler warning
return memory::format::format_undef;
}
@@ -1466,7 +1464,7 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
}
inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
- TensorFormat format) {
+ TensorFormat format) {
// Check validity of format.
CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
memory::format::format_undef);
@@ -1481,7 +1479,6 @@ inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
return memory::dims({n, c, d, h, w});
}
-
/// Overloaded version of function above. Input parameters are
/// self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
@@ -1594,6 +1591,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
+ // flat to indicate if data is 3D or not.
+ bool bIs3D;
/// Operations temp buffer
void* allocated_buffer_;
/// CPU engine on which operation will be executed
@@ -1620,6 +1619,10 @@ class MklDnnData {
static_cast<const void*>(tensor->flat<T>().data()));
}
+ void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
+
+ bool GetIs3D() { return bIs3D; }
+
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index a5f7ecf0d1..f331973f5c 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() {
return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' ";
}
+string GetConvnetDataFormat2D3DAttrString() {
+ return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' ";
+}
+
string GetConvnetFilterFormatAttrString() {
return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' ";
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 918835e1fb..b0c349dd90 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString();
// Return the string that specifies the filter format for convnet operations.
string GetConvnetFilterFormatAttrString();
string GetConvnet3dFilterFormatAttrString();
+string GetConvnetDataFormat2D3DAttrString();
// Returns a tensor shape for the specified format and dimension sizes.
// Works for both 2D and 3D operations. The output shapes are as follows: