diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-08 15:16:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 15:20:45 -0700 |
commit | a1915c5f008cd7e6f01d563f83b36de783a76a0a (patch) | |
tree | 57b74f093d9b8b7b2f1f8b97af5cc520da9fe155 /tensorflow/contrib/lite/kernels/concatenation.cc | |
parent | 60bb01f1a7871958646669863a289960231be374 (diff) |
Added int32 concatenation to TFLite.
PiperOrigin-RevId: 207954437
Diffstat (limited to 'tensorflow/contrib/lite/kernels/concatenation.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/concatenation.cc | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index ad211e9c67..605a20ac3e 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -57,7 +57,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE(context, t0->dims->size <= 4); TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); TF_LITE_ENSURE(context, - input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || + input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || + input_type == kTfLiteInt64); // Output dimensions will match input dimensions, except 'axis', which // will be the sum of inputs @@ -121,6 +123,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION(optimized_ops, float); } break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int32); + } else { + TF_LITE_CONCATENATION(optimized_ops, int32); + } + break; case kTfLiteUInt8: if (kernel_type == kReference) { TF_LITE_CONCATENATION_QUANTIZED(reference_ops); @@ -128,6 +137,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_CONCATENATION_QUANTIZED(optimized_ops); } break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, int64_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, int64_t); + } + break; + default: context->ReportError(context, "Only float32 and uint8 are currently supported."); |