diff options
Diffstat (limited to 'tensorflow/core/kernels/ops_util.h')
-rw-r--r-- | tensorflow/core/kernels/ops_util.h | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h index 68a9c37406..d3d1b56c9d 100644 --- a/tensorflow/core/kernels/ops_util.h +++ b/tensorflow/core/kernels/ops_util.h @@ -84,6 +84,20 @@ bool IsDim0SliceAligned(const TensorShape& s, int64 start, int64 end_or_size) { // Returns <suffix> sanitized to have only [a-zA-Z0-9-_]. string SanitizeThreadSuffix(string suffix); +// Helper to compute 'strides' given a tensor 'shape'. I.e., +// strides[i] = prod(shape.dim_size[(i+1):]) +template <typename T> +gtl::InlinedVector<T, 8> ComputeStride(const TensorShape& shape) { + const int ndims = shape.dims(); + gtl::InlinedVector<T, 8> strides(ndims); + T stride = 1; + for (int i = ndims - 1; i >= 0; --i) { + strides[i] = stride; + stride *= static_cast<T>(shape.dim_size(i)); + } + return strides; +} + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_OPS_UTIL_H_ |