diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-24 20:39:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 20:43:58 -0700 |
commit | 626fef2af7d4bc49aeeef7ffd195dc30235bcd1e (patch) | |
tree | f81c1a5b95696897957619b5635537c73942b8fe /tensorflow/contrib/lite/kernels/fully_connected.cc | |
parent | 6ba60e051409a5346c2aab21160c9c311de1cb03 (diff) |
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214377809
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/fully_connected.cc | 66 |
1 files changed, 42 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index 7a71fcc219..f6d2f76dbe 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -281,15 +281,23 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -filter->params.zero_point; int32_t output_offset = output->params.zero_point; -#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ - type::FullyConnected( \ - GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \ - GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \ - GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \ - data->output_multiplier, data->output_shift, \ - data->output_activation_min, data->output_activation_max, \ - GetTensorData<output_data_type>(output), GetTensorDims(output), \ - gemm_context) +#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ + { \ + FullyConnectedParams op_params; \ + op_params.input_offset = input_offset; \ + op_params.weights_offset = filter_offset; \ + op_params.output_offset = output_offset; \ + op_params.output_multiplier = data->output_multiplier; \ + op_params.output_shift = -data->output_shift; \ + op_params.quantized_activation_min = data->output_activation_min; \ + op_params.quantized_activation_max = data->output_activation_max; \ + type::FullyConnected( \ + op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \ + GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ + GetTensorShape(bias), GetTensorData<int32_t>(bias), \ + GetTensorShape(output), GetTensorData<output_data_type>(output), \ + gemm_context); \ + } if (kernel_type == kReference) { switch (output->type) { case kTfLiteUInt8: @@ -349,15 +357,20 @@ TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, return kTfLiteError; } -#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ - type::ShuffledFullyConnected( \ - GetTensorData<uint8_t>(input), GetTensorDims(input), \ - GetTensorData<uint8_t>(filter), GetTensorDims(filter), \ - GetTensorData<int32_t>(bias), GetTensorDims(bias), \ - data->output_multiplier, data->output_shift, \ - data->output_activation_min, data->output_activation_max, \ - GetTensorData<int16_t>(output), GetTensorDims(output), \ - GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context) +#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ + { \ + FullyConnectedParams op_params; \ + op_params.output_multiplier = data->output_multiplier; \ + op_params.output_shift = -data->output_shift; \ + op_params.quantized_activation_min = data->output_activation_min; \ + op_params.quantized_activation_max = data->output_activation_max; \ + type::ShuffledFullyConnected( \ + op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \ + GetTensorShape(filter), GetTensorData<uint8_t>(filter), \ + GetTensorShape(bias), GetTensorData<int32_t>(bias), \ + GetTensorShape(output), GetTensorData<int16_t>(output), \ + GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context); \ + } if (kernel_type == kReference) { TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); } else { @@ -376,12 +389,17 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, float output_activation_min, output_activation_max; CalculateActivationRange(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_FULLY_CONNECTED(type) \ - type::FullyConnected(GetTensorData<float>(input), GetTensorDims(input), \ - GetTensorData<float>(filter), GetTensorDims(filter), \ - GetTensorData<float>(bias), GetTensorDims(bias), \ - output_activation_min, output_activation_max, \ - GetTensorData<float>(output), GetTensorDims(output)) +#define TF_LITE_FULLY_CONNECTED(type) \ + { \ + FullyConnectedParams op_params; \ + op_params.float_activation_min = output_activation_min; \ + op_params.float_activation_max = output_activation_max; \ + type::FullyConnected(op_params, GetTensorShape(input), \ + GetTensorData<float>(input), GetTensorShape(filter), \ + GetTensorData<float>(filter), GetTensorShape(bias), \ + GetTensorData<float>(bias), GetTensorShape(output), \ + GetTensorData<float>(output)); \ + } if (kernel_type == kReference) { TF_LITE_FULLY_CONNECTED(reference_ops); } else if (kernel_type == kPie) { |