aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_format.h
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2017-07-13 16:03:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-13 16:07:30 -0700
commit9d45160c76bdb2aa7cf3cc5b318512fc7d7337a2 (patch)
treee291f3b20936f6f6e8b7a8e1a205300b9f6b1ddc /tensorflow/core/util/tensor_format.h
parentc959e9b023127d0fb42b8e735301bc70fd20af74 (diff)
Extend shape_inference::Conv2DShape to handle NCHW_VECT_C format.
Tested Conv2DShape with NCHW_VECT_C format. PiperOrigin-RevId: 161879362
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r--tensorflow/core/util/tensor_format.h11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 9923428a34..83f9500490 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -54,6 +54,17 @@ inline int GetTensorSpatialDims(int num_dims, TensorFormat format) {
}
}
+// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and
+// tensor format 'format'. This is the inverse of GetTensorSpatialDims.
+inline int GetTensorDimsFromSpatialDims(int num_spatial_dims,
+ TensorFormat format) {
+ if (format == FORMAT_NCHW_VECT_C) {
+ return num_spatial_dims + 3; // Include N,C,InnerC.
+ } else {
+ return num_spatial_dims + 2; // Include N,C.
+ }
+}
+
// Returns the index of the batch dimension.
inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) {
switch (format) {