diff options
Diffstat (limited to 'tensorflow/python/tools/optimize_for_inference_test.py')
-rw-r--r-- | tensorflow/python/tools/optimize_for_inference_test.py | 33 |
1 files changed, 31 insertions, 2 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py index d92d7ab8c7..57a90fbfe0 100644 --- a/tensorflow/python/tools/optimize_for_inference_test.py +++ b/tensorflow/python/tools/optimize_for_inference_test.py @@ -54,6 +54,7 @@ class OptimizeForInferenceTest(tf.test.TestCase): shape=shape))) def testOptimizeForInference(self): + self.maxDiff = 1000 unused_constant_name = "unused_constant" unconnected_add_name = "unconnected_add" a_constant_name = "a_constant" @@ -183,7 +184,7 @@ class OptimizeForInferenceTest(tf.test.TestCase): original_graph_def = sess.graph_def original_result = sess.run(["output:0"]) optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv( - original_graph_def) + original_graph_def, ["output"]) with self.test_session() as sess: _ = tf.import_graph_def(optimized_graph_def, input_map={}, @@ -212,7 +213,7 @@ class OptimizeForInferenceTest(tf.test.TestCase): original_graph_def = sess.graph_def original_result = sess.run(["output:0"]) optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv( - original_graph_def) + original_graph_def, ["output"]) with self.test_session() as sess: _ = tf.import_graph_def(optimized_graph_def, input_map={}, @@ -225,6 +226,34 @@ class OptimizeForInferenceTest(tf.test.TestCase): self.assertNotEqual("Conv2D", node.op) self.assertNotEqual("ResizeBilinear", node.op) + def testFusePadAndConv(self): + with self.test_session() as sess: + inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] + input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2], + dtype=tf.float32) + pad_op = tf.pad(input_op, [[0, 0], [1, 1], [2, 2], [0, 0]], + mode="REFLECT") + weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] + weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2], + dtype=tf.float32) + tf.nn.conv2d(pad_op, weights_op, [1, 1, 1, 1], + padding="VALID", name="output") + original_graph_def = sess.graph_def + original_result = sess.run(["output:0"]) + optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv( + original_graph_def, ["output"]) + + with self.test_session() as sess: + _ = tf.import_graph_def(optimized_graph_def, input_map={}, + name="optimized") + optimized_result = sess.run(["optimized/output:0"]) + + self.assertAllClose(original_result, optimized_result) + + for node in optimized_graph_def.node: + self.assertNotEqual("Conv2D", node.op) + self.assertNotEqual("MirrorPad", node.op) + if __name__ == "__main__": tf.test.main() |