diff options
author | 2018-06-14 17:59:25 -0700 | |
---|---|---|
committer | 2018-06-14 18:01:20 -0700 | |
commit | 889833b5f145079d4837a5da73ffb2a997014764 (patch) | |
tree | f3a09e6befa101c88acc3c6c6456ea2d3e966703 | |
parent | 99d48bdec4605cdd21f09d2dfcfc70139cbe4ebd (diff) |
Add HWNC and HWCN data format support
PiperOrigin-RevId: 200650683
-rw-r--r-- | tensorflow/core/util/tensor_format.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 47 | ||||
-rw-r--r-- | tensorflow/core/util/tensor_format_test.cc | 25 |
3 files changed, 76 insertions, 8 deletions
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc index d4311d1ab0..a5f7ecf0d1 100644 --- a/tensorflow/core/util/tensor_format.cc +++ b/tensorflow/core/util/tensor_format.cc @@ -43,6 +43,10 @@ string ToString(TensorFormat format) { return "NCHW_VECT_C"; case FORMAT_NHWC_VECT_W: return "NHWC_VECT_W"; + case FORMAT_HWNC: + return "HWNC"; + case FORMAT_HWCN: + return "HWCN"; default: LOG(FATAL) << "Invalid Format: " << static_cast<int32>(format); return "INVALID_FORMAT"; @@ -80,6 +84,14 @@ bool FormatFromString(const string& format_str, TensorFormat* format) { *format = FORMAT_NHWC_VECT_W; return true; } + if (format_str == "HWNC") { + *format = FORMAT_HWNC; + return true; + } + if (format_str == "HWCN") { + *format = FORMAT_HWCN; + return true; + } return false; } diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index d3d5602f92..918835e1fb 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -59,6 +59,12 @@ enum TensorFormat { // In the future we may change the meaning of these enums to include vectors // of other types such as int16x2, with op implementations automatically // determining which format is implied based on the datatype. + + // FORMAT_HWNC is for TPUs. + FORMAT_HWNC = 4, + + // FORMAT_HWCN is for TPUs. + FORMAT_HWCN = 5, }; // Tensor format for convolutional filters. @@ -105,11 +111,11 @@ string ToString(FilterTensorFormat format); inline int GetTensorSpatialDims(int num_dims, TensorFormat format) { switch (format) { case FORMAT_NHWC: - return num_dims - 2; // Exclude N,C. case FORMAT_NCHW: + case FORMAT_HWNC: + case FORMAT_HWCN: return num_dims - 2; // Exclude N,C. case FORMAT_NCHW_VECT_C: - return num_dims - 3; // Exclude N,C,VectDim. case FORMAT_NHWC_VECT_W: // Note: the VECT_W is not counted as an independent spatial dim here, // since it just a component of the width dimension. @@ -132,6 +138,8 @@ inline int GetTensorDimsFromSpatialDims(int num_spatial_dims, switch (format) { case FORMAT_NHWC: case FORMAT_NCHW: + case FORMAT_HWNC: + case FORMAT_HWCN: return num_spatial_dims + 2; // Include N,C. case FORMAT_NCHW_VECT_C: case FORMAT_NHWC_VECT_W: @@ -158,6 +166,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) { case FORMAT_NCHW_VECT_C: case FORMAT_NHWC_VECT_W: return 0; + case FORMAT_HWNC: + return num_dims - 2; + case FORMAT_HWCN: + return num_dims - 1; default: LOG(FATAL) << "Unknown format " << format; return -1; // Avoid compiler warning about missing return value @@ -170,8 +182,10 @@ inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) { inline int GetTensorFeatureDimIndex(int num_dims, TensorFormat format) { switch (format) { case FORMAT_NHWC: + case FORMAT_HWNC: return num_dims - 1; case FORMAT_NHWC_VECT_W: + case FORMAT_HWCN: return num_dims - 2; case FORMAT_NCHW: case FORMAT_NCHW_VECT_C: @@ -210,6 +224,9 @@ inline int GetTensorSpatialDimIndex(int num_dims, TensorFormat format, case FORMAT_NCHW: case FORMAT_NCHW_VECT_C: return spatial_dim + 2; + case FORMAT_HWNC: + case FORMAT_HWCN: + return spatial_dim; default: LOG(FATAL) << "Unknown format " << format; return -1; // Avoid compiler warning about missing return value @@ -310,6 +327,32 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { LOG(FATAL) << "Invalid dimension: " << dimension; return -1; // Avoid compiler warning about missing return value } + } else if (format == FORMAT_HWNC) { + switch (dimension) { + case '0': return 0; + case '1': return 1; + case '2': return 2; + case 'H': return NUM_SPATIAL_DIMS - 2; + case 'W': return NUM_SPATIAL_DIMS - 1; + case 'N': return NUM_SPATIAL_DIMS; + case 'C': return NUM_SPATIAL_DIMS + 1; + default: + LOG(FATAL) << "Invalid dimension: " << dimension; + return -1; // Avoid compiler warning about missing return value + } + } else if (format == FORMAT_HWCN) { + switch (dimension) { + case '0': return 0; + case '1': return 1; + case '2': return 2; + case 'H': return NUM_SPATIAL_DIMS - 2; + case 'W': return NUM_SPATIAL_DIMS - 1; + case 'C': return NUM_SPATIAL_DIMS; + case 'N': return NUM_SPATIAL_DIMS + 1; + default: + LOG(FATAL) << "Invalid dimension: " << dimension; + return -1; // Avoid compiler warning about missing return value + } } else { LOG(FATAL) << "Invalid format: " << static_cast<int>(format); return -1; // Avoid compiler warning about missing return value diff --git a/tensorflow/core/util/tensor_format_test.cc b/tensorflow/core/util/tensor_format_test.cc index 93902290eb..07cdce998a 100644 --- a/tensorflow/core/util/tensor_format_test.cc +++ b/tensorflow/core/util/tensor_format_test.cc @@ -26,10 +26,9 @@ namespace tensorflow { { val, #val } std::pair<TensorFormat, const char*> test_data_formats[] = { - EnumStringPair(FORMAT_NHWC), - EnumStringPair(FORMAT_NCHW), - EnumStringPair(FORMAT_NCHW_VECT_C), - EnumStringPair(FORMAT_NHWC_VECT_W), + EnumStringPair(FORMAT_NHWC), EnumStringPair(FORMAT_NCHW), + EnumStringPair(FORMAT_NCHW_VECT_C), EnumStringPair(FORMAT_NHWC_VECT_W), + EnumStringPair(FORMAT_HWNC), EnumStringPair(FORMAT_HWCN), }; std::pair<FilterTensorFormat, const char*> test_filter_formats[] = { @@ -85,6 +84,16 @@ struct DimMaps { { 0, 2, 3, 1, { 2, 3, -1 } }, { 0, 3, 4, 1, { 2, 3, 4 } } }; + StaCoExTensorDm kTdmHWNC[4] = { kTdmInvalid, + { 1, -1, 0, 2, { 0, -1, -1 } }, + { 2, 0, 1, 3, { 0, 1, -1 } }, + { 3, 1, 2, 4, { 0, 1, 2 } } + }; + StaCoExTensorDm kTdmHWCN[4] = { kTdmInvalid, + { 2, -1, 0, 1, { 0, -1, -1 } }, + { 3, 0, 1, 2, { 0, 1, -1 } }, + { 4, 1, 2, 3, { 0, 1, 2 } } + }; #undef StaCoExTensorDm #define StaCoExFilterDm static constexpr FilterDimMap // 'H', 'W', 'I', 'O' 0 1 2 @@ -108,8 +117,10 @@ GetTensorDimMap(const int num_spatial_dims, const TensorFormat format) { (format == FORMAT_NHWC || format == FORMAT_NHWC_VECT_W) ? DimMaps::kTdmNHWC[num_spatial_dims] : (format == FORMAT_NCHW || - format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] - : DimMaps::kTdmInvalid; + format == FORMAT_NCHW_VECT_C) ? DimMaps::kTdmNCHW[num_spatial_dims] : + (format == FORMAT_HWNC) ? DimMaps::kTdmHWNC[num_spatial_dims] : + (format == FORMAT_HWCN) ? DimMaps::kTdmHWCN[num_spatial_dims] + : DimMaps::kTdmInvalid; } inline constexpr const FilterDimMap& @@ -126,6 +137,8 @@ GetFilterDimMap(const int num_spatial_dims, constexpr TensorDimMap DimMaps::kTdmInvalid; constexpr TensorDimMap DimMaps::kTdmNHWC[4]; constexpr TensorDimMap DimMaps::kTdmNCHW[4]; +constexpr TensorDimMap DimMaps::kTdmHWNC[4]; +constexpr TensorDimMap DimMaps::kTdmHWCN[4]; constexpr FilterDimMap DimMaps::kFdmInvalid; constexpr FilterDimMap DimMaps::kFdmHWIO[4]; constexpr FilterDimMap DimMaps::kFdmOIHW[4]; |