aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/python/tools
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/freeze_graph.py38
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py16
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py1
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py92
-rw-r--r--tensorflow/python/tools/saved_model_cli.py3
5 files changed, 95 insertions, 55 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 0ddf09260b..affa97062a 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -72,7 +72,8 @@ def freeze_graph_with_def_protos(input_graph_def,
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
- saved_model_tags=None):
+ saved_model_tags=None,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
@@ -100,7 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def,
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
- saver = saver_lib.Saver(saver_def=input_saver_def)
+ saver = saver_lib.Saver(
+ saver_def=input_saver_def, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
@@ -124,7 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
- saver = saver_lib.Saver(var_list=var_list)
+ saver = saver_lib.Saver(
+ var_list=var_list, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
@@ -217,7 +220,8 @@ def freeze_graph(input_graph,
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
- saved_model_tags=tag_constants.SERVING):
+ saved_model_tags=tag_constants.SERVING,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
input_graph_def = None
if input_saved_model_dir:
@@ -233,10 +237,21 @@ def freeze_graph(input_graph,
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
freeze_graph_with_def_protos(
- input_graph_def, input_saver_def, input_checkpoint, output_node_names,
- restore_op_name, filename_tensor_name, output_graph, clear_devices,
- initializer_nodes, variable_names_whitelist, variable_names_blacklist,
- input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(","))
+ input_graph_def,
+ input_saver_def,
+ input_checkpoint,
+ output_node_names,
+ restore_op_name,
+ filename_tensor_name,
+ output_graph,
+ clear_devices,
+ initializer_nodes,
+ variable_names_whitelist,
+ variable_names_blacklist,
+ input_meta_graph_def,
+ input_saved_model_dir,
+ saved_model_tags.split(","),
+ checkpoint_version=checkpoint_version)
def main(unused_args):
@@ -246,7 +261,7 @@ def main(unused_args):
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags)
+ FLAGS.saved_model_tags, FLAGS.checkpoint_version)
if __name__ == "__main__":
@@ -268,6 +283,11 @@ if __name__ == "__main__":
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
+ "--checkpoint_version",
+ type=int,
+ default=saver_pb2.SaverDef.V2,
+ help="Tensorflow variable file format")
+ parser.add_argument(
"--output_graph",
type=str,
default="",
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index feeed7102c..91f0061ebc 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -84,9 +84,19 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
input_meta_graph = checkpoint_meta_graph_file
freeze_graph.freeze_graph(
- input_graph_path, input_saver_def_path, input_binary, checkpoint_path,
- output_node_names, restore_op_name, filename_tensor_name,
- output_graph_path, clear_devices, "", "", input_meta_graph)
+ input_graph_path,
+ input_saver_def_path,
+ input_binary,
+ checkpoint_path,
+ output_node_names,
+ restore_op_name,
+ filename_tensor_name,
+ output_graph_path,
+ clear_devices,
+ "",
+ "",
+ input_meta_graph,
+ checkpoint_version=saver_write_version)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index c2687bf557..9c19271222 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def):
bias_add_op.op = "BiasAdd"
bias_add_op.name = node.name
bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
+ bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
bias_add_op.input.extend([new_conv_op.name, offset_op.name])
new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index 7686bb0f14..084a4500f8 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -173,48 +173,56 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
def testFoldFusedBatchNorms(self):
- with self.test_session() as sess:
- inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
- input_op = constant_op.constant(
- np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
- weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
- weights_op = constant_op.constant(
- np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
- conv_op = nn_ops.conv2d(
- input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
- mean_op = constant_op.constant(
- np.array([10, 20]), shape=[2], dtype=dtypes.float32)
- variance_op = constant_op.constant(
- np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
- beta_op = constant_op.constant(
- np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
- gamma_op = constant_op.constant(
- np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
- ops.get_default_graph().graph_def_versions.producer = 9
- gen_nn_ops._fused_batch_norm(
- conv_op,
- gamma_op,
- beta_op,
- mean_op,
- variance_op,
- 0.00001,
- is_training=False,
- name="output")
- original_graph_def = sess.graph_def
- original_result = sess.run(["output:0"])
- optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
- original_graph_def)
-
- with self.test_session() as sess:
- _ = importer.import_graph_def(
- optimized_graph_def, input_map={}, name="optimized")
- optimized_result = sess.run(["optimized/output:0"])
-
- self.assertAllClose(
- original_result, optimized_result, rtol=1e-04, atol=1e-06)
-
- for node in optimized_graph_def.node:
- self.assertNotEqual("FusedBatchNorm", node.op)
+ for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+ input_op = constant_op.constant(
+ np.array(inputs),
+ shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
+ dtype=dtypes.float32)
+ weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+ weights_op = constant_op.constant(
+ np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
+ conv_op = nn_ops.conv2d(
+ input_op,
+ weights_op, [1, 1, 1, 1],
+ padding="SAME",
+ data_format=data_format,
+ name="conv_op")
+ mean_op = constant_op.constant(
+ np.array([10, 20]), shape=[2], dtype=dtypes.float32)
+ variance_op = constant_op.constant(
+ np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
+ beta_op = constant_op.constant(
+ np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
+ gamma_op = constant_op.constant(
+ np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
+ ops.get_default_graph().graph_def_versions.producer = 9
+ gen_nn_ops._fused_batch_norm(
+ conv_op,
+ gamma_op,
+ beta_op,
+ mean_op,
+ variance_op,
+ 0.00001,
+ is_training=False,
+ data_format=data_format,
+ name="output")
+ original_graph_def = sess.graph_def
+ original_result = sess.run(["output:0"])
+ optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
+ original_graph_def)
+
+ with self.test_session(use_gpu=use_gpu) as sess:
+ _ = importer.import_graph_def(
+ optimized_graph_def, input_map={}, name="optimized")
+ optimized_result = sess.run(["optimized/output:0"])
+
+ self.assertAllClose(
+ original_result, optimized_result, rtol=1e-04, atol=1e-06)
+
+ for node in optimized_graph_def.node:
+ self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 667a4b1db8..33f6debbcb 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -31,6 +31,7 @@ import warnings
import numpy as np
+from six import integer_types
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.example import example_pb2
@@ -440,7 +441,7 @@ def _create_example_string(example_dict):
elif isinstance(feature_list[0], str):
example.features.feature[feature_name].bytes_list.value.extend(
feature_list)
- elif isinstance(feature_list[0], (int, long)):
+ elif isinstance(feature_list[0], integer_types):
example.features.feature[feature_name].int64_list.value.extend(
feature_list)
else: