aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc67
1 files changed, 35 insertions, 32 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
index b7531ea2e2..d2f1103e14 100644
--- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -32,19 +32,21 @@ namespace tflite {
namespace {
void RunLogSoftmaxFloatReference(const uint8* input_data,
- const Dims<4>& dims_common, int32 input_offset,
- const double input_scale, int stride,
- float beta, uint8* reference_output_data) {
- const int ref_buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ int32 input_offset, const double input_scale,
+ int stride, float beta,
+ uint8* reference_output_data) {
+ const int ref_buffer_size = shape_common.FlatSize();
std::vector<float> reference_dequant_data(ref_buffer_size);
std::vector<float> reference_output_float_data(ref_buffer_size);
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
- reference_ops::Dequantize(input_data, dims_common, input_offset, input_scale,
- reference_dequant_data.data(), dims_common);
- optimized_ops::LogSoftmax(reference_dequant_data.data(), dims_common,
- reference_output_float_data.data(), dims_common);
+ reference_ops::Dequantize(
+ input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
+ reference_dequant_data.data(), ToRuntimeDims(shape_common));
+ optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data(), shape_common);
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -55,9 +57,9 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
}
void CheckOutputData(const uint8* test_output, const uint8* reference_output,
- const Dims<4>& dims_common, const string& check_label,
- bool be_exacting) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ const RuntimeShape& shape_common,
+ const string& check_label, bool be_exacting) {
+ const int buffer_size = shape_common.FlatSize();
// While calculating some metrics in floating point, we work with quantized
// scaling.
std::vector<int> diff(buffer_size);
@@ -99,15 +101,15 @@ void CheckOutputData(const uint8* test_output, const uint8* reference_output,
// Runs the LogSoftmax and compares against the float reference implementation
// and the quantized reference implementation.
-void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
- int32 input_offset, const double input_scale,
- int stride, float beta) {
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+void RunOneLogSoftmaxTest(const uint8* input_data,
+ const RuntimeShape& shape_common, int32 input_offset,
+ const double input_scale, int stride, float beta) {
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> optimized_logsoftmax_output(buffer_size);
std::vector<uint8> reference_float_logsoftmax_output(buffer_size);
std::vector<uint8> reference_quant_logsoftmax_output(buffer_size);
- RunLogSoftmaxFloatReference(input_data, dims_common, input_offset,
+ RunLogSoftmaxFloatReference(input_data, shape_common, input_offset,
input_scale, stride, beta,
reference_float_logsoftmax_output.data());
@@ -116,32 +118,33 @@ void RunOneLogSoftmaxTest(const uint8* input_data, const Dims<4>& dims_common,
int32 reverse_scaling_divisor;
int reverse_scaling_right_shift;
static const int kScaledDiffIntegerBits = 5;
- tflite::PreprocessLogSoftmaxScaling(
+ tflite::PreprocessLogSoftmaxScalingExp(
beta, input_scale, kScaledDiffIntegerBits, &input_beta_multiplier,
&input_beta_left_shift, &reverse_scaling_divisor,
&reverse_scaling_right_shift);
+ reverse_scaling_right_shift *= -1;
// diff_min has a negative value, and is used to limit the maximum magnitude
// of the diffs, which are <= 0.
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::LogSoftmax(input_data, dims_common, input_beta_multiplier,
+ optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier,
input_beta_left_shift, reverse_scaling_divisor,
reverse_scaling_right_shift, diff_min,
- optimized_logsoftmax_output.data(), dims_common);
+ optimized_logsoftmax_output.data(), shape_common);
reference_ops::LogSoftmax(
- input_data, dims_common, input_beta_multiplier, input_beta_left_shift,
+ input_data, shape_common, input_beta_multiplier, input_beta_left_shift,
reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
- reference_quant_logsoftmax_output.data(), dims_common);
+ reference_quant_logsoftmax_output.data(), shape_common);
CheckOutputData(optimized_logsoftmax_output.data(),
- reference_float_logsoftmax_output.data(), dims_common,
+ reference_float_logsoftmax_output.data(), shape_common,
"Optimized vs float reference", false);
CheckOutputData(optimized_logsoftmax_output.data(),
- reference_quant_logsoftmax_output.data(), dims_common,
+ reference_quant_logsoftmax_output.data(), shape_common,
"Optimized vs quant reference", true);
CheckOutputData(reference_quant_logsoftmax_output.data(),
- reference_float_logsoftmax_output.data(), dims_common,
+ reference_float_logsoftmax_output.data(), shape_common,
"Quant reference vs float reference", false);
}
@@ -164,13 +167,13 @@ bool TryOneUniformLogSoftmax() {
const int32 input_offset = UniformRandomInt(-256, 0);
static constexpr float beta = 1.0f;
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandom(&input_data);
- RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
return true;
}
@@ -202,14 +205,14 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) {
const int middle_min = UniformRandomInt(0, 255);
const int sides_max = UniformRandomInt(0, middle_min);
- Dims<4> dims_common =
- MakeDimsForInference(input_depth, input_width, input_height, batch);
- const int buffer_size = RequiredBufferSizeForDims(dims_common);
+ auto shape_common =
+ RuntimeShape({batch, input_height, input_width, input_depth});
+ const int buffer_size = shape_common.FlatSize();
std::vector<uint8> input_data(buffer_size);
FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min,
sides_max);
- RunOneLogSoftmaxTest(input_data.data(), dims_common, input_offset,
+ RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset,
input_scale, stride, beta);
return true;
}