aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_format.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-04 07:46:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-04 08:52:26 -0700
commit6a187ccddaebb741ea77fc3201c6e36625f0aadb (patch)
tree48097a7dbc49a3256de30ef0d9ae631e940af83e /tensorflow/core/util/tensor_format.h
parente5df6adc63653f31e9d5d6c539f799539cfbbed1 (diff)
Add support for 3d convolutions and pooling. CPU kernels use Eigen, GPU kernels use CuDNN.
Change: 121484787
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r--tensorflow/core/util/tensor_format.h52
1 files changed, 35 insertions, 17 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index ee5f3703ce..4115afb2b1 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -36,18 +36,26 @@ bool FormatFromString(const string& format_str, TensorFormat* format);
string ToString(TensorFormat format);
// Return the position index from a format given a dimension specification with
-// a char.
+// a char. The chars can be N (batch), C (channels), H (y), W (x), or
+// 0 .. (NDIMS-1).
+template <int NDIMS>
inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
if (format == FORMAT_NHWC) {
switch (dimension) {
case 'N':
return 0;
- case 'H':
+ case '0':
return 1;
- case 'W':
+ case '1':
return 2;
- case 'C':
+ case '2':
return 3;
+ case 'H':
+ return NDIMS - 1;
+ case 'W':
+ return NDIMS;
+ case 'C':
+ return 1 + NDIMS;
default:
LOG(FATAL) << "Invalid dimension: " << dimension;
}
@@ -57,10 +65,16 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
return 0;
case 'C':
return 1;
- case 'H':
+ case '0':
return 2;
- case 'W':
+ case '1':
return 3;
+ case '2':
+ return 4;
+ case 'H':
+ return NDIMS;
+ case 'W':
+ return NDIMS + 1;
default:
LOG(FATAL) << "Invalid dimension: " << dimension;
}
@@ -69,11 +83,15 @@ inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
}
}
+inline int32 GetTensorDimIndex(TensorFormat format, char dimension) {
+ return GetTensorDimIndex<2>(format, dimension);
+}
+
// Return the given tensor dimension from a tensor. The tensor is interpretted
// using the specified format, and a dimension specification using a char.
inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex(format, dimension);
+ int index = GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < tensor.dims())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -86,7 +104,7 @@ inline int64 GetTensorDim(const Tensor& tensor, TensorFormat format,
// specification using a char.
inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex(format, dimension);
+ int index = GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < tensor_shape.dims())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -99,7 +117,7 @@ inline int64 GetTensorDim(const TensorShape& tensor_shape, TensorFormat format,
template <typename T>
T GetTensorDim(const std::vector<T>& attributes, TensorFormat format,
char dimension) {
- int index = GetTensorDimIndex(format, dimension);
+ int index = GetTensorDimIndex<2>(format, dimension);
CHECK(index >= 0 && index < attributes.size())
<< "Invalid index from the dimension: " << index << ", " << format << ", "
<< dimension;
@@ -113,10 +131,10 @@ string GetConvnetDataFormatAttrString();
inline TensorShape ShapeFromFormat(TensorFormat format, int64 N, int64 H,
int64 W, int64 C) {
std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex(format, 'N')] = N;
- dim_sizes[GetTensorDimIndex(format, 'H')] = H;
- dim_sizes[GetTensorDimIndex(format, 'W')] = W;
- dim_sizes[GetTensorDimIndex(format, 'C')] = C;
+ dim_sizes[GetTensorDimIndex<2>(format, 'N')] = N;
+ dim_sizes[GetTensorDimIndex<2>(format, 'H')] = H;
+ dim_sizes[GetTensorDimIndex<2>(format, 'W')] = W;
+ dim_sizes[GetTensorDimIndex<2>(format, 'C')] = C;
return TensorShape(dim_sizes);
}
@@ -128,13 +146,13 @@ inline TensorShape ShapeFromFormat(TensorFormat dst_format,
return src_shape;
}
std::vector<int64> dim_sizes(4);
- dim_sizes[GetTensorDimIndex(dst_format, 'N')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'N')] =
GetTensorDim(src_shape, src_format, 'N');
- dim_sizes[GetTensorDimIndex(dst_format, 'H')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'H')] =
GetTensorDim(src_shape, src_format, 'H');
- dim_sizes[GetTensorDimIndex(dst_format, 'W')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'W')] =
GetTensorDim(src_shape, src_format, 'W');
- dim_sizes[GetTensorDimIndex(dst_format, 'C')] =
+ dim_sizes[GetTensorDimIndex<2>(dst_format, 'C')] =
GetTensorDim(src_shape, src_format, 'C');
return TensorShape(dim_sizes);
}