aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-04 12:28:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-04 13:38:13 -0700
commit5ca373b4b64167f8b0fcab96d7d2e7886ea31b6a (patch)
treef82cba06cb52c035f834f8fb8c5daa9bee3ed9bb
parentbe9b87375adecad9bd8bb12c81b2566c77a68ad7 (diff)
Some fixes to support another TF graph:
1. Fix ResolveBatchNormalization to avoid deleting arrays that may still be used. 2. Correctly count the number of ops using a given array, even when some ops use the same array as more than one of their inputs. 3. In PropagateFixedSizes for Concatenation ops, when resolving a -1 wildcard to a fixed value, we were doing so in a local 'axis' variable without actually updating op->axis! The resulting -1 value still in op->axis tripped runtime code, causing the concatenation to misbehave during inference. PiperOrigin-RevId: 195454037
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc11
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc6
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc4
3 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 4923f83d91..b02b02c5be 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -670,8 +670,7 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
const auto& first_input_array = model->GetArray(op->inputs[0]);
output_array.copy_shape(first_input_array.shape());
// Negative axis means the count starts at the back of the dims().
- int axis = op->axis;
- if (axis < 0) axis += first_input_array.shape().dims().size();
+ if (op->axis < 0) op->axis += first_input_array.shape().dims().size();
// Determine the concat size, and enfore that all inputs have
// the same dimensions count.
int concat_size = 0;
@@ -684,14 +683,14 @@ void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
CHECK_EQ(input_array.shape().dimensions_count(),
output_array.shape().dimensions_count());
const std::vector<int>& input_dims = input_array.shape().dims();
- CHECK_LT(axis, input_dims.size());
- concat_size += input_dims[axis];
+ CHECK_LT(op->axis, input_dims.size());
+ concat_size += input_dims[op->axis];
}
// Write out the concat_size on the output array shape.
auto& output_shape = *output_array.mutable_shape();
auto& output_dims = *output_shape.mutable_dims();
- CHECK_LT(axis, output_shape.dimensions_count());
- output_dims[axis] = concat_size;
+ CHECK_LT(op->axis, output_shape.dimensions_count());
+ output_dims[op->axis] = concat_size;
}
void ProcessRangeOperator(Model* model, RangeOperator* op) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
index 2b3ee36ad1..8f2c1f8162 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -134,9 +134,9 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
}
// Remove the old param arrays
- model->EraseArray(bn_op->inputs[1]);
- model->EraseArray(bn_op->inputs[2]);
- model->EraseArray(bn_op->inputs[3]);
+ DeleteArrayIfUsedOnce(bn_op->inputs[1], model);
+ DeleteArrayIfUsedOnce(bn_op->inputs[2], model);
+ DeleteArrayIfUsedOnce(bn_op->inputs[3], model);
// Remove the old operator
DCHECK_EQ(bn_it->get(), bn_op);
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 86ee1f3761..341d45e753 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -143,6 +143,10 @@ int CountOpsWithInput(const Model& model, const string& array_name) {
for (auto& input : op->inputs) {
if (input == array_name) {
count++;
+ // Breaking here is important: some graphs have ops that use the
+ // same array as more than one of their inputs, and in that case
+ // we want it counted only once.
+ break;
}
}
}