aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/optimize_for_inference_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/optimize_for_inference_lib.py')
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py52
1 files changed, 38 insertions, 14 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index 1cb5ba1625..8e040dcef7 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -48,6 +48,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import math
import re
import numpy as np
@@ -84,7 +85,8 @@ def optimize_for_inference(input_graph_def, input_node_names,
placeholder_type_enum)
optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
optimized_graph_def = fold_batch_norms(optimized_graph_def)
- optimized_graph_def = fuse_resize_and_conv(optimized_graph_def)
+ optimized_graph_def = fuse_resize_and_conv(optimized_graph_def,
+ output_node_names)
ensure_graph_is_valid(optimized_graph_def)
return optimized_graph_def
@@ -336,7 +338,7 @@ def fold_batch_norms(input_graph_def):
return result_graph_def
-def fuse_resize_and_conv(input_graph_def):
+def fuse_resize_and_conv(input_graph_def, output_node_names):
"""Merges preceding resize and mirror pad ops into a specialized convolution.
There's a common pattern of enlarging the input to a convolution using a
@@ -361,7 +363,14 @@ def fuse_resize_and_conv(input_graph_def):
else:
raise ValueError("Duplicate node names detected for ", node.name)
- nodes_to_skip = {}
+ node_reference_count = collections.defaultdict(int)
+ for node in input_graph_def.node:
+ for input_name in node.input:
+ stripped_name = node_name_from_input(input_name)
+ node_reference_count[stripped_name] += 1
+ for output_name in output_node_names:
+ node_reference_count[output_name] += 1
+
new_ops = []
for node in input_graph_def.node:
@@ -373,20 +382,31 @@ def fuse_resize_and_conv(input_graph_def):
if input_op.op == "MirrorPad":
mirror_pad_op = input_op
resize_op = node_from_map(input_node_map, mirror_pad_op.input[0])
+ if resize_op.op != "ResizeBilinear":
+ resize_op = None
else:
mirror_pad_op = None
- resize_op = input_op
+ if input_op.op == "ResizeBilinear":
+ resize_op = input_op
+ else:
+ resize_op = None
- if resize_op.op != "ResizeBilinear":
+ # There are no ops to be fused into the conv, so skip replacing this one.
+ if not mirror_pad_op and not resize_op:
continue
- nodes_to_skip[conv_op.name] = True
+ # We're replacing this node, so make sure the old one is removed.
+ node_reference_count[conv_op.name] = 0
if mirror_pad_op:
- nodes_to_skip[mirror_pad_op.name] = True
- nodes_to_skip[resize_op.name] = True
+ node_reference_count[mirror_pad_op.name] -= 1
+ if resize_op:
+ node_reference_count[resize_op.name] -= 1
fused_conv_op = tf.NodeDef()
- fused_conv_op.op = "FusedResizeAndPadConv2D"
+ if resize_op:
+ fused_conv_op.op = "FusedResizeAndPadConv2D"
+ else:
+ fused_conv_op.op = "FusedPadConv2D"
fused_conv_op.name = conv_op.name
if mirror_pad_op:
mirror_paddings_name = mirror_pad_op.input[1]
@@ -405,11 +425,15 @@ def fuse_resize_and_conv(input_graph_def):
new_ops.extend([paddings_op])
mirror_paddings_name = paddings_op.name
mirror_paddings_mode = tf.AttrValue(s=b"REFLECT")
- fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1],
- mirror_paddings_name, conv_op.input[1]])
+ if resize_op:
+ fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1],
+ mirror_paddings_name, conv_op.input[1]])
+ fused_conv_op.attr["resize_align_corners"].CopyFrom(
+ resize_op.attr["align_corners"])
+ else:
+ fused_conv_op.input.extend([mirror_pad_op.input[0], mirror_paddings_name,
+ conv_op.input[1]])
fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"])
- fused_conv_op.attr["resize_align_corners"].CopyFrom(
- resize_op.attr["align_corners"])
fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode)
fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"])
fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"])
@@ -417,7 +441,7 @@ def fuse_resize_and_conv(input_graph_def):
result_graph_def = tf.GraphDef()
for node in input_graph_def.node:
- if node.name in nodes_to_skip:
+ if node_reference_count[node.name] < 1:
continue
new_node = tf.NodeDef()
new_node.CopyFrom(node)