diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index f103bb94ae..d056a8add7 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, - op->stride_height, 1, 1, op->padding.type, + op->stride_height, op->dilation_width_factor, + op->dilation_height_factor, op->padding.type, model->GetArray(output_name).mutable_shape(), &op->padding.GetOrCreateFixedPadding()); } @@ -658,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) { } } auto& output_array = model->GetArray(op->outputs[0]); - // Use 0 input as basis for output dimensions. - const auto& first_input_array = model->GetArray(op->inputs[0]); - output_array.copy_shape(first_input_array.shape()); - // Negative axis means the count starts at the back of the dims(). - if (op->axis < 0) op->axis += first_input_array.shape().dims().size(); + // Use first non-empty input as basis for output dimensions. + for (const auto& input_name : op->inputs) { + const auto& input_array = model->GetArray(input_name); + if (input_array.shape().dimensions_count() > 0) { + output_array.copy_shape(input_array.shape()); + // Negative axis means the count starts at the back of the dims(). + if (op->axis < 0) op->axis += input_array.shape().dims().size(); + break; + } + } // Determine the concat size, and enfore that all inputs have // the same dimensions count. int concat_size = 0; @@ -1655,6 +1661,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kLogicalAnd: case OperatorType::kLogicalNot: case OperatorType::kLogicalOr: + case OperatorType::kZerosLike: ProcessSimpleOperator(model, op, 0); break; case OperatorType::kGather: |