aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/ops_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/ops_util.h')
-rw-r--r--tensorflow/core/kernels/ops_util.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h
index 68a9c37406..d3d1b56c9d 100644
--- a/tensorflow/core/kernels/ops_util.h
+++ b/tensorflow/core/kernels/ops_util.h
@@ -84,6 +84,20 @@ bool IsDim0SliceAligned(const TensorShape& s, int64 start, int64 end_or_size) {
// Returns <suffix> sanitized to have only [a-zA-Z0-9-_].
string SanitizeThreadSuffix(string suffix);
+// Helper to compute 'strides' given a tensor 'shape'. I.e.,
+// strides[i] = prod(shape.dim_size[(i+1):])
+template <typename T>
+gtl::InlinedVector<T, 8> ComputeStride(const TensorShape& shape) {
+ const int ndims = shape.dims();
+ gtl::InlinedVector<T, 8> strides(ndims);
+ T stride = 1;
+ for (int i = ndims - 1; i >= 0; --i) {
+ strides[i] = stride;
+ stride *= static_cast<T>(shape.dim_size(i));
+ }
+ return strides;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_