aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools/optimize_for_inference_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/tools/optimize_for_inference_test.py')
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py33
1 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index d92d7ab8c7..57a90fbfe0 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -54,6 +54,7 @@ class OptimizeForInferenceTest(tf.test.TestCase):
shape=shape)))
def testOptimizeForInference(self):
+ self.maxDiff = 1000
unused_constant_name = "unused_constant"
unconnected_add_name = "unconnected_add"
a_constant_name = "a_constant"
@@ -183,7 +184,7 @@ class OptimizeForInferenceTest(tf.test.TestCase):
original_graph_def = sess.graph_def
original_result = sess.run(["output:0"])
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
- original_graph_def)
+ original_graph_def, ["output"])
with self.test_session() as sess:
_ = tf.import_graph_def(optimized_graph_def, input_map={},
@@ -212,7 +213,7 @@ class OptimizeForInferenceTest(tf.test.TestCase):
original_graph_def = sess.graph_def
original_result = sess.run(["output:0"])
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
- original_graph_def)
+ original_graph_def, ["output"])
with self.test_session() as sess:
_ = tf.import_graph_def(optimized_graph_def, input_map={},
@@ -225,6 +226,34 @@ class OptimizeForInferenceTest(tf.test.TestCase):
self.assertNotEqual("Conv2D", node.op)
self.assertNotEqual("ResizeBilinear", node.op)
+ def testFusePadAndConv(self):
+ with self.test_session() as sess:
+ inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+ input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
+ dtype=tf.float32)
+ pad_op = tf.pad(input_op, [[0, 0], [1, 1], [2, 2], [0, 0]],
+ mode="REFLECT")
+ weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+ weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
+ dtype=tf.float32)
+ tf.nn.conv2d(pad_op, weights_op, [1, 1, 1, 1],
+ padding="VALID", name="output")
+ original_graph_def = sess.graph_def
+ original_result = sess.run(["output:0"])
+ optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
+ original_graph_def, ["output"])
+
+ with self.test_session() as sess:
+ _ = tf.import_graph_def(optimized_graph_def, input_map={},
+ name="optimized")
+ optimized_result = sess.run(["optimized/output:0"])
+
+ self.assertAllClose(original_result, optimized_result)
+
+ for node in optimized_graph_def.node:
+ self.assertNotEqual("Conv2D", node.op)
+ self.assertNotEqual("MirrorPad", node.op)
+
if __name__ == "__main__":
tf.test.main()