diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-03-28 19:21:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-28 19:23:05 -0700 |
commit | a5a90e6b55c19bd14d5effa5cb1695ddbe31026f (patch) | |
tree | 507a8daf3a729b442f4e666ad731826655c8b3a4 /tensorflow/contrib/quantize | |
parent | 2b41d75654012f917cda1b54aee090d73086ab84 (diff) |
Relax limitations on rerouting graph outputs.
- Allow multiple outputs of output_tensors in fold_batch_norms.
- Allow duplicate consumers in quantize.
- I also quick a fix issue for matching final layers that have batch norm.
PiperOrigin-RevId: 190873003
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/python/fold_batch_norms.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 18 |
2 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 5750be6f4c..4a8f8a04cc 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -134,9 +134,9 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, match.output_tensor) - if nodes_modified_count != 1: - raise ValueError( - 'Unexpected inputs to op: %s' % match.output_tensor.name) + if nodes_modified_count == 0: + raise ValueError('Folding batch norms failed, %s had no outputs.' % + match.output_tensor.name) def _FindFusedBatchNorms(graph): diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 019d123a68..2889016a84 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -305,7 +305,8 @@ def _FindLayersToQuantize(graph): # the output of the final BiasAdd must be quantized. So we treat the BiasAdd # as the 'activation_op' in the _LayerMatch, to ensure that it's output is # quantized. - final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern) + final_layer_matcher = graph_matcher.GraphMatcher( + graph_matcher.OneofPattern([bias_add_pattern, folded_bias_add_pattern])) for match_result in final_layer_matcher.match_graph(graph): layer_op = match_result.get_op(layer_pattern) weight_tensor = match_result.get_tensor(weight_identity_pattern) @@ -463,11 +464,16 @@ def _InsertQuantOp(context, lambda: inputs, name=name_prefix + '/delayed_quant') - nodes_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) - if nodes_modified_count != len(consumers): - raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join( - [consumer.name for consumer in consumers])) + if consumers: + tensors_modified_count = graph_editor.reroute_ts( + [quant], [inputs], can_modify=consumers) + # Some operations can have multiple output tensors going to the same + # consumer. Since consumers is a set, we need to ensure that + # tensors_modified_count is greater than or equal to the length of the set + # of consumers. + if tensors_modified_count < len(consumers): + raise ValueError('No inputs quantized for ops: [%s]' % ', '.join( + [consumer.name for consumer in consumers])) def _GetContextFromOp(op): |