aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/fully_connected.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-24 20:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 20:43:58 -0700
commit626fef2af7d4bc49aeeef7ffd195dc30235bcd1e (patch)
treef81c1a5b95696897957619b5635537c73942b8fe /tensorflow/contrib/lite/kernels/fully_connected.cc
parent6ba60e051409a5346c2aab21160c9c311de1cb03 (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214377809
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc66
1 files changed, 42 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index 7a71fcc219..f6d2f76dbe 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -281,15 +281,23 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
-#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
- type::FullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<output_data_type>(output), GetTensorDims(output), \
- gemm_context)
+#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.input_offset = input_offset; \
+ op_params.weights_offset = filter_offset; \
+ op_params.output_offset = output_offset; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::FullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<output_data_type>(output), \
+ gemm_context); \
+ }
if (kernel_type == kReference) {
switch (output->type) {
case kTfLiteUInt8:
@@ -349,15 +357,20 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
return kTfLiteError;
}
-#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
- type::ShuffledFullyConnected( \
- GetTensorData<uint8_t>(input), GetTensorDims(input), \
- GetTensorData<uint8_t>(filter), GetTensorDims(filter), \
- GetTensorData<int32_t>(bias), GetTensorDims(bias), \
- data->output_multiplier, data->output_shift, \
- data->output_activation_min, data->output_activation_max, \
- GetTensorData<int16_t>(output), GetTensorDims(output), \
- GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context)
+#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = -data->output_shift; \
+ op_params.quantized_activation_min = data->output_activation_min; \
+ op_params.quantized_activation_max = data->output_activation_max; \
+ type::ShuffledFullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<int16_t>(output), \
+ GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \
+ }
if (kernel_type == kReference) {
TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
} else {
@@ -376,12 +389,17 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
float output_activation_min, output_activation_max;
CalculateActivationRange(params->activation, &output_activation_min,
&output_activation_max);
-#define TF_LITE_FULLY_CONNECTED(type) \
- type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<float>(filter), GetTensorDims(filter), \
- GetTensorData<float>(bias), GetTensorDims(bias), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_FULLY_CONNECTED(type) \
+ { \
+ FullyConnectedParams op_params; \
+ op_params.float_activation_min = output_activation_min; \
+ op_params.float_activation_max = output_activation_max; \
+ type::FullyConnected(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(filter), \
+ GetTensorData<float>(filter), GetTensorShape(bias), \
+ GetTensorData<float>(bias), GetTensorShape(output), \
+ GetTensorData<float>(output)); \
+ }
if (kernel_type == kReference) {
TF_LITE_FULLY_CONNECTED(reference_ops);
} else if (kernel_type == kPie) {