diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-05 16:32:38 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-05 16:36:34 -0800 |
commit | 179795c0067f05abe54904797288efebf6958b35 (patch) | |
tree | 3e75d52e1b35eef812f986418a5ab90b3198d4d2 /tensorflow/contrib/lite/kernels/concatenation.cc | |
parent | dcefe9bc65de55b6b3bc81c01adc3e41ec2d33aa (diff) |
Support negative axis in concatenation
PiperOrigin-RevId: 184605786
Diffstat (limited to 'tensorflow/contrib/lite/kernels/concatenation.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/concatenation.cc | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 9e7a1233da..7ff9075318 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -49,6 +49,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // dimensions except 'axis' must be equal. TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; TfLiteType input_type = t0->type; + if (axis < 0) axis += t0->dims->size; TF_LITE_ENSURE(context, axis >= 0); TF_LITE_ENSURE(context, axis < t0->dims->size); @@ -131,8 +132,9 @@ template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); - + int axis = params->axis; TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + if (axis < 0) axis += output->dims->size; // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // allocate and populate these during Prepare(). @@ -141,7 +143,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #define TF_LITE_CONCATENATION(type, scalar) \ VectorOfInputs<scalar> all_inputs(*context, *node->inputs); \ type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \ - RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + RemapDim(NumDimensions(output), axis), all_inputs.data(), \ all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \ GetTensorDims(output)) |