aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/sub.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-15 20:04:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-15 20:08:38 -0700
commiteadcdf91aa9e8ba6a196791ee349fd3474ffab76 (patch)
tree9d37beae82c55ea9c02d1c9aac014f2cfb996806 /tensorflow/contrib/lite/kernels/sub.cc
parent6c3c766dcabff3b5fa41dbfd491c9e8062a77b07 (diff)
add int32 support for sub
PiperOrigin-RevId: 204681037
Diffstat (limited to 'tensorflow/contrib/lite/kernels/sub.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc62
1 files changed, 39 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 1247525d41..541c85f756 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -78,29 +78,44 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
template <KernelType kernel_type>
-void EvalFloat(TfLiteContext* context, TfLiteNode* node,
- TfLiteSubParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- float output_activation_min, output_activation_max;
- CalculateActivationRange(params->activation, &output_activation_min,
- &output_activation_max);
-#define TF_LITE_SUB(type, opname) \
- type::opname(GetTensorData<float>(input1), GetTensorDims(input1), \
- GetTensorData<float>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<float>(output), GetTensorDims(output))
- if (kernel_type == kReference) {
- if (data->requires_broadcast) {
- TF_LITE_SUB(reference_ops, BroadcastSub);
+void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
+ const OpData* data, const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+#define TF_LITE_SUB(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
+ GetTensorData<data_type>(input2), GetTensorDims(input2), \
+ output_activation_min, output_activation_max, \
+ GetTensorData<data_type>(output), GetTensorDims(output))
+ if (output->type == kTfLiteInt32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(reference_ops, BroadcastSub, int32_t);
+ } else {
+ TF_LITE_SUB(reference_ops, Sub, int32_t);
+ }
} else {
- TF_LITE_SUB(reference_ops, Sub);
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(optimized_ops, BroadcastSub, int32_t);
+ } else {
+ TF_LITE_SUB(optimized_ops, Sub, int32_t);
+ }
}
- } else {
- if (data->requires_broadcast) {
- TF_LITE_SUB(optimized_ops, BroadcastSub);
+ } else if (output->type == kTfLiteFloat32) {
+ if (kernel_type == kReference) {
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(reference_ops, BroadcastSub, float);
+ } else {
+ TF_LITE_SUB(reference_ops, Sub, float);
+ }
} else {
- TF_LITE_SUB(optimized_ops, Sub);
+ if (data->requires_broadcast) {
+ TF_LITE_SUB(optimized_ops, BroadcastSub, float);
+ } else {
+ TF_LITE_SUB(optimized_ops, Sub, float);
+ }
}
}
#undef TF_LITE_SUB
@@ -171,14 +186,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (output->type == kTfLiteFloat32) {
- EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
+ EvalSub<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
context->ReportError(
- context, "output type %d is not supported, requires float|uint8 types.",
+ context,
+ "output type %d is not supported, requires float|uint8|int32 types.",
output->type);
return kTfLiteError;
}