diff options
author | Jingyue Wu <jingyue@google.com> | 2017-07-13 16:03:28 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-13 16:07:30 -0700 |
commit | 9d45160c76bdb2aa7cf3cc5b318512fc7d7337a2 (patch) | |
tree | e291f3b20936f6f6e8b7a8e1a205300b9f6b1ddc /tensorflow/core/util/tensor_format.h | |
parent | c959e9b023127d0fb42b8e735301bc70fd20af74 (diff) |
Extend shape_inference::Conv2DShape to handle NCHW_VECT_C format.
Tested Conv2DShape with NCHW_VECT_C format.
PiperOrigin-RevId: 161879362
Diffstat (limited to 'tensorflow/core/util/tensor_format.h')
-rw-r--r-- | tensorflow/core/util/tensor_format.h | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h index 9923428a34..83f9500490 100644 --- a/tensorflow/core/util/tensor_format.h +++ b/tensorflow/core/util/tensor_format.h @@ -54,6 +54,17 @@ inline int GetTensorSpatialDims(int num_dims, TensorFormat format) { } } +// Returns the rank of a tensor with 'num_spatial_dims' spatial dimensions and +// tensor format 'format'. This is the inverse of GetTensorSpatialDims. +inline int GetTensorDimsFromSpatialDims(int num_spatial_dims, + TensorFormat format) { + if (format == FORMAT_NCHW_VECT_C) { + return num_spatial_dims + 3; // Include N,C,InnerC. + } else { + return num_spatial_dims + 2; // Include N,C. + } +} + // Returns the index of the batch dimension. inline int GetTensorBatchDimIndex(int num_dims, TensorFormat format) { switch (format) { |