diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-05-04 07:46:46 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-04 08:52:26 -0700 |
commit | 6a187ccddaebb741ea77fc3201c6e36625f0aadb (patch) | |
tree | 48097a7dbc49a3256de30ef0d9ae631e940af83e /tensorflow/core/util/tensor_format.h | |
parent | e5df6adc63653f31e9d5d6c539f799539cfbbed1 (diff) |
Add support for 3d convolutions and pooling. CPU kernels use Eigen, GPU kernels use CuDNN.
Change: 121484787
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 52 |
1 files changed, 35 insertions, 17 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index ee5f3703ce..4115afb2b1 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -36,18 +36,26 @@ bool FormatFromString(const string& format_str, TensorFormat* format); string ToString(TensorFormat format); // Return the position index from a format given a dimension specification with -// a char. +// a char. The chars can be N (batch), C (channels), H (y), W (x), or +// 0 .. (NDIMS-1). +template <int NDIMS> inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { if (format == FORMAT_NHWC) { switch (dimension) { case 'N': return 0; - case 'H': + case '0': return 1; - case 'W': + case '1': return 2; - case 'C': + case '2': return 3; + case 'H': + return NDIMS - 1; + case 'W': + return NDIMS; + case 'C': + return 1 + NDIMS; default: LOG(FATAL) << "Invalid dimension: " << dimension; } @@ -57,10 +65,16 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { return 0; case 'C': return 1; - case 'H': + case '0': return 2; - case 'W': + case '1': return 3; + case '2': + return 4; + case 'H': + return NDIMS; + case 'W': + return NDIMS + 1; default: LOG(FATAL) << "Invalid dimension: " << dimension; } @@ -69,11 +83,15 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { } } +inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { + return GetTensorDimIndex<2>(format, dimension); +} + // Return the given tensor dimension from a tensor. The tensor is interpretted // using the specified format, and a dimension specification using a char. inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, char dimension) { - int index = GetTensorDimIndex(format, dimension); + int index = GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < tensor.dims()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -86,7 +104,7 @@ inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, // specification using a char. inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format, char dimension) { - int index = GetTensorDimIndex(format, dimension); + int index = GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < tensor_shape.dims()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -99,7 +117,7 @@ inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format, template <typename T> T GetTensorDim(const std::vector<T>& attributes, TensorFormat format, char dimension) { - int index = GetTensorDimIndex(format, dimension); + int index = GetTensorDimIndex<2>(format, dimension); CHECK(index >= 0 && index < attributes.size()) << "Invalid index from the dimension: " << index << ", " << format << ", " << dimension; @@ -113,10 +131,10 @@ string GetConvnetDataFormatAttrString(); inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H, int64 W, int64 C) { std::vector<int64> dim_sizes(4); - dim_sizes[GetTensorDimIndex(format, 'N')] = N; - dim_sizes[GetTensorDimIndex(format, 'H')] = H; - dim_sizes[GetTensorDimIndex(format, 'W')] = W; - dim_sizes[GetTensorDimIndex(format, 'C')] = C; + dim_sizes[GetTensorDimIndex<2>(format, 'N')] = N; + dim_sizes[GetTensorDimIndex<2>(format, 'H')] = H; + dim_sizes[GetTensorDimIndex<2>(format, 'W')] = W; + dim_sizes[GetTensorDimIndex<2>(format, 'C')] = C; return TensorShape(dim_sizes); } @@ -128,13 +146,13 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format, return src_shape; } std::vector<int64> dim_sizes(4); - dim_sizes[GetTensorDimIndex(dst_format, 'N')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'N')] = GetTensorDim(src_shape, src_format, 'N'); - dim_sizes[GetTensorDimIndex(dst_format, 'H')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'H')] = GetTensorDim(src_shape, src_format, 'H'); - dim_sizes[GetTensorDimIndex(dst_format, 'W')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'W')] = GetTensorDim(src_shape, src_format, 'W'); - dim_sizes[GetTensorDimIndex(dst_format, 'C')] = + dim_sizes[GetTensorDimIndex<2>(dst_format, 'C')] = GetTensorDim(src_shape, src_format, 'C'); return TensorShape(dim_sizes); } |