aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc19
1 files changed, 13 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index f103bb94ae..d056a8add7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -285,7 +285,8 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
- op->stride_height, 1, 1, op->padding.type,
+ op->stride_height, op->dilation_width_factor,
+ op->dilation_height_factor, op->padding.type,
model->GetArray(output_name).mutable_shape(),
&op->padding.GetOrCreateFixedPadding());
}
@@ -658,11 +659,16 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
}
}
auto& output_array = model->GetArray(op->outputs[0]);
- // Use 0 input as basis for output dimensions.
- const auto& first_input_array = model->GetArray(op->inputs[0]);
- output_array.copy_shape(first_input_array.shape());
- // Negative axis means the count starts at the back of the dims().
- if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
+ // Use first non-empty input as basis for output dimensions.
+ for (const auto& input_name : op->inputs) {
+ const auto& input_array = model->GetArray(input_name);
+ if (input_array.shape().dimensions_count() > 0) {
+ output_array.copy_shape(input_array.shape());
+ // Negative axis means the count starts at the back of the dims().
+ if (op->axis < 0) op->axis += input_array.shape().dims().size();
+ break;
+ }
+ }
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
@@ -1655,6 +1661,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kLogicalAnd:
case OperatorType::kLogicalNot:
case OperatorType::kLogicalOr:
+ case OperatorType::kZerosLike:
ProcessSimpleOperator(model, op, 0);
break;
case OperatorType::kGather: