aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_format.h
diff options
context:
space:
mode:
authorGravatar Xiaoqiang Zheng <zhengxq@google.com>2016-02-25 15:27:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-25 18:09:01 -0800
commit01a6f5e504d9299395888a786e52c589c16af529 (patch)
treec0cd02df394f44b297440cccba08734a2271e8c4 /tensorflow/core/util/tensor_format.h
parentcdd0f2eeef9a11a48433156e41c95a5fd6f4e1ee (diff)
Multiple layout support for pooling operations.
Change: 115611259
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r--tensorflow/core/util/tensor_format.h41
1 files changed, 30 insertions, 11 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index c829e7ca51..ee5f3703ce 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -69,17 +69,6 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
}
}
-// Return a tensor shape from the given format, and tensor dimensions.
-inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
- int64 W, int64 C) {
- std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex(format, 'N')] = N;
- dim_sizes[GetTensorDimIndex(format, 'H')] = H;
- dim_sizes[GetTensorDimIndex(format, 'W')] = W;
- dim_sizes[GetTensorDimIndex(format, 'C')] = C;
- return TensorShape(dim_sizes);
-}
-
// Return the given tensor dimension from a tensor. The tensor is interpretted
// using the specified format, and a dimension specification using a char.
inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
@@ -120,6 +109,36 @@ T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
// Return the string that specifies the data format for convnet operations.
string GetConvnetDataFormatAttrString();
+// Return a tensor shape from the given format, and tensor dimensions.
+inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
+ int64 W, int64 C) {
+ std::vector<int64> dim_sizes(4);
+ dim_sizes[GetTensorDimIndex(format, 'N')] = N;
+ dim_sizes[GetTensorDimIndex(format, 'H')] = H;
+ dim_sizes[GetTensorDimIndex(format, 'W')] = W;
+ dim_sizes[GetTensorDimIndex(format, 'C')] = C;
+ return TensorShape(dim_sizes);
+}
+
+// Return a tensor shape from the given format, and tensor dimensions.
+inline TensorShape ShapeFromFormat(TensorFormat dst_format,
+ const TensorShape& src_shape,
+ TensorFormat src_format) {
+ if (src_format == dst_format) {
+ return src_shape;
+ }
+ std::vector<int64> dim_sizes(4);
+ dim_sizes[GetTensorDimIndex(dst_format, 'N')] =
+ GetTensorDim(src_shape, src_format, 'N');
+ dim_sizes[GetTensorDimIndex(dst_format, 'H')] =
+ GetTensorDim(src_shape, src_format, 'H');
+ dim_sizes[GetTensorDimIndex(dst_format, 'W')] =
+ GetTensorDim(src_shape, src_format, 'W');
+ dim_sizes[GetTensorDimIndex(dst_format, 'C')] =
+ GetTensorDim(src_shape, src_format, 'C');
+ return TensorShape(dim_sizes);
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_UTIL_TENSOR_FORMAT_H_