aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/types.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-11 09:16:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 09:19:14 -0700
commit01c27242128a55aa4aaf47c674642dd950beda1d (patch)
treef39001e31494a74a70a1d9cf7b682e974a79937c /tensorflow/contrib/lite/kernels/internal/types.h
parent56104e275348c377f765c49dc677c0a34440d5c5 (diff)
Add interim runtime utility function for use during refactoring out of Dims.
PiperOrigin-RevId: 200061346
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/types.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 0c7fb7a76a..1086c5b092 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -142,6 +142,22 @@ class RuntimeShape {
};
};
+// Converts inference-style shape to legacy tflite::Dims<4>.
+inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
+ tflite::Dims<4> result;
+ const int dimensions_count = array_shape.DimensionsCount();
+ TFLITE_CHECK_LE(dimensions_count, 4);
+ int cum_prod = 1;
+ for (int i = 0; i < 4; i++) {
+ const int new_dim =
+ (i < dimensions_count) ? array_shape.Dims(dimensions_count - 1 - i) : 1;
+ result.sizes[i] = new_dim;
+ result.strides[i] = cum_prod;
+ cum_prod *= new_dim;
+ }
+ return result;
+}
+
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
TFLITE_DCHECK_GT(num_dims, 0);