aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar yegord <yegor.derevenets@gmail.com>2018-02-01 00:02:25 +0100
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-01-31 15:02:25 -0800
commit6afe900f543e0005ce69b3152330f1b7b16cb286 (patch)
tree2e7e4d67e1ec99cbc9238f6cb2a4cd838010f683
parentc24e3dd451aca36504fc9a69e2c1a01b6af2e854 (diff)
optimize_for_inference_lib.fold_batch_norms() preserves data_format (#16075)
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py1
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py89
2 files changed, 48 insertions, 42 deletions
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index c2687bf557..9c19271222 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def):
bias_add_op.op = "BiasAdd"
bias_add_op.name = node.name
bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
+ bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
bias_add_op.input.extend([new_conv_op.name, offset_op.name])
new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index 7686bb0f14..2ef612473b 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -173,48 +173,53 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
def testFoldFusedBatchNorms(self):
- with self.test_session() as sess:
- inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
- input_op = constant_op.constant(
- np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
- weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
- weights_op = constant_op.constant(
- np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
- conv_op = nn_ops.conv2d(
- input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
- mean_op = constant_op.constant(
- np.array([10, 20]), shape=[2], dtype=dtypes.float32)
- variance_op = constant_op.constant(
- np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
- beta_op = constant_op.constant(
- np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
- gamma_op = constant_op.constant(
- np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
- ops.get_default_graph().graph_def_versions.producer = 9
- gen_nn_ops._fused_batch_norm(
- conv_op,
- gamma_op,
- beta_op,
- mean_op,
- variance_op,
- 0.00001,
- is_training=False,
- name="output")
- original_graph_def = sess.graph_def
- original_result = sess.run(["output:0"])
- optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
- original_graph_def)
-
- with self.test_session() as sess:
- _ = importer.import_graph_def(
- optimized_graph_def, input_map={}, name="optimized")
- optimized_result = sess.run(["optimized/output:0"])
-
- self.assertAllClose(
- original_result, optimized_result, rtol=1e-04, atol=1e-06)
-
- for node in optimized_graph_def.node:
- self.assertNotEqual("FusedBatchNorm", node.op)
+ for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+ input_op = constant_op.constant(
+ np.array(inputs),
+ shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
+ dtype=dtypes.float32)
+ weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+ weights_op = constant_op.constant(
+ np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
+ conv_op = nn_ops.conv2d(
+ input_op, weights_op, [1, 1, 1, 1], padding="SAME",
+ data_format=data_format, name="conv_op")
+ mean_op = constant_op.constant(
+ np.array([10, 20]), shape=[2], dtype=dtypes.float32)
+ variance_op = constant_op.constant(
+ np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
+ beta_op = constant_op.constant(
+ np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
+ gamma_op = constant_op.constant(
+ np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
+ ops.get_default_graph().graph_def_versions.producer = 9
+ gen_nn_ops._fused_batch_norm(
+ conv_op,
+ gamma_op,
+ beta_op,
+ mean_op,
+ variance_op,
+ 0.00001,
+ is_training=False,
+ data_format=data_format,
+ name="output")
+ original_graph_def = sess.graph_def
+ original_result = sess.run(["output:0"])
+ optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
+ original_graph_def)
+
+ with self.test_session(use_gpu=use_gpu) as sess:
+ _ = importer.import_graph_def(
+ optimized_graph_def, input_map={}, name="optimized")
+ optimized_result = sess.run(["optimized/output:0"])
+
+ self.assertAllClose(
+ original_result, optimized_result, rtol=1e-04, atol=1e-06)
+
+ for node in optimized_graph_def.node:
+ self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
with self.test_session() as sess: