diff options
author | Xiaoqiang Zheng <zhengxq@google.com> | 2016-02-25 15:27:10 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-25 18:09:01 -0800 |
commit | 01a6f5e504d9299395888a786e52c589c16af529 (patch) | |
tree | c0cd02df394f44b297440cccba08734a2271e8c4 /tensorflow/core/util/tensor_format.h | |
parent | cdd0f2eeef9a11a48433156e41c95a5fd6f4e1ee (diff) |
Multiple layout support for pooling operations.
Change: 115611259
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 41 |
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index c829e7ca51..ee5f3703ce 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -69,17 +69,6 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) { } } -// Return a tensor shape from the given format, and tensor dimensions. -inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H, - int64 W, int64 C) { - std::vector<int64> dim_sizes(4); - dim_sizes[GetTensorDimIndex(format, 'N')] = N; - dim_sizes[GetTensorDimIndex(format, 'H')] = H; - dim_sizes[GetTensorDimIndex(format, 'W')] = W; - dim_sizes[GetTensorDimIndex(format, 'C')] = C; - return TensorShape(dim_sizes); -} - // Return the given tensor dimension from a tensor. The tensor is interpretted // using the specified format, and a dimension specification using a char. inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format, @@ -120,6 +109,36 @@ T GetTensorDim(const std::vector<T>& attributes, TensorFormat format, // Return the string that specifies the data format for convnet operations. string GetConvnetDataFormatAttrString(); +// Return a tensor shape from the given format, and tensor dimensions. +inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H, + int64 W, int64 C) { + std::vector<int64> dim_sizes(4); + dim_sizes[GetTensorDimIndex(format, 'N')] = N; + dim_sizes[GetTensorDimIndex(format, 'H')] = H; + dim_sizes[GetTensorDimIndex(format, 'W')] = W; + dim_sizes[GetTensorDimIndex(format, 'C')] = C; + return TensorShape(dim_sizes); +} + +// Return a tensor shape from the given format, and tensor dimensions. +inline TensorShape ShapeFromFormat(TensorFormat dst_format, + const TensorShape& src_shape, + TensorFormat src_format) { + if (src_format == dst_format) { + return src_shape; + } + std::vector<int64> dim_sizes(4); + dim_sizes[GetTensorDimIndex(dst_format, 'N')] = + GetTensorDim(src_shape, src_format, 'N'); + dim_sizes[GetTensorDimIndex(dst_format, 'H')] = + GetTensorDim(src_shape, src_format, 'H'); + dim_sizes[GetTensorDimIndex(dst_format, 'W')] = + GetTensorDim(src_shape, src_format, 'W'); + dim_sizes[GetTensorDimIndex(dst_format, 'C')] = + GetTensorDim(src_shape, src_format, 'C'); + return TensorShape(dim_sizes); +} + } // namespace tensorflow #endif // TENSORFLOW_UTIL_TENSOR_FORMAT_H_ |