diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-07-30 13:48:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 13:51:34 -0700 |
commit | 37594fd0945061aee9f4d5f6ba2aa8c4b360697c (patch) | |
tree | f9c0a9bd9aca7f6c654be1b473143dfbbbe4af1d /tensorflow/python/tools | |
parent | a896aa5ac1e78fd9f71735769435bf8591c89bdd (diff) |
Adds error message in freeze_graph.py.
PiperOrigin-RevId: 206639908
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r-- | tensorflow/python/tools/freeze_graph.py | 32 | ||||
-rw-r--r-- | tensorflow/python/tools/freeze_graph_test.py | 67 |
2 files changed, 97 insertions, 2 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index e9f1def48c..4349699a94 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import argparse +import re import sys from google.protobuf import text_format @@ -116,16 +117,43 @@ def freeze_graph_with_def_protos(input_graph_def, var_list = {} reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) var_to_shape_map = reader.get_variable_to_shape_map() + + # List of all partition variables. Because the condition is heuristic + # based, the list could include false positives. + all_parition_variable_names = [ + tensor.name.split(":")[0] + for op in sess.graph.get_operations() + for tensor in op.values() + if re.search(r"/part_\d+/", tensor.name) + ] + has_partition_var = False + for key in var_to_shape_map: try: tensor = sess.graph.get_tensor_by_name(key + ":0") + if any(key in name for name in all_parition_variable_names): + has_partition_var = True except KeyError: # This tensor doesn't exist in the graph (for example it's # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor - saver = saver_lib.Saver( - var_list=var_list, write_version=checkpoint_version) + + try: + saver = saver_lib.Saver( + var_list=var_list, write_version=checkpoint_version) + except TypeError as e: + # `var_list` is required to be a map of variable names to Variable + # tensors. Partition variables are Identity tensors that cannot be + # handled by Saver. + if has_partition_var: + print("Models containing partition variables cannot be converted " + "from checkpoint files. Please pass in a SavedModel using " + "the flag --input_saved_model_dir.") + return -1 + else: + raise e + saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.replace(" ", "").split(",")) diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index 91f0061ebc..e38945fabc 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os +import re from tensorflow.core.example import example_pb2 from tensorflow.core.framework import graph_pb2 @@ -31,7 +32,10 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.saved_model import builder as saved_model_builder @@ -262,6 +266,69 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): output = sess.run(output_node, feed_dict={input_node: [example]}) self.assertNear(feature_value, output, 0.00001) + def testSinglePartitionedVariable(self): + """Ensures partitioned variables fail cleanly with freeze graph.""" + checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") + checkpoint_state_name = "checkpoint_state" + input_graph_name = "input_graph.pb" + output_graph_name = "output_graph.pb" + + # Create a graph with partition variables. When weights are partitioned into + # a single partition, the weights variable is followed by a identity -> + # identity (an additional identity node). + partitioner = partitioned_variables.fixed_size_partitioner(1) + with ops.Graph().as_default(): + with variable_scope.variable_scope("part", partitioner=partitioner): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros( + (batch_size, height, width, depth), name="input1") + input2 = array_ops.zeros( + (batch_size, height, width, depth), name="input2") + + num_nodes = depth + filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes]) + filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes]) + conv = nn.conv2d( + input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME") + node = math_ops.add(conv, input2, name="test/add") + node = nn.relu6(node, name="test/relu6") + + # Save graph and checkpoints. + sess = session.Session() + sess.run(variables.global_variables_initializer()) + + saver = saver_lib.Saver() + checkpoint_path = saver.save( + sess, + checkpoint_prefix, + global_step=0, + latest_filename=checkpoint_state_name) + graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) + + # Ensure this graph has partition variables. + self.assertTrue([ + tensor.name.split(":")[0] + for op in sess.graph.get_operations() + for tensor in op.values() + if re.search(r"/part_\d+/", tensor.name) + ]) + + # Test freezing graph doesn't make it crash. + output_node_names = "save/restore_all" + output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) + + return_value = freeze_graph.freeze_graph_with_def_protos( + input_graph_def=sess.graph_def, + input_saver_def=None, + input_checkpoint=checkpoint_path, + output_node_names=output_node_names, + restore_op_name="save/restore_all", # default value + filename_tensor_name="save/Const:0", # default value + output_graph=output_graph_path, + clear_devices=False, + initializer_nodes="") + self.assertTrue(return_value, -1) + if __name__ == "__main__": test.main() |