aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h307
1 files changed, 289 insertions, 18 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 878b2441b4..bcf5e4e4f6 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -63,15 +63,257 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <FusedActivationFunctionType Ac>
+inline void Add(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier, int input2_shift,
+ int32 output_offset, int32 output_multiplier, int output_shift,
+ int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ gemmlowp::ScopedProfilingLabel label("Add/int32");
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<int32>::min();
+ op_params.quantized_activation_max = std::numeric_limits<int32>::max();
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAdd(int left_shift, const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const uint8* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset,
+ int32 output_multiplier, int output_shift,
+ int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Add(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void BroadcastAddFivefold(
+ int y0, int y1, int y2, int y3, int y4, int left_shift,
+ const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift, const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min, int32 output_activation_max,
+ uint8* output_data, const Dims<4>& output_dims) {
+ constexpr int kReverseShift = -1;
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, 0);
+ TFLITE_DCHECK_EQ(output_activation_max, 255);
+ }
+ tflite::ArithmeticParams op_params;
+ op_params.broadcast_category =
+ tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.broadcast_shape[4] = y0;
+ op_params.broadcast_shape[3] = y1;
+ op_params.broadcast_shape[2] = y2;
+ op_params.broadcast_shape[1] = y3;
+ op_params.broadcast_shape[0] = y4;
+ BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
+ output_activation_min, output_activation_max, output_data,
+ output_dims);
+}
+
+template <FusedActivationFunctionType Ac>
+inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
+ int input1_shift, const int16* input2_data,
+ const Dims<4>& input2_dims, int input2_shift,
+ int16 output_activation_min, int16 output_activation_max,
+ int16* output_data, const Dims<4>& output_dims) {
+ static_assert(Ac == FusedActivationFunctionType::kNone ||
+ Ac == FusedActivationFunctionType::kRelu ||
+ Ac == FusedActivationFunctionType::kRelu6 ||
+ Ac == FusedActivationFunctionType::kRelu1,
+ "");
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ if (Ac == FusedActivationFunctionType::kNone) {
+ TFLITE_DCHECK_EQ(output_activation_min, -32768);
+ TFLITE_DCHECK_EQ(output_activation_max, 32767);
+ }
+
+ tflite::ArithmeticParams op_params;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_shift = kReverseShift * input2_shift;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ Add(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(FusedActivationFunctionType::kNone,
+ &output_activation_min, &output_activation_max);
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
+ const Dims<4>& input2_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = std::numeric_limits<T>::min();
+ op_params.quantized_activation_max = std::numeric_limits<T>::max();
+ Sub(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
int stride_width, int stride_height, int pad_width,
int pad_height, int kwidth, int kheight,
float output_activation_min,
float output_activation_max, float* output_data,
const Dims<4>& output_dims) {
- AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -104,10 +346,17 @@ inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims,
int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- AveragePool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ AveragePool(params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -148,9 +397,17 @@ inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
int pad_height, int kwidth, int kheight,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
- MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, kwidth, kheight, output_activation_min,
- output_activation_max, output_data, DimsToShape(output_dims));
+ tflite::PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = kheight;
+ params.filter_width = kwidth;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -180,10 +437,17 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
int pad_height, int filter_width, int filter_height,
int32 output_activation_min, int32 output_activation_max,
uint8* output_data, const Dims<4>& output_dims) {
- MaxPool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.quantized_activation_min = output_activation_min;
+ params.quantized_activation_max = output_activation_max;
+ MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code
@@ -223,10 +487,17 @@ inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
int pad_height, int filter_width, int filter_height,
float output_activation_min, float output_activation_max,
float* output_data, const Dims<4>& output_dims) {
- L2Pool(input_data, DimsToShape(input_dims), stride_width, stride_height,
- pad_width, pad_height, filter_width, filter_height,
- output_activation_min, output_activation_max, output_data,
- DimsToShape(output_dims));
+ PoolParams params;
+ params.stride_height = stride_height;
+ params.stride_width = stride_width;
+ params.filter_height = filter_height;
+ params.filter_width = filter_width;
+ params.padding_values.height = pad_height;
+ params.padding_values.width = pad_width;
+ params.float_activation_min = output_activation_min;
+ params.float_activation_max = output_activation_max;
+ L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
}
// legacy, for compatibility with old checked-in code