diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc | 67 |
1 files changed, 35 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc index b7531ea2e2..d2f1103e14 100644 --- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -32,19 +32,21 @@ namespace tflite { namespace { void RunLogSoftmaxFloatReference(const uint8* input_data, - const Dims<4>& dims_common, int32 input_offset, - const double input_scale, int stride, - float beta, uint8* reference_output_data) { - const int ref_buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + int32 input_offset, const double input_scale, + int stride, float beta, + uint8* reference_output_data) { + const int ref_buffer_size = shape_common.FlatSize(); std::vector<float> reference_dequant_data(ref_buffer_size); std::vector<float> reference_output_float_data(ref_buffer_size); // Reference data generated via Dequant of input into float, and then applying // float LogSoftmax. - reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale, - reference_dequant_data.data(), dims_common); - optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common, - reference_output_float_data.data(), dims_common); + reference_ops::Dequantize( + input_data, ToRuntimeDims(shape_common), input_offset, input_scale, + reference_dequant_data.data(), ToRuntimeDims(shape_common)); + optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common, + reference_output_float_data.data(), shape_common); // Work with quantized scaling for LogSoftmax, under which 255 represents 0, // and -16 gets nudged up to 0. for (int i = 0; i < ref_buffer_size; i++) { @@ -55,9 +57,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, } void CheckOutputData(const uint8* test_output, const uint8* reference_output, - const Dims<4>& dims_common, const string& check_label, - bool be_exacting) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); + const RuntimeShape& shape_common, + const string& check_label, bool be_exacting) { + const int buffer_size = shape_common.FlatSize(); // While calculating some metrics in floating point, we work with quantized // scaling. std::vector<int> diff(buffer_size); @@ -99,15 +101,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output, // Runs the LogSoftmax and compares against the float reference implementation // and the quantized reference implementation. -void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, - int32 input_offset, const double input_scale, - int stride, float beta) { - const int buffer_size = RequiredBufferSizeForDims(dims_common); +void RunOneLogSoftmaxTest(const uint8* input_data, + const RuntimeShape& shape_common, int32 input_offset, + const double input_scale, int stride, float beta) { + const int buffer_size = shape_common.FlatSize(); std::vector<uint8> optimized_logsoftmax_output(buffer_size); std::vector<uint8> reference_float_logsoftmax_output(buffer_size); std::vector<uint8> reference_quant_logsoftmax_output(buffer_size); - RunLogSoftmaxFloatReference(input_data, dims_common, input_offset, + RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_logsoftmax_output.data()); @@ -116,32 +118,33 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common, int32 reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; - tflite::PreprocessLogSoftmaxScaling( + tflite::PreprocessLogSoftmaxScalingExp( beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier, &input_beta_left_shift, &reverse_scaling_divisor, &reverse_scaling_right_shift); + reverse_scaling_right_shift *= -1; // diff_min has a negative value, and is used to limit the maximum magnitude // of the diffs, which are <= 0. const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits, input_beta_left_shift); - optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier, + optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - optimized_logsoftmax_output.data(), dims_common); + optimized_logsoftmax_output.data(), shape_common); reference_ops::LogSoftmax( - input_data, dims_common, input_beta_multiplier, input_beta_left_shift, + input_data, shape_common, input_beta_multiplier, input_beta_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, - reference_quant_logsoftmax_output.data(), dims_common); + reference_quant_logsoftmax_output.data(), shape_common); CheckOutputData(optimized_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Optimized vs float reference", false); CheckOutputData(optimized_logsoftmax_output.data(), - reference_quant_logsoftmax_output.data(), dims_common, + reference_quant_logsoftmax_output.data(), shape_common, "Optimized vs quant reference", true); CheckOutputData(reference_quant_logsoftmax_output.data(), - reference_float_logsoftmax_output.data(), dims_common, + reference_float_logsoftmax_output.data(), shape_common, "Quant reference vs float reference", false); } @@ -164,13 +167,13 @@ bool TryOneUniformLogSoftmax() { const int32 input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector<uint8> input_data(buffer_size); FillRandom(&input_data); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } @@ -202,14 +205,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { const int middle_min = UniformRandomInt(0, 255); const int sides_max = UniformRandomInt(0, middle_min); - Dims<4> dims_common = - MakeDimsForInference(input_depth, input_width, input_height, batch); - const int buffer_size = RequiredBufferSizeForDims(dims_common); + auto shape_common = + RuntimeShape({batch, input_height, input_width, input_depth}); + const int buffer_size = shape_common.FlatSize(); std::vector<uint8> input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); - RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset, + RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); return true; } |