aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 06:54:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 06:59:45 -0700
commit3d30dd424c0404ea5349c0d2acdde2acd4e0aa97 (patch)
tree37e6bd9c7c643d139a314b3df41e29ea4aa27f4d /tensorflow/contrib/lite/kernels
parent234229b014cb0cfe4bf8e9466db79d596085faba (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214767788
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc113
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc51
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc27
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc32
-rw-r--r--tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc28
-rw-r--r--tensorflow/contrib/lite/kernels/log_softmax_test.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/softmax_test.cc12
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc5
10 files changed, 180 insertions, 120 deletions
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index b2d9b84979..cf9441aee3 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -348,18 +348,22 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
} break;
case kTfLiteInt16: {
- optimized_ops::Tanh(GetTensorData<int16_t>(input), GetTensorShape(input),
- data->input_left_shift,
- GetTensorData<int16_t>(output),
- GetTensorShape(output));
+ TanhParams params;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<int16_t>(input), GetTensorShape(output),
+ GetTensorData<int16_t>(output));
return kTfLiteOk;
} break;
case kTfLiteUInt8: {
- optimized_ops::Tanh(GetTensorData<uint8_t>(input), GetTensorShape(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ TanhParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
+ optimized_ops::Tanh(params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
return kTfLiteOk;
} break;
default:
@@ -385,17 +389,21 @@ TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt16: {
+ LogisticParams params;
optimized_ops::Logistic(
- GetTensorData<int16>(input), GetTensorShape(input),
- GetTensorData<int16_t>(output), GetTensorShape(output));
+ params, GetTensorShape(input), GetTensorData<int16_t>(input),
+ GetTensorShape(output), GetTensorData<int16_t>(output));
break;
}
case kTfLiteUInt8: {
+ LogisticParams params;
+ params.input_zero_point = input->params.zero_point;
+ params.input_range_radius = data->input_range_radius;
+ params.input_multiplier = data->input_multiplier;
+ params.input_left_shift = data->input_left_shift;
optimized_ops::Logistic(
- GetTensorData<uint8_t>(input), GetTensorShape(input),
- input->params.zero_point, data->input_range_radius,
- data->input_multiplier, data->input_left_shift,
- GetTensorData<uint8_t>(output), GetTensorShape(output));
+ params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
break;
}
default:
@@ -459,11 +467,13 @@ void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int intermediate_size = input->dims->data[1];
const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
GetTensorData<float>(input),
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
- params->beta, GetTensorData<float>(output),
- GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+ GetTensorData<float>(output));
}
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
@@ -473,10 +483,14 @@ void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// tensor is 4D in a special way. We will convert a (Y) shape into a (1,
// 1, 1, Y) shape.
const int input_size = input->dims->data[0];
- optimized_ops::Softmax(
- GetTensorData<uint8_t>(input), GetTensorShape({1, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift, data->diff_min,
- GetTensorData<uint8_t>(output), GetTensorShape({1, 1, 1, input_size}));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({1, 1, 1, input_size}),
+ GetTensorData<uint8_t>(output));
}
void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
@@ -486,11 +500,15 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
// 1, 1, Y) shape.
const int batch_size = input->dims->data[0];
const int input_size = input->dims->data[1];
- optimized_ops::Softmax(GetTensorData<uint8_t>(input),
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params,
+ GetTensorShape({batch_size, 1, 1, input_size}),
+ GetTensorData<uint8_t>(input),
GetTensorShape({batch_size, 1, 1, input_size}),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape({batch_size, 1, 1, input_size}));
+ GetTensorData<uint8_t>(output));
}
void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
@@ -498,28 +516,36 @@ void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
const int batch_size = input->dims->data[0];
const int intermediate_size = input->dims->data[1];
const int input_size = input->dims->data[2];
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::Softmax(
+ op_params, GetTensorShape({batch_size, intermediate_size, 1, input_size}),
GetTensorData<uint8_t>(input),
GetTensorShape({batch_size, intermediate_size, 1, input_size}),
- data->input_multiplier, data->input_left_shift, data->diff_min,
- GetTensorData<uint8_t>(output),
- GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+ GetTensorData<uint8_t>(output));
}
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
- optimized_ops::Softmax(GetTensorData<float>(input), GetTensorShape(input),
- params->beta, GetTensorData<float>(output),
- GetTensorShape(output));
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<float>(input), GetTensorShape(output),
+ GetTensorData<float>(output));
}
void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
- optimized_ops::Softmax(GetTensorData<uint8_t>(input), GetTensorShape(input),
- data->input_multiplier, data->input_left_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ optimized_ops::Softmax(op_params, GetTensorShape(input),
+ GetTensorData<uint8_t>(input), GetTensorShape(output),
+ GetTensorData<uint8_t>(output));
}
TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
@@ -591,17 +617,20 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
case kTfLiteFloat32:
+ SoftmaxParams op_params;
optimized_ops::LogSoftmax(
- GetTensorData<float>(input), GetTensorShape(input),
- GetTensorData<float>(output), GetTensorShape(output));
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
case kTfLiteUInt8:
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
+ op_params.reverse_scaling_right_shift = data->reverse_scaling_right_shift;
+ op_params.diff_min = data->diff_min;
optimized_ops::LogSoftmax(
- GetTensorData<uint8_t>(input), GetTensorShape(input),
- data->input_multiplier, data->input_left_shift,
- data->reverse_scaling_divisor, data->reverse_scaling_right_shift,
- data->diff_min, GetTensorData<uint8_t>(output),
- GetTensorShape(output));
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
default:
context->ReportError(context, "Only float32 supported currently., got %d",
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 4cd96348a2..f765235e04 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -83,20 +83,24 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
&input2_multiplier, &input2_shift); \
\
+ ComparisonParams op_params; \
+ op_params.left_shift = left_shift; \
+ op_params.input1_offset = input1_offset; \
+ op_params.input1_multiplier = input1_multiplier; \
+ op_params.input1_shift = -input1_shift; \
+ op_params.input2_offset = input2_offset; \
+ op_params.input2_multiplier = input2_multiplier; \
+ op_params.input2_shift = -input2_shift; \
if (requires_broadcast) { \
- reference_ops::Broadcast##opname( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- input2_offset, input2_multiplier, input2_shift, \
- GetTensorData<bool>(output), GetTensorDims(output)); \
+ reference_ops::Broadcast4DSlow##opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
} else { \
- reference_ops::opname( \
- left_shift, GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, input1_multiplier, input1_shift, \
- GetTensorData<uint8_t>(input2), GetTensorDims(input2), \
- input2_offset, input2_multiplier, input2_shift, \
- GetTensorData<bool>(output), GetTensorDims(output)); \
+ reference_ops::opname##WithScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
+ GetTensorShape(input2), GetTensorData<uint8_t>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
} \
} \
}
@@ -108,16 +112,19 @@ TF_LITE_QUANTIZE_COMPARISON(Less);
TF_LITE_QUANTIZE_COMPARISON(LessEqual);
#undef TF_LITE_QUANTIZE_COMPARISON
-#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
- requires_broadcast \
- ? reference_ops::Broadcast##opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output)) \
- : reference_ops::opname( \
- GetTensorData<type>(input1), GetTensorDims(input1), \
- GetTensorData<type>(input2), GetTensorDims(input2), \
- GetTensorData<bool>(output), GetTensorDims(output));
+#define TF_LITE_COMPARISON(type, opname, requires_broadcast) \
+ { \
+ ComparisonParams op_params; \
+ requires_broadcast \
+ ? reference_ops::Broadcast4DSlow##opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)) \
+ : reference_ops::opname##NoScaling( \
+ op_params, GetTensorShape(input1), GetTensorData<type>(input1), \
+ GetTensorShape(input2), GetTensorData<type>(input2), \
+ GetTensorShape(output), GetTensorData<bool>(output)); \
+ }
TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 3a08f48b00..59bf64e0af 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -77,13 +77,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
- auto zero_point = op_context.input->params.zero_point;
- auto scale = op_context.input->params.scale;
-
- optimized_ops::Dequantize(GetTensorData<uint8_t>(op_context.input),
- GetTensorDims(op_context.input), zero_point, scale,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = op_context.input->params.zero_point;
+ op_params.scale = op_context.input->params.scale;
+ optimized_ops::Dequantize(op_params, GetTensorShape(op_context.input),
+ GetTensorData<uint8_t>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
if (IsConstantTensor(op_context.input)) {
op_data->float_dequantized_weights_initialized = true;
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index 7945c095b1..8d4bb51006 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -81,24 +81,27 @@ template <KernelType kernel_type>
void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_DIV(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_DIV(type, opname, data_type) \
+ tflite::ArithmeticParams op_params; \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv, int32_t);
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, int32_t);
} else {
TF_LITE_DIV(reference_ops, Div, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv, int32_t);
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, int32_t);
} else {
TF_LITE_DIV(optimized_ops, Div, int32_t);
}
@@ -106,13 +109,13 @@ void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_DIV(reference_ops, BroadcastDiv, float);
+ TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, float);
} else {
TF_LITE_DIV(reference_ops, Div, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_DIV(optimized_ops, BroadcastDiv, float);
+ TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, float);
} else {
TF_LITE_DIV(optimized_ops, Div, float);
}
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index f9bc3747cb..b51af72fe6 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -68,11 +68,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params =
reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data);
- reference_ops::FakeQuant(GetTensorData<float>(op_context.input),
- GetTensorDims(op_context.input), params->min,
- params->max, params->num_bits,
- GetTensorData<float>(op_context.output),
- GetTensorDims(op_context.output));
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = params->num_bits;
+ op_params.minmax.min = params->min;
+ op_params.minmax.max = params->max;
+ reference_ops::FakeQuant(op_params, GetTensorShape(op_context.input),
+ GetTensorData<float>(op_context.input),
+ GetTensorShape(op_context.output),
+ GetTensorData<float>(op_context.output));
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
index 3624c20ae3..2252ca1bcc 100644
--- a/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/logsoftmax_quantized_test.cc
@@ -43,11 +43,15 @@ void RunLogSoftmaxFloatReference(const uint8* input_data,
// Reference data generated via Dequant of input into float, and then applying
// float LogSoftmax.
- reference_ops::Dequantize(
- input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
- reference_dequant_data.data(), ToRuntimeDims(shape_common));
- optimized_ops::LogSoftmax(reference_dequant_data.data(), shape_common,
- reference_output_float_data.data(), shape_common);
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ optimized_ops::LogSoftmax(sm_params, shape_common,
+ reference_dequant_data.data(), shape_common,
+ reference_output_float_data.data());
// Work with quantized scaling for LogSoftmax, under which 255 represents 0,
// and -16 gets nudged up to 0.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -129,14 +133,16 @@ void RunOneLogSoftmaxTest(const uint8* input_data,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::LogSoftmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, reverse_scaling_divisor,
- reverse_scaling_right_shift, diff_min,
- optimized_logsoftmax_output.data(), shape_common);
- reference_ops::LogSoftmax(
- input_data, shape_common, input_beta_multiplier, input_beta_left_shift,
- reverse_scaling_divisor, reverse_scaling_right_shift, diff_min,
- reference_quant_logsoftmax_output.data(), shape_common);
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ optimized_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ optimized_logsoftmax_output.data());
+ reference_ops::LogSoftmax(params, shape_common, input_data, shape_common,
+ reference_quant_logsoftmax_output.data());
CheckOutputData(optimized_logsoftmax_output.data(),
reference_float_logsoftmax_output.data(), shape_common,
diff --git a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
index ca94e7740e..831fb3c243 100644
--- a/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/softmax_quantized_test.cc
@@ -43,11 +43,15 @@ void RunSoftmaxFloatReference(const uint8* input_data,
// Reference data generated via Dequant of input into float, and then applying
// float Softmax.
- reference_ops::Dequantize(
- input_data, ToRuntimeDims(shape_common), input_offset, input_scale,
- reference_dequant_data.data(), ToRuntimeDims(shape_common));
- optimized_ops::Softmax(reference_dequant_data.data(), shape_common, beta,
- reference_output_float_data.data(), shape_common);
+ DequantizationParams dq_params;
+ dq_params.zero_point = input_offset;
+ dq_params.scale = input_scale;
+ reference_ops::Dequantize(dq_params, shape_common, input_data, shape_common,
+ reference_dequant_data.data());
+ SoftmaxParams sm_params;
+ sm_params.beta = beta;
+ optimized_ops::Softmax(sm_params, shape_common, reference_dequant_data.data(),
+ shape_common, reference_output_float_data.data());
// Work with quantized scaling for Softmax, under which 256 represents 1, but
// we limit this to 255.
for (int i = 0; i < ref_buffer_size; i++) {
@@ -116,12 +120,14 @@ void RunOneSoftmaxTest(const uint8* input_data,
const int diff_min = -tflite::CalculateInputRadius(kScaledDiffIntegerBits,
input_beta_left_shift);
- optimized_ops::Softmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, diff_min,
- optimized_softmax_output.data(), shape_common);
- reference_ops::Softmax(input_data, shape_common, input_beta_multiplier,
- input_beta_left_shift, diff_min,
- reference_quant_softmax_output.data(), shape_common);
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ optimized_ops::Softmax(params, shape_common, input_data, shape_common,
+ optimized_softmax_output.data());
+ reference_ops::Softmax(params, shape_common, input_data, shape_common,
+ reference_quant_softmax_output.data());
CheckOutputData(optimized_softmax_output.data(),
reference_float_softmax_output.data(), shape_common,
diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
index 9a8d35e82c..1acc966cdc 100644
--- a/tensorflow/contrib/lite/kernels/log_softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc
@@ -91,8 +91,9 @@ TEST(LogSoftmaxOpTest, CompareWithTFmini) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::LogSoftmax(input_buffer, input_shape,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ tflite::reference_ops::LogSoftmax(params, input_shape, input_buffer,
+ input_shape, output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc
index 727822f6be..bd66980226 100644
--- a/tensorflow/contrib/lite/kernels/softmax_test.cc
+++ b/tensorflow/contrib/lite/kernels/softmax_test.cc
@@ -93,8 +93,10 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
@@ -120,8 +122,10 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) {
std::unique_ptr<float[]> output_buffer(new float[input_size * batch_size]);
auto input_shape = RuntimeShape({batch_size, 1, 1, input_size});
- tflite::reference_ops::Softmax(input_buffer, input_shape, beta,
- output_buffer.get(), input_shape);
+ SoftmaxParams params;
+ params.beta = beta;
+ tflite::reference_ops::Softmax(params, input_shape, input_buffer, input_shape,
+ output_buffer.get());
std::vector<float> expected;
expected.insert(expected.end(), output_buffer.get(),
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index 178568e07c..349fa0bd28 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -210,8 +210,9 @@ TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) {
&indices_vector));
reference_ops::SparseToDense(indices_vector, GetTensorData<T>(values),
*GetTensorData<T>(default_value),
- GetTensorData<T>(output), GetTensorDims(output),
- value_is_scalar);
+ value_is_scalar, GetTensorShape(output),
+ GetTensorData<T>(output));
+
return kTfLiteOk;
}