aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/reshape.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-07-30 10:53:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 10:57:45 -0700
commit50ba36f1662dc61cb1b60353a2a09aa3ea72bb59 (patch)
tree61383a4a9792353e7699fb6d995ac724acc7026d /tensorflow/contrib/lite/kernels/reshape.cc
parent538c7198fdba4ac9b71ba2ceb8c2fb0bb31d20e5 (diff)
Fix issue with scalar reshape parameters
PiperOrigin-RevId: 206609783
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reshape.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc17
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 99ecc16093..49ba0571e2 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -37,10 +37,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node,
// special -1 value, meaning it will be calculated automatically based on the
// input. Here we calculate what that dimension should be so that the number
// of output elements in the same as the number of input elements.
- int num_input_elements = 1;
- for (int i = 0; i < NumDimensions(input); ++i) {
- num_input_elements *= SizeOfDimension(input, i);
- }
+ int num_input_elements = NumElements(input);
int num_output_elements = 1;
int stretch_dim = -1;
@@ -96,9 +93,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// The function is returned above this line if the shape tensor is usable.
// Now fallback to the shape parameter in `TfLiteReshapeParams`.
-
- TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
- for (int i = 0; i < params->num_dimensions; ++i) {
+ int num_dimensions = params->num_dimensions;
+ if (num_dimensions == 1 && params->shape[0] == 0) {
+ // Legacy tflite models use a shape parameter of [0] to indicate scalars,
+ // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
+ // toco conversion.
+ num_dimensions = 0;
+ }
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
+ for (int i = 0; i < num_dimensions; ++i) {
output_shape->data[i] = params->shape[i];
}
return ResizeOutput(context, node, output_shape);