aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/concatenation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-05 16:32:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 16:36:34 -0800
commit179795c0067f05abe54904797288efebf6958b35 (patch)
tree3e75d52e1b35eef812f986418a5ab90b3198d4d2 /tensorflow/contrib/lite/kernels/concatenation.cc
parentdcefe9bc65de55b6b3bc81c01adc3e41ec2d33aa (diff)
Support negative axis in concatenation
PiperOrigin-RevId: 184605786
Diffstat (limited to 'tensorflow/contrib/lite/kernels/concatenation.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 9e7a1233da..7ff9075318 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -49,6 +49,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// dimensions except 'axis' must be equal.
TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]];
TfLiteType input_type = t0->type;
+ if (axis < 0) axis += t0->dims->size;
TF_LITE_ENSURE(context, axis >= 0);
TF_LITE_ENSURE(context, axis < t0->dims->size);
@@ -131,8 +132,9 @@ template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
-
+ int axis = params->axis;
TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ if (axis < 0) axis += output->dims->size;
// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
// allocate and populate these during Prepare().
@@ -141,7 +143,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
#define TF_LITE_CONCATENATION(type, scalar) \
VectorOfInputs<scalar> all_inputs(*context, *node->inputs); \
type::Concatenation<FusedActivationFunctionType::kNone, scalar>( \
- RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \
+ RemapDim(NumDimensions(output), axis), all_inputs.data(), \
all_inputs.dims(), node->inputs->size, GetTensorData<scalar>(output), \
GetTensorDims(output))