aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/add.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-26 23:41:11 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitc30c57bd0792c50397883252ee5b2960988846d3 (patch)
tree5412a82ddd4ef7e3dde8c8389c88321e49608561 /tensorflow/contrib/lite/kernels/add.cc
parent92cc6352abaf2442c0d29755f87f6dbcd514a684 (diff)
add int32 support for add
PiperOrigin-RevId: 202259189
Diffstat (limited to 'tensorflow/contrib/lite/kernels/add.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc60
1 files changed, 37 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index ccb957ebc5..f44d531cbf 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -170,29 +170,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalAddFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteAddParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRangeFloat(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_ADD(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_ADD(reference_ops, BroadcastAdd);
+void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_ADD(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
+ GetTensorData<data_type>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd, int32_t);
+ } else {
+ TF_LITE_ADD(reference_ops, Add, int32_t);
+ }
} else {
- TF_LITE_ADD(reference_ops, Add);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd, int32_t);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_ADD(optimized_ops, BroadcastAdd);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(reference_ops, BroadcastAdd, float);
+ } else {
+ TF_LITE_ADD(reference_ops, Add, float);
+ }
} else {
- TF_LITE_ADD(optimized_ops, Add);
+ if (data->requires_broadcast) {
+ TF_LITE_ADD(optimized_ops, BroadcastAdd, float);
+ } else {
+ TF_LITE_ADD(optimized_ops, Add, float);
+ }
}
}
#undef TF_LITE_ADD
@@ -251,9 +266,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalAddFloat<kernel_type>(context, node, params, data, input1, input2,
- output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
TF_LITE_ENSURE_OK(context,
EvalAddQuantized<kernel_type>(context, node, params, data,