diff options
author | 2018-06-11 09:16:31 -0700 | |
---|---|---|
committer | 2018-06-11 09:19:14 -0700 | |
commit | 01c27242128a55aa4aaf47c674642dd950beda1d (patch) | |
tree | f39001e31494a74a70a1d9cf7b682e974a79937c /tensorflow/contrib/lite/kernels/internal/types.h | |
parent | 56104e275348c377f765c49dc677c0a34440d5c5 (diff) |
Add interim runtime utility function for use during refactoring out of Dims.
PiperOrigin-RevId: 200061346
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/types.h')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/types.h | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 0c7fb7a76a..1086c5b092 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -142,6 +142,22 @@ class RuntimeShape { }; }; +// Converts inference-style shape to legacy tflite::Dims<4>. +inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) { + tflite::Dims<4> result; + const int dimensions_count = array_shape.DimensionsCount(); + TFLITE_CHECK_LE(dimensions_count, 4); + int cum_prod = 1; + for (int i = 0; i < 4; i++) { + const int new_dim = + (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1; + result.sizes[i] = new_dim; + result.strides[i] = cum_prod; + cum_prod *= new_dim; + } + return result; +} + // Gets next index to iterate through a multidimensional array. inline bool NextIndex(const int num_dims, const int* dims, int* current) { TFLITE_DCHECK_GT(num_dims, 0); |