diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-26 23:41:11 -0700 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-06-28 21:37:43 -0700 |
commit | c30c57bd0792c50397883252ee5b2960988846d3 (patch) | |
tree | 5412a82ddd4ef7e3dde8c8389c88321e49608561 /tensorflow/contrib/lite/kernels/add.cc | |
parent | 92cc6352abaf2442c0d29755f87f6dbcd514a684 (diff) |
add int32 support for add
PiperOrigin-RevId: 202259189
Diffstat (limited to 'tensorflow/contrib/lite/kernels/add.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/add.cc | 60 |
1 files changed, 37 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index ccb957ebc5..f44d531cbf 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -170,29 +170,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template <KernelType kernel_type> -void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRangeFloat(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_ADD(type, opname) \ - type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \ - GetTensorData<float>(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData<float>(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_ADD(reference_ops, BroadcastAdd); +void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_ADD(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \ + GetTensorData<data_type>(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData<data_type>(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t); + } else { + TF_LITE_ADD(reference_ops, Add, int32_t); + } } else { - TF_LITE_ADD(reference_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t); + } else { + TF_LITE_ADD(optimized_ops, Add, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd, float); + } else { + TF_LITE_ADD(reference_ops, Add, float); + } } else { - TF_LITE_ADD(optimized_ops, Add); + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd, float); + } else { + TF_LITE_ADD(optimized_ops, Add, float); + } } } #undef TF_LITE_ADD @@ -251,9 +266,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalAddFloat<kernel_type>(context, node, params, data, input1, input2, - output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalAdd<kernel_type>(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { TF_LITE_ENSURE_OK(context, EvalAddQuantized<kernel_type>(context, node, params, data, |