aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/types.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/types.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h139
1 files changed, 136 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index fa2420713f..c44698b677 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -23,7 +23,12 @@ limitations under the License.
namespace tflite {
enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu };
-enum class PaddingType { kNone, kSame, kValid };
+enum class PaddingType : uint8 { kNone, kSame, kValid };
+
+struct PaddingValues {
+ int8 width;
+ int8 height;
+};
// This enumeration allows for non-default formats for the weights array
// of a fully-connected operator, allowing the use of special optimized
@@ -114,6 +119,8 @@ class RuntimeShape {
// larger shapes are separately allocated.
static constexpr int kMaxSmallSize = 4;
+ RuntimeShape& operator=(RuntimeShape const&) = delete;
+
RuntimeShape() : size_(0) {}
explicit RuntimeShape(int dimensions_count) : size_(dimensions_count) {
@@ -130,6 +137,20 @@ class RuntimeShape {
BuildFrom(init_list);
}
+ // Avoid using this constructor. We should be able to delete it when C++17
+ // rolls out.
+ RuntimeShape(RuntimeShape const& other) : size_(other.DimensionsCount()) {
+ if (size_ > kMaxSmallSize) {
+ dims_pointer_ = new int32[size_];
+ }
+ std::memcpy(DimsData(), other.DimsData(), sizeof(int32) * size_);
+ }
+
+ bool operator==(const RuntimeShape& comp) const {
+ return this->size_ == comp.size_ &&
+ std::memcmp(DimsData(), comp.DimsData(), size_ * sizeof(int32)) == 0;
+ }
+
~RuntimeShape() {
if (size_ > kMaxSmallSize) {
delete[] dims_pointer_;
@@ -186,6 +207,16 @@ class RuntimeShape {
}
}
+ // This will probably be factored out. Old code made substantial use of 4-D
+ // shapes, and so this function is used to extend smaller shapes. Note that
+ // (a) as Dims<4>-dependent code is eliminated, the reliance on this should be
+ // reduced, and (b) some kernels are stricly 4-D, but then the shapes of their
+ // inputs should already be 4-D, so this function should not be needed.
+ inline static RuntimeShape ExtendedShape(int new_shape_size,
+ const RuntimeShape& shape) {
+ return RuntimeShape(new_shape_size, shape, 1);
+ }
+
inline void BuildFrom(const std::initializer_list<int> init_list) {
BuildFrom<const std::initializer_list<int>>(init_list);
}
@@ -203,7 +234,25 @@ class RuntimeShape {
return buffer_size;
}
+ bool operator!=(const RuntimeShape& comp) const { return !((*this) == comp); }
+
private:
+ // For use only by ExtendFrom(), written to guarantee (return-value) copy
+ // elision in C++17.
+ // This creates a shape padded to the desired size with the specified value.
+ RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
+ : size_(0) {
+ TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
+ TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
+ Resize(new_shape_size);
+ const int size_increase = new_shape_size - shape.DimensionsCount();
+ for (int i = 0; i < size_increase; ++i) {
+ SetDim(i, pad_value);
+ }
+ std::memcpy(DimsData() + size_increase, shape.DimsData(),
+ sizeof(int32) * shape.DimensionsCount());
+ }
+
int32 size_;
union {
int32 dims_[kMaxSmallSize];
@@ -229,7 +278,9 @@ inline tflite::Dims<4> ToRuntimeDims(const tflite::RuntimeShape& array_shape) {
// Gets next index to iterate through a multidimensional array.
inline bool NextIndex(const int num_dims, const int* dims, int* current) {
- TFLITE_DCHECK_GT(num_dims, 0);
+ if (num_dims == 0) {
+ return false;
+ }
TFLITE_DCHECK(dims != nullptr);
TFLITE_DCHECK(current != nullptr);
int carry = 1;
@@ -256,7 +307,9 @@ inline bool NextIndex(const int num_dims, const int* dims, int* current) {
inline size_t ReducedOutputOffset(const int num_dims, const int* dims,
const int* index, const int num_axis,
const int* axis) {
- TFLITE_DCHECK_GT(num_dims, 0);
+ if (num_dims == 0) {
+ return 0;
+ }
TFLITE_DCHECK(dims != nullptr);
TFLITE_DCHECK(index != nullptr);
size_t offset = 0;
@@ -359,6 +412,7 @@ inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
// arrays.
inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -369,6 +423,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -380,6 +435,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_0,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -392,6 +448,7 @@ inline int MatchingFlatSize(const RuntimeShape& shape,
const RuntimeShape& check_shape_1,
const RuntimeShape& check_shape_2,
const RuntimeShape& check_shape_3) {
+ TFLITE_DCHECK_EQ(shape.DimensionsCount(), check_shape_0.DimensionsCount());
const int dims_count = shape.DimensionsCount();
for (int i = 0; i < dims_count; ++i) {
TFLITE_DCHECK_EQ(shape.Dims(i), check_shape_0.Dims(i));
@@ -588,6 +645,82 @@ void ComputeStrides(Dims<N>* dims) {
}
}
+struct PoolParams {
+ FusedActivationFunctionType activation;
+ PaddingType padding_type;
+ PaddingValues padding_values;
+ int stride_height;
+ int stride_width;
+ int filter_height;
+ int filter_width;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+};
+
+enum class BroadcastableOpCategory : uint8 {
+ kNone,
+ kNonBroadcast, // Matching input shapes.
+ kFirstInputBroadcastsFast, // Fivefold nested loops.
+ kSecondInputBroadcastsFast, // Fivefold nested loops.
+ kGenericBroadcast, // Fall-back.
+};
+
+// For Add, Sub, Mul ops.
+struct ArithmeticParams {
+ // Shape dependent / common to data / op types.
+ BroadcastableOpCategory broadcast_category;
+ // uint8 inference params.
+ int32 input1_offset;
+ int32 input2_offset;
+ int32 output_offset;
+ int32 output_multiplier;
+ int output_shift;
+ // Add / Sub, not Mul, uint8 inference params.
+ int left_shift;
+ int32 input1_multiplier;
+ int input1_shift;
+ int32 input2_multiplier;
+ int input2_shift;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
+ // float activation params.
+ float float_activation_min;
+ float float_activation_max;
+
+ // Processed output dimensions.
+ // Let input "a" be the one that broadcasts in the faster-changing dimension.
+ // Then, after coalescing, for shapes {a0, a1, a2, a3, a4} and
+ // {b0, b1, b2, b3, b4},
+ // broadcast_shape[4] = b0 = a0.
+ // broadcast_shape[3] = b1; a1 = 1.
+ // broadcast_shape[2] = b2 = a2.
+ // broadcast_shape[1] = a3; b3 = 1.
+ // broadcast_shape[0] = b4 = a4.
+ int broadcast_shape[5];
+};
+
+template <typename T>
+inline void SetActivationParams(T min, T max, ArithmeticParams* params);
+
+template <>
+inline void SetActivationParams(float min, float max,
+ ArithmeticParams* params) {
+ params->float_activation_min = min;
+ params->float_activation_max = max;
+}
+
+template <>
+inline void SetActivationParams(int32 min, int32 max,
+ ArithmeticParams* params) {
+ params->quantized_activation_min = min;
+ params->quantized_activation_max = max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_