aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-14 17:59:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-14 18:01:20 -0700
commit889833b5f145079d4837a5da73ffb2a997014764 (patch)
treef3a09e6befa101c88acc3c6c6456ea2d3e966703
parent99d48bdec4605cdd21f09d2dfcfc70139cbe4ebd (diff)
Add HWNC and HWCN data format support
PiperOrigin-RevId: 200650683
-rw-r--r--tensorflow/core/util/tensor_format.cc12
-rw-r--r--tensorflow/core/util/tensor_format.h47
-rw-r--r--tensorflow/core/util/tensor_format_test.cc25
3 files changed, 76 insertions, 8 deletions
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index d4311d1ab0..a5f7ecf0d1 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -43,6 +43,10 @@ string ToString(TensorFormat format) {
return "NCHW_VECT_C";
case FORMAT_NHWC_VECT_W:
return "NHWC_VECT_W";
+ case FORMAT_HWNC:
+ return "HWNC";
+ case FORMAT_HWCN:
+ return "HWCN";
default:
LOG(FATAL) << "Invalid Format: " << static_cast<int32>(format);
return "INVALID_FORMAT";
@@ -80,6 +84,14 @@ bool FormatFromString(const string& format_str, TensorFormat* format) {
*format = FORMAT_NHWC_VECT_W;
return true;
}
+ if (format_str == "HWNC") {
+ *format = FORMAT_HWNC;
+ return true;
+ }
+ if (format_str == "HWCN") {
+ *format = FORMAT_HWCN;
+ return true;
+ }
return false;
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index d3d5602f92..918835e1fb 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -59,6 +59,12 @@ enum TensorFormat {
// In the future we may change the meaning of these enums to include vectors
// of other types such as int16x2, with op implementations automatically
// determining which format is implied based on the datatype.
+
+ // FORMAT_HWNC is for TPUs.
+ FORMAT_HWNC = 4,
+
+ // FORMAT_HWCN is for TPUs.
+ FORMAT_HWCN = 5,
};
// Tensor format for convolutional filters.
@@ -105,11 +111,11 @@ string ToString(FilterTensorFormat format);
inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
switch (format) {
case FORMAT_NHWC:
- return num_dims - 2; // Exclude N,C.
case FORMAT_NCHW:
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
return num_dims - 2; // Exclude N,C.
case FORMAT_NCHW_VECT_C:
- return num_dims - 3; // Exclude N,C,VectDim.
case FORMAT_NHWC_VECT_W:
// Note: the VECT_W is not counted as an independent spatial dim here,
// since it just a component of the width dimension.
@@ -132,6 +138,8 @@ inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
switch (format) {
case FORMAT_NHWC:
case FORMAT_NCHW:
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
return num_spatial_dims + 2; // Include N,C.
case FORMAT_NCHW_VECT_C:
case FORMAT_NHWC_VECT_W:
@@ -158,6 +166,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
case FORMAT_NCHW_VECT_C:
case FORMAT_NHWC_VECT_W:
return 0;
+ case FORMAT_HWNC:
+ return num_dims - 2;
+ case FORMAT_HWCN:
+ return num_dims - 1;
default:
LOG(FATAL) << "Unknown format " << format;
return -1; // Avoid compiler warning about missing return value
@@ -170,8 +182,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) {
switch (format) {
case FORMAT_NHWC:
+ case FORMAT_HWNC:
return num_dims - 1;
case FORMAT_NHWC_VECT_W:
+ case FORMAT_HWCN:
return num_dims - 2;
case FORMAT_NCHW:
case FORMAT_NCHW_VECT_C:
@@ -210,6 +224,9 @@ inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format,
case FORMAT_NCHW:
case FORMAT_NCHW_VECT_C:
return spatial_dim + 2;
+ case FORMAT_HWNC:
+ case FORMAT_HWCN:
+ return spatial_dim;
default:
LOG(FATAL) << "Unknown format " << format;
return -1; // Avoid compiler warning about missing return value
@@ -310,6 +327,32 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
LOG(FATAL) << "Invalid dimension: " << dimension;
return -1; // Avoid compiler warning about missing return value
}
+ } else if (format == FORMAT_HWNC) {
+ switch (dimension) {
+ case '0': return 0;
+ case '1': return 1;
+ case '2': return 2;
+ case 'H': return NUM_SPATIAL_DIMS - 2;
+ case 'W': return NUM_SPATIAL_DIMS - 1;
+ case 'N': return NUM_SPATIAL_DIMS;
+ case 'C': return NUM_SPATIAL_DIMS + 1;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ } else if (format == FORMAT_HWCN) {
+ switch (dimension) {
+ case '0': return 0;
+ case '1': return 1;
+ case '2': return 2;
+ case 'H': return NUM_SPATIAL_DIMS - 2;
+ case 'W': return NUM_SPATIAL_DIMS - 1;
+ case 'C': return NUM_SPATIAL_DIMS;
+ case 'N': return NUM_SPATIAL_DIMS + 1;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
} else {
LOG(FATAL) << "Invalid format: " << static_cast<int>(format);
return -1; // Avoid compiler warning about missing return value
diff --git a/tensorflow/core/util/tensor_format_test.cc b/tensorflow/core/util/tensor_format_test.cc
index 93902290eb..07cdce998a 100644
--- a/tensorflow/core/util/tensor_format_test.cc
+++ b/tensorflow/core/util/tensor_format_test.cc
@@ -26,10 +26,9 @@ namespace tensorflow {
{ val, #val }
std::pair<TensorFormat, const char*> test_data_formats[] = {
- EnumStringPair(FORMAT_NHWC),
- EnumStringPair(FORMAT_NCHW),
- EnumStringPair(FORMAT_NCHW_VECT_C),
- EnumStringPair(FORMAT_NHWC_VECT_W),
+ EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW),
+ EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W),
+ EnumStringPair(FORMAT_HWNC), EnumStringPair(FORMAT_HWCN),
};
std::pair<FilterTensorFormat, const char*> test_filter_formats[] = {
@@ -85,6 +84,16 @@ struct DimMaps {
{ 0, 2, 3, 1, { 2, 3, -1 } },
{ 0, 3, 4, 1, { 2, 3, 4 } }
};
+ StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid,
+ { 1, -1, 0, 2, { 0, -1, -1 } },
+ { 2, 0, 1, 3, { 0, 1, -1 } },
+ { 3, 1, 2, 4, { 0, 1, 2 } }
+ };
+ StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid,
+ { 2, -1, 0, 1, { 0, -1, -1 } },
+ { 3, 0, 1, 2, { 0, 1, -1 } },
+ { 4, 1, 2, 3, { 0, 1, 2 } }
+ };
#undef StaCoExTensorDm
#define StaCoExFilterDm static constexpr FilterDimMap
// 'H', 'W', 'I', 'O' 0 1 2
@@ -108,8 +117,10 @@ GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) {
(format == FORMAT_NHWC ||
format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] :
(format == FORMAT_NCHW ||
- format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims]
- : DimMaps::kTdmInvalid;
+ format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] :
+ (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] :
+ (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims]
+ : DimMaps::kTdmInvalid;
}
inline constexpr const FilterDimMap&
@@ -126,6 +137,8 @@ GetFilterDimMap(const int num_spatial_dims,
constexpr TensorDimMap DimMaps::kTdmInvalid;
constexpr TensorDimMap DimMaps::kTdmNHWC[4];
constexpr TensorDimMap DimMaps::kTdmNCHW[4];
+constexpr TensorDimMap DimMaps::kTdmHWNC[4];
+constexpr TensorDimMap DimMaps::kTdmHWCN[4];
constexpr FilterDimMap DimMaps::kFdmInvalid;
constexpr FilterDimMap DimMaps::kFdmHWIO[4];
constexpr FilterDimMap DimMaps::kFdmOIHW[4];