diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 102 |
1 files changed, 80 insertions, 22 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 00f9616cc2..a027a47726 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -3398,10 +3398,12 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, } } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, - const Dims<4>& output_dims) { - const int flat_size = MatchingFlatSize(output_dims, input_dims); +inline void Dequantize(const tflite::DequantizationParams& op_params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, float* output_data) { + int32 zero_point = op_params.zero_point; + double scale = op_params.scale; + const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { int32 val = input_data[i]; @@ -3410,9 +3412,25 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, } } -inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, - float rmin, float rmax, int num_bits, float* output_data, - const Dims<4>& output_dims) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + tflite::DequantizationParams op_params; + op_params.zero_point = zero_point; + op_params.scale = scale; + + Dequantize(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + +inline void FakeQuant(const tflite::FakeQuantParams& op_params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + float rmin = op_params.minmax.min; + float rmax = op_params.minmax.max; + int num_bits = op_params.num_bits; // 0 should always be a representable value. Let's assume that the initial // min,max range contains 0. TFLITE_DCHECK_LE(rmin, 0.0f); @@ -3425,11 +3443,25 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, float nudged_min, nudged_max, nudged_scale; NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min, &nudged_max, &nudged_scale); - const int flat_size = MatchingFlatSize(output_dims, input_dims); + const int flat_size = MatchingFlatSize(input_shape, output_shape); FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data, output_data, flat_size); } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, int num_bits, float* output_data, + const Dims<4>& output_dims) { + tflite::FakeQuantParams op_params; + op_params.num_bits = num_bits; + op_params.minmax.min = rmin; + op_params.minmax.max = rmax; + + FakeQuant(op_params, DimsToShape(input_dims), input_data, + DimsToShape(output_dims), output_data); +} + template <typename SrcT, typename DstT> inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, const RuntimeShape& output_shape, DstT* output_data) { @@ -4050,22 +4082,32 @@ inline bool Mean(const T* input_data, const int* input_dims, } template <typename T> -inline void Mean(const T* input_data, const Dims<4>& input_dims, - const std::vector<int>& reduction_indices, T* output_data, - const Dims<4>& output_dims) { - const int output_batch = ArraySize(output_dims, 3); - const int output_height = ArraySize(output_dims, 2); - const int output_width = ArraySize(output_dims, 1); - const int output_depth = ArraySize(output_dims, 0); +inline void Mean(const tflite::MeanParams& op_params, + const RuntimeShape& unextended_input_shape, + const T* input_data, + const RuntimeShape& unextended_output_shape, T* output_data) { + gemmlowp::ScopedProfilingLabel label("Mean"); - const int input_height = ArraySize(input_dims, 2); - const int input_width = ArraySize(input_dims, 1); + TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); + TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); + RuntimeShape input_shape = + RuntimeShape::ExtendedShape(4, unextended_input_shape); + RuntimeShape output_shape = + RuntimeShape::ExtendedShape(4, unextended_output_shape); + + const int output_batch = output_shape.Dims(0); + const int output_height = output_shape.Dims(1); + const int output_width = output_shape.Dims(2); + const int output_depth = output_shape.Dims(3); + + const int input_height = input_shape.Dims(1); + const int input_width = input_shape.Dims(2); // The current implementation only supports simultaneous reduction over // width and height. - TFLITE_DCHECK_EQ(reduction_indices.size(), 2); - TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || - (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(op_params.axis_count, 2); + TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) || + (op_params.axis[0] == 2 && op_params.axis[1] == 1)); TFLITE_DCHECK_EQ(output_height, 1); TFLITE_DCHECK_EQ(output_width, 1); @@ -4074,15 +4116,31 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims, float value = 0; for (int in_h = 0; in_h < input_height; ++in_h) { for (int in_w = 0; in_w < input_width; ++in_w) { - value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)]; } } - output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + output_data[Offset(output_shape, out_b, 0, 0, out_d)] = value / (input_width * input_height); } } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy Dims<4>. +template <typename T> +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector<int>& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + tflite::MeanParams op_params; + op_params.axis_count = reduction_indices.size(); + for (int i = 0; i < op_params.axis_count; ++i) { + op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i]; + } + + Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims), + output_data); +} + // Computes the mean of elements across dimensions given in axis. // It does so in two stages, first calculates the sum of elements along the axis // then divides it by the number of element in axis for quantized values. |