aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h102
1 files changed, 80 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 00f9616cc2..a027a47726 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -3398,10 +3398,12 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ int32 zero_point = op_params.zero_point;
+ double scale = op_params.scale;
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int32 val = input_data[i];
@@ -3410,9 +3412,25 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = zero_point;
+ op_params.scale = scale;
+
+ Dequantize(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const tflite::FakeQuantParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ float rmin = op_params.minmax.min;
+ float rmax = op_params.minmax.max;
+ int num_bits = op_params.num_bits;
// 0 should always be a representable value. Let's assume that the initial
// min,max range contains 0.
TFLITE_DCHECK_LE(rmin, 0.0f);
@@ -3425,11 +3443,25 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
float nudged_min, nudged_max, nudged_scale;
NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
&nudged_max, &nudged_scale);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
output_data, flat_size);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, int num_bits, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = num_bits;
+ op_params.minmax.min = rmin;
+ op_params.minmax.max = rmax;
+
+ FakeQuant(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
@@ -4050,22 +4082,32 @@ inline bool Mean(const T* input_data, const int* input_dims,
}
template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
+inline void Mean(const tflite::MeanParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mean");
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_batch = output_shape.Dims(0);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int output_depth = output_shape.Dims(3);
+
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
// The current implementation only supports simultaneous reduction over
// width and height.
- TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
- TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
- (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+ TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+ (op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
@@ -4074,15 +4116,31 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
float value = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
- value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
}
}
- output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
value / (input_width * input_height);
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::MeanParams op_params;
+ op_params.axis_count = reduction_indices.size();
+ for (int i = 0; i < op_params.axis_count; ++i) {
+ op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+ }
+
+ Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
// Computes the mean of elements across dimensions given in axis.
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis for quantized values.