diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 09:30:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 09:35:24 -0700 |
commit | c0b63bef59bd2a94de2d1925259d1499d3ad04ea (patch) | |
tree | 9ebbe74ff125a2d205fa031334aea00cb7dd8626 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 7cd7a2e3877641da18182424bc7ea114fd7702ba (diff) |
Allow empty arrays to occur as the first input to the concat op.
The conversion process fails for graphs that use tf.boolean_mask(..., axis=0) --
this op calls tf.concat with an empty array as the first argument.
PiperOrigin-RevId: 214451470
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 15 |
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index f943da6d85..d056a8add7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -659,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { } } auto& output_array = model->GetArray(op->outputs[0]); - // Use 0 input as basis for output dimensions. - const auto& first_input_array = model->GetArray(op->inputs[0]); - output_array.copy_shape(first_input_array.shape()); - // Negative axis means the count starts at the back of the dims(). - if (op->axis < 0) op->axis += first_input_array.shape().dims().size(); + // Use first non-empty input as basis for output dimensions. + for (const auto& input_name : op->inputs) { + const auto& input_array = model->GetArray(input_name); + if (input_array.shape().dimensions_count() > 0) { + output_array.copy_shape(input_array.shape()); + // Negative axis means the count starts at the back of the dims(). + if (op->axis < 0) op->axis += input_array.shape().dims().size(); + break; + } + } // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; |