diff options
author | 2018-07-30 10:53:43 -0700 | |
---|---|---|
committer | 2018-07-30 10:57:45 -0700 | |
commit | 50ba36f1662dc61cb1b60353a2a09aa3ea72bb59 (patch) | |
tree | 61383a4a9792353e7699fb6d995ac724acc7026d /tensorflow/contrib/lite/kernels/reshape.cc | |
parent | 538c7198fdba4ac9b71ba2ceb8c2fb0bb31d20e5 (diff) |
Fix issue with scalar reshape parameters
PiperOrigin-RevId: 206609783
Diffstat (limited to 'tensorflow/contrib/lite/kernels/reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/reshape.cc | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc index 99ecc16093..49ba0571e2 100644 --- a/tensorflow/contrib/lite/kernels/reshape.cc +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -37,10 +37,7 @@ TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node, // special -1 value, meaning it will be calculated automatically based on the // input. Here we calculate what that dimension should be so that the number // of output elements in the same as the number of input elements. - int num_input_elements = 1; - for (int i = 0; i < NumDimensions(input); ++i) { - num_input_elements *= SizeOfDimension(input, i); - } + int num_input_elements = NumElements(input); int num_output_elements = 1; int stretch_dim = -1; @@ -96,9 +93,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // The function is returned above this line if the shape tensor is usable. // Now fallback to the shape parameter in `TfLiteReshapeParams`. - - TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions); - for (int i = 0; i < params->num_dimensions; ++i) { + int num_dimensions = params->num_dimensions; + if (num_dimensions == 1 && params->shape[0] == 0) { + // Legacy tflite models use a shape parameter of [0] to indicate scalars, + // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during + // toco conversion. + num_dimensions = 0; + } + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions; ++i) { output_shape->data[i] = params->shape[i]; } return ResizeOutput(context, node, output_shape); |