diff options
Diffstat (limited to 'tensorflow/python/tools/optimize_for_inference_lib.py')
-rw-r--r-- | tensorflow/python/tools/optimize_for_inference_lib.py | 52 |
1 files changed, 38 insertions, 14 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py index 1cb5ba1625..8e040dcef7 100644 --- a/tensorflow/python/tools/optimize_for_inference_lib.py +++ b/tensorflow/python/tools/optimize_for_inference_lib.py @@ -48,6 +48,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import math import re import numpy as np @@ -84,7 +85,8 @@ def optimize_for_inference(input_graph_def, input_node_names, placeholder_type_enum) optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def) optimized_graph_def = fold_batch_norms(optimized_graph_def) - optimized_graph_def = fuse_resize_and_conv(optimized_graph_def) + optimized_graph_def = fuse_resize_and_conv(optimized_graph_def, + output_node_names) ensure_graph_is_valid(optimized_graph_def) return optimized_graph_def @@ -336,7 +338,7 @@ def fold_batch_norms(input_graph_def): return result_graph_def -def fuse_resize_and_conv(input_graph_def): +def fuse_resize_and_conv(input_graph_def, output_node_names): """Merges preceding resize and mirror pad ops into a specialized convolution. There's a common pattern of enlarging the input to a convolution using a @@ -361,7 +363,14 @@ def fuse_resize_and_conv(input_graph_def): else: raise ValueError("Duplicate node names detected for ", node.name) - nodes_to_skip = {} + node_reference_count = collections.defaultdict(int) + for node in input_graph_def.node: + for input_name in node.input: + stripped_name = node_name_from_input(input_name) + node_reference_count[stripped_name] += 1 + for output_name in output_node_names: + node_reference_count[output_name] += 1 + new_ops = [] for node in input_graph_def.node: @@ -373,20 +382,31 @@ def fuse_resize_and_conv(input_graph_def): if input_op.op == "MirrorPad": mirror_pad_op = input_op resize_op = node_from_map(input_node_map, mirror_pad_op.input[0]) + if resize_op.op != "ResizeBilinear": + resize_op = None else: mirror_pad_op = None - resize_op = input_op + if input_op.op == "ResizeBilinear": + resize_op = input_op + else: + resize_op = None - if resize_op.op != "ResizeBilinear": + # There are no ops to be fused into the conv, so skip replacing this one. + if not mirror_pad_op and not resize_op: continue - nodes_to_skip[conv_op.name] = True + # We're replacing this node, so make sure the old one is removed. + node_reference_count[conv_op.name] = 0 if mirror_pad_op: - nodes_to_skip[mirror_pad_op.name] = True - nodes_to_skip[resize_op.name] = True + node_reference_count[mirror_pad_op.name] -= 1 + if resize_op: + node_reference_count[resize_op.name] -= 1 fused_conv_op = tf.NodeDef() - fused_conv_op.op = "FusedResizeAndPadConv2D" + if resize_op: + fused_conv_op.op = "FusedResizeAndPadConv2D" + else: + fused_conv_op.op = "FusedPadConv2D" fused_conv_op.name = conv_op.name if mirror_pad_op: mirror_paddings_name = mirror_pad_op.input[1] @@ -405,11 +425,15 @@ def fuse_resize_and_conv(input_graph_def): new_ops.extend([paddings_op]) mirror_paddings_name = paddings_op.name mirror_paddings_mode = tf.AttrValue(s=b"REFLECT") - fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1], - mirror_paddings_name, conv_op.input[1]]) + if resize_op: + fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1], + mirror_paddings_name, conv_op.input[1]]) + fused_conv_op.attr["resize_align_corners"].CopyFrom( + resize_op.attr["align_corners"]) + else: + fused_conv_op.input.extend([mirror_pad_op.input[0], mirror_paddings_name, + conv_op.input[1]]) fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"]) - fused_conv_op.attr["resize_align_corners"].CopyFrom( - resize_op.attr["align_corners"]) fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode) fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"]) fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"]) @@ -417,7 +441,7 @@ def fuse_resize_and_conv(input_graph_def): result_graph_def = tf.GraphDef() for node in input_graph_def.node: - if node.name in nodes_to_skip: + if node_reference_count[node.name] < 1: continue new_node = tf.NodeDef() new_node.CopyFrom(node) |