diff options
author | 2017-03-21 14:26:26 -0800 | |
---|---|---|
committer | 2017-03-21 15:48:06 -0700 | |
commit | 61f30222eba5e3f1f51dedb3c5493f5f8eb331c8 (patch) | |
tree | 622b69c4e91a1be79bb33681d54842c08b17ee4e /tensorflow/core/util/tensor_format.h | |
parent | a6116e04dbcf397337de8a4c37b531c4f373da04 (diff) |
Add support for the NCHW data_format for 3d operations (convolution, pooling).
This brings NCHW support for 3d in sync with the corresponding 2d ops.
Change: 150811076
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 62 |
1 files changed, 43 insertions, 19 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 7d8e4b11c8..fe89fe852e 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -16,9 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_UTIL_TENSOR_FORMAT_H_ #define TENSORFLOW_UTIL_TENSOR_FORMAT_H_ +#include <array> #include <vector> #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -127,7 +129,8 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { // using the specified format, and a dimension specification using a char. inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, char dimension) { - int index = GetTensorDimIndex<2>(format, dimension); + int index = (tensor.dims() == 5) ? GetTensorDimIndex<3>(format, dimension) + : GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < tensor.dims()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -140,7 +143,9 @@ inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, // specification using a char. inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format, char dimension) { - int index = GetTensorDimIndex<2>(format, dimension); + int index = (tensor_shape.dims() == 5) + ? GetTensorDimIndex<3>(format, dimension) + : GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < tensor_shape.dims()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -153,7 +158,9 @@ inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format, template <typename T> T GetTensorDim(const std::vector<T>& attributes, TensorFormat format, char dimension) { - int index = GetTensorDimIndex<2>(format, dimension); + int index = (attributes.size() == 5) + ? GetTensorDimIndex<3>(format, dimension) + : GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < attributes.size()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -162,16 +169,27 @@ T GetTensorDim(const std::vector<T>& attributes, TensorFormat format, // Return the string that specifies the data format for convnet operations. string GetConvnetDataFormatAttrString(); +string GetConvnet3dDataFormatAttrString(); + +// Return a tensor shape for the given format. Works for both 2D and 3D +// operations. +inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, + gtl::ArraySlice<int64> spatial, int64 C) { + gtl::InlinedVector<int64, 5> dim_sizes(spatial.size() + 2); + dim_sizes[GetTensorBatchDimIndex(dim_sizes.size(), format)] = N; + for (int dim = 0; dim < spatial.size(); dim++) { + dim_sizes[GetTensorSpatialDimIndex(dim_sizes.size(), format, dim)] = + spatial[dim]; + } + dim_sizes[GetTensorFeatureDimIndex(dim_sizes.size(), format)] = C; + + return TensorShape(dim_sizes); +} // Return a tensor shape from the given format, and tensor dimensions. inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H, int64 W, int64 C) { - std::vector<int64> dim_sizes(4); - dim_sizes[GetTensorDimIndex<2>(format, 'N')] = N; - dim_sizes[GetTensorDimIndex<2>(format, 'H')] = H; - dim_sizes[GetTensorDimIndex<2>(format, 'W')] = W; - dim_sizes[GetTensorDimIndex<2>(format, 'C')] = C; - return TensorShape(dim_sizes); + return ShapeFromFormat(format, N, {{H, W}}, C); } // Return a tensor shape from the given format, and tensor dimensions. @@ -181,16 +199,22 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format, if (src_format == dst_format) { return src_shape; } - std::vector<int64> dim_sizes(4); - dim_sizes[GetTensorDimIndex<2>(dst_format, 'N')] = - GetTensorDim(src_shape, src_format, 'N'); - dim_sizes[GetTensorDimIndex<2>(dst_format, 'H')] = - GetTensorDim(src_shape, src_format, 'H'); - dim_sizes[GetTensorDimIndex<2>(dst_format, 'W')] = - GetTensorDim(src_shape, src_format, 'W'); - dim_sizes[GetTensorDimIndex<2>(dst_format, 'C')] = - GetTensorDim(src_shape, src_format, 'C'); - return TensorShape(dim_sizes); + + const int64 channels = GetTensorDim(src_shape, src_format, 'C'); + const int64 batch = GetTensorDim(src_shape, src_format, 'N'); + + if (src_shape.dims() == 5) { + return ShapeFromFormat(dst_format, batch, + {{GetTensorDim(src_shape, src_format, '0'), + GetTensorDim(src_shape, src_format, '1'), + GetTensorDim(src_shape, src_format, '2')}}, + channels); + } + + return ShapeFromFormat(dst_format, batch, + {{GetTensorDim(src_shape, src_format, 'H'), + GetTensorDim(src_shape, src_format, 'W')}}, + channels); } } // namespace tensorflow |