diff options
author | 2018-07-15 20:04:46 -0700 | |
---|---|---|
committer | 2018-07-15 20:08:38 -0700 | |
commit | eadcdf91aa9e8ba6a196791ee349fd3474ffab76 (patch) | |
tree | 9d37beae82c55ea9c02d1c9aac014f2cfb996806 /tensorflow/contrib/lite/kernels/sub.cc | |
parent | 6c3c766dcabff3b5fa41dbfd491c9e8062a77b07 (diff) |
add int32 support for sub
PiperOrigin-RevId: 204681037
Diffstat (limited to 'tensorflow/contrib/lite/kernels/sub.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/sub.cc | 62 |
1 files changed, 39 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index 1247525d41..541c85f756 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -78,29 +78,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template <KernelType kernel_type> -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteSubParams* params, const OpData* data, - const TfLiteTensor* input1, const TfLiteTensor* input2, - TfLiteTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); -#define TF_LITE_SUB(type, opname) \ - type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \ - GetTensorData<float>(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData<float>(output), GetTensorDims(output)) - if (kernel_type == kReference) { - if (data->requires_broadcast) { - TF_LITE_SUB(reference_ops, BroadcastSub); +void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params, + const OpData* data, const TfLiteTensor* input1, + const TfLiteTensor* input2, TfLiteTensor* output) { +#define TF_LITE_SUB(type, opname, data_type) \ + data_type output_activation_min, output_activation_max; \ + CalculateActivationRange(params->activation, &output_activation_min, \ + &output_activation_max); \ + type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \ + GetTensorData<data_type>(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData<data_type>(output), GetTensorDims(output)) + if (output->type == kTfLiteInt32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub, int32_t); + } else { + TF_LITE_SUB(reference_ops, Sub, int32_t); + } } else { - TF_LITE_SUB(reference_ops, Sub); + if (data->requires_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub, int32_t); + } else { + TF_LITE_SUB(optimized_ops, Sub, int32_t); + } } - } else { - if (data->requires_broadcast) { - TF_LITE_SUB(optimized_ops, BroadcastSub); + } else if (output->type == kTfLiteFloat32) { + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_SUB(reference_ops, BroadcastSub, float); + } else { + TF_LITE_SUB(reference_ops, Sub, float); + } } else { - TF_LITE_SUB(optimized_ops, Sub); + if (data->requires_broadcast) { + TF_LITE_SUB(optimized_ops, BroadcastSub, float); + } else { + TF_LITE_SUB(optimized_ops, Sub, float); + } } } #undef TF_LITE_SUB @@ -171,14 +186,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - if (output->type == kTfLiteFloat32) { - EvalFloat<kernel_type>(context, node, params, data, input1, input2, output); + if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { + EvalSub<kernel_type>(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { EvalQuantized<kernel_type>(context, node, params, data, input1, input2, output); } else { context->ReportError( - context, "output type %d is not supported, requires float|uint8 types.", + context, + "output type %d is not supported, requires float|uint8|int32 types.", output->type); return kTfLiteError; } |