aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_format.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-21 14:26:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-21 15:48:06 -0700
commit61f30222eba5e3f1f51dedb3c5493f5f8eb331c8 (patch)
tree622b69c4e91a1be79bb33681d54842c08b17ee4e /tensorflow/core/util/tensor_format.h
parenta6116e04dbcf397337de8a4c37b531c4f373da04 (diff)
Add support for the NCHW data_format for 3d operations (convolution, pooling).
This brings NCHW support for 3d in sync with the corresponding 2d ops. Change: 150811076
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r--tensorflow/core/util/tensor_format.h62
1 files changed, 43 insertions, 19 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 7d8e4b11c8..fe89fe852e 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -16,9 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_UTIL_TENSOR_FORMAT_H_
#define TENSORFLOW_UTIL_TENSOR_FORMAT_H_
+#include <array>
#include <vector>
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -127,7 +129,8 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
// using the specified format, and a dimension specification using a char.
inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex<2>(format, dimension);
+ int index = (tensor.dims() == 5) ? GetTensorDimIndex<3>(format, dimension)
+ : GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < tensor.dims())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -140,7 +143,9 @@ inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
// specification using a char.
inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex<2>(format, dimension);
+ int index = (tensor_shape.dims() == 5)
+ ? GetTensorDimIndex<3>(format, dimension)
+ : GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < tensor_shape.dims())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -153,7 +158,9 @@ inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format,
template <typename T>
T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex<2>(format, dimension);
+ int index = (attributes.size() == 5)
+ ? GetTensorDimIndex<3>(format, dimension)
+ : GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < attributes.size())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -162,16 +169,27 @@ T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
// Return the string that specifies the data format for convnet operations.
string GetConvnetDataFormatAttrString();
+string GetConvnet3dDataFormatAttrString();
+
+// Return a tensor shape for the given format. Works for both 2D and 3D
+// operations.
+inline TensorShape ShapeFromFormat(TensorFormat format, int64 N,
+ gtl::ArraySlice<int64> spatial, int64 C) {
+ gtl::InlinedVector<int64, 5> dim_sizes(spatial.size() + 2);
+ dim_sizes[GetTensorBatchDimIndex(dim_sizes.size(), format)] = N;
+ for (int dim = 0; dim < spatial.size(); dim++) {
+ dim_sizes[GetTensorSpatialDimIndex(dim_sizes.size(), format, dim)] =
+ spatial[dim];
+ }
+ dim_sizes[GetTensorFeatureDimIndex(dim_sizes.size(), format)] = C;
+
+ return TensorShape(dim_sizes);
+}
// Return a tensor shape from the given format, and tensor dimensions.
inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
int64 W, int64 C) {
- std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex<2>(format, 'N')] = N;
- dim_sizes[GetTensorDimIndex<2>(format, 'H')] = H;
- dim_sizes[GetTensorDimIndex<2>(format, 'W')] = W;
- dim_sizes[GetTensorDimIndex<2>(format, 'C')] = C;
- return TensorShape(dim_sizes);
+ return ShapeFromFormat(format, N, {{H, W}}, C);
}
// Return a tensor shape from the given format, and tensor dimensions.
@@ -181,16 +199,22 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format,
if (src_format == dst_format) {
return src_shape;
}
- std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex<2>(dst_format, 'N')] =
- GetTensorDim(src_shape, src_format, 'N');
- dim_sizes[GetTensorDimIndex<2>(dst_format, 'H')] =
- GetTensorDim(src_shape, src_format, 'H');
- dim_sizes[GetTensorDimIndex<2>(dst_format, 'W')] =
- GetTensorDim(src_shape, src_format, 'W');
- dim_sizes[GetTensorDimIndex<2>(dst_format, 'C')] =
- GetTensorDim(src_shape, src_format, 'C');
- return TensorShape(dim_sizes);
+
+ const int64 channels = GetTensorDim(src_shape, src_format, 'C');
+ const int64 batch = GetTensorDim(src_shape, src_format, 'N');
+
+ if (src_shape.dims() == 5) {
+ return ShapeFromFormat(dst_format, batch,
+ {{GetTensorDim(src_shape, src_format, '0'),
+ GetTensorDim(src_shape, src_format, '1'),
+ GetTensorDim(src_shape, src_format, '2')}},
+ channels);
+ }
+
+ return ShapeFromFormat(dst_format, batch,
+ {{GetTensorDim(src_shape, src_format, 'H'),
+ GetTensorDim(src_shape, src_format, 'W')}},
+ channels);
}
} // namespace tensorflow