aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 09:30:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 09:35:24 -0700
commitc0b63bef59bd2a94de2d1925259d1499d3ad04ea (patch)
tree9ebbe74ff125a2d205fa031334aea00cb7dd8626 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent7cd7a2e3877641da18182424bc7ea114fd7702ba (diff)
Allow empty arrays to occur as the first input to the concat op.
The conversion process fails for graphs that use tf.boolean_mask(..., axis=0) -- this op calls tf.concat with an empty array as the first argument. PiperOrigin-RevId: 214451470
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index f943da6d85..d056a8add7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -659,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
}
}
auto& output_array = model->GetArray(op->outputs[0]);
- // Use 0 input as basis for output dimensions.
- const auto& first_input_array = model->GetArray(op->inputs[0]);
- output_array.copy_shape(first_input_array.shape());
- // Negative axis means the count starts at the back of the dims().
- if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
+ // Use first non-empty input as basis for output dimensions.
+ for (const auto& input_name : op->inputs) {
+ const auto& input_array = model->GetArray(input_name);
+ if (input_array.shape().dimensions_count() > 0) {
+ output_array.copy_shape(input_array.shape());
+ // Negative axis means the count starts at the back of the dims().
+ if (op->axis < 0) op->axis += input_array.shape().dims().size();
+ break;
+ }
+ }
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;