aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/types.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 12:36:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 12:39:15 -0700
commit47b1c9396aef567b839c2c5ad91aa37ba0cb68ca (patch)
treebf85c8aeeddecd88411fcac64ef559e549944b32 /tensorflow/contrib/lite/kernels/internal/types.h
parentd40ca72ff692d21e7965b3b17445bca873510941 (diff)
Initial application of runtime shapes to runtime kernels.
PiperOrigin-RevId: 200435608
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/types.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h97
1 files changed, 97 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 3ecef15271..64f4881a46 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -65,6 +65,10 @@ class RuntimeShape {
ReplaceWith(dimensions_count, dims_data);
}
+ RuntimeShape(const std::initializer_list<int> init_list) : size_(0) {
+ BuildFrom(init_list);
+ }
+
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
delete[] dims_pointer_;
@@ -214,6 +218,15 @@ inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
return offset;
}
+inline int Offset(const RuntimeShape& shape, int i0, int i1, int i2, int i3) {
+ TFLITE_DCHECK(i0 >= 0 && i0 < shape.Dims(0));
+ TFLITE_DCHECK(i1 >= 0 && i1 < shape.Dims(1));
+ TFLITE_DCHECK(i2 >= 0 && i2 < shape.Dims(2));
+ TFLITE_DCHECK(i3 >= 0 && i3 < shape.Dims(3));
+ const int* dims_data = shape.DimsData();
+ return ((i0 * dims_data[1] + i1) * dims_data[2] + i2) * dims_data[3] + i3;
+}
+
inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) {
TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]);
TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]);
@@ -228,6 +241,9 @@ inline int Offset(const Dims<4>& dims, int* index) {
}
// Get array size, DCHECKing that the dim index is in range.
+//
+// Note that this will be phased out with Dims<4>, since RuntimeShape::Dims()
+// already performs this check.
template <int N>
int ArraySize(const Dims<N>& array, int index) {
TFLITE_DCHECK(index >= 0 && index < N);
@@ -249,6 +265,21 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
return MatchingArraySize(array1, index1, args...);
}
+// Get common shape dim, DCHECKing that they all agree.
+inline int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return shape1.Dims(index1);
+}
+
+template <typename... Args>
+int MatchingDim(const RuntimeShape& shape1, int index1,
+ const RuntimeShape& shape2, int index2, Args... args) {
+ TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
+ return MatchingDim(shape1, index1, args...);
+}
+
+// Will be phased out with Dims<4>, replaced by RuntimeShape::FlatSize().
template <int N>
inline int FlatSize(const Dims<N>& dims) {
int flat_size = 1;
@@ -368,6 +399,72 @@ inline int MatchingFlatSizeSkipDim(const Dims<N>& dims, int skip_dim,
check_dims_3);
}
+// Data is required to be contiguous, and so many operators can use either the
+// full array flat size or the flat size with one dimension skipped (commonly
+// the depth).
+inline int FlatSizeSkipDim(const RuntimeShape& shape, int skip_dim) {
+ const int dims_count = shape.DimensionsCount();
+ TFLITE_DCHECK(skip_dim >= 0 && skip_dim < dims_count);
+ const auto* dims_data = shape.DimsData();
+ int flat_size = 1;
+ for (int i = 0; i < dims_count; ++i) {
+ flat_size *= (i == skip_dim) ? 1 : dims_data[i];
+ }
+ return flat_size;
+}
+
+// A combination of MatchingFlatSize() and FlatSizeSkipDim().
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return FlatSizeSkipDim(shape, skip_dim);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2);
+}
+
+inline int MatchingFlatSizeSkipDim(const RuntimeShape& shape, int skip_dim,
+ const RuntimeShape& check_shape_0,
+ const RuntimeShape& check_shape_1,
+ const RuntimeShape& check_shape_2,
+ const RuntimeShape& check_shape_3) {
+ const int dims_count = shape.DimensionsCount();
+ for (int i = 0; i < dims_count; ++i) {
+ if (i != skip_dim) {
+ TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
+ }
+ }
+ return MatchingFlatSizeSkipDim(shape, skip_dim, check_shape_1, check_shape_2,
+ check_shape_3);
+}
+
template <int N>
bool IsPackedWithoutStrides(const Dims<N>& dims) {
int expected_stride = 1;