aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/concatenation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 15:16:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 15:20:45 -0700
commita1915c5f008cd7e6f01d563f83b36de783a76a0a (patch)
tree57b74f093d9b8b7b2f1f8b97af5cc520da9fe155 /tensorflow/contrib/lite/kernels/concatenation.cc
parent60bb01f1a7871958646669863a289960231be374 (diff)
Added int32 concatenation to TFLite.
PiperOrigin-RevId: 207954437
Diffstat (limited to 'tensorflow/contrib/lite/kernels/concatenation.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc19
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index ad211e9c67..605a20ac3e 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -57,7 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, t0->dims->size <= 4);
TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
TF_LITE_ENSURE(context,
- input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+ input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
+ input_type == kTfLiteInt64);
// Output dimensions will match input dimensions, except 'axis', which
// will be the sum of inputs
@@ -121,6 +123,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_CONCATENATION(optimized_ops, float);
}
break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, int32);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, int32);
+ }
+ break;
case kTfLiteUInt8:
if (kernel_type == kReference) {
TF_LITE_CONCATENATION_QUANTIZED(reference_ops);
@@ -128,6 +137,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_CONCATENATION_QUANTIZED(optimized_ops);
}
break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_CONCATENATION(reference_ops, int64_t);
+ } else {
+ TF_LITE_CONCATENATION(optimized_ops, int64_t);
+ }
+ break;
+
default:
context->ReportError(context,
"Only float32 and uint8 are currently supported.");