aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-04 12:28:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-04 13:38:13 -0700
commit5ca373b4b64167f8b0fcab96d7d2e7886ea31b6a (patch)
treef82cba06cb52c035f834f8fb8c5daa9bee3ed9bb /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentbe9b87375adecad9bd8bb12c81b2566c77a68ad7 (diff)
Some fixes to support another TF graph:
1. Fix ResolveBatchNormalization to avoid deleting arrays that may still be used. 2. Correctly count the number of ops using a given array, even when some ops use the same array as more than one of their inputs. 3. In PropagateFixedSizes for Concatenation ops, when resolving a -1 wildcard to a fixed value, we were doing so in a local 'axis' variable without actually updating op->axis! The resulting -1 value still in op->axis tripped runtime code, causing the concatenation to misbehave during inference. PiperOrigin-RevId: 195454037
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 4923f83d91..b02b02c5be 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -670,8 +670,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
const auto& first_input_array = model->GetArray(op->inputs[0]);
output_array.copy_shape(first_input_array.shape());
// Negative axis means the count starts at the back of the dims().
- int axis = op->axis;
- if (axis < 0) axis += first_input_array.shape().dims().size();
+ if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
@@ -684,14 +683,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
CHECK_EQ(input_array.shape().dimensions_count(),
output_array.shape().dimensions_count());
const std::vector<int>& input_dims = input_array.shape().dims();
- CHECK_LT(axis, input_dims.size());
- concat_size += input_dims[axis];
+ CHECK_LT(op->axis, input_dims.size());
+ concat_size += input_dims[op->axis];
}
// Write out the concat_size on the output array shape.
auto& output_shape = *output_array.mutable_shape();
auto& output_dims = *output_shape.mutable_dims();
- CHECK_LT(axis, output_shape.dimensions_count());
- output_dims[axis] = concat_size;
+ CHECK_LT(op->axis, output_shape.dimensions_count());
+ output_dims[op->axis] = concat_size;
}
void ProcessRangeOperator(Model* model, RangeOperator* op) {