aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/tools
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-07-30 13:48:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 13:51:34 -0700
commit37594fd0945061aee9f4d5f6ba2aa8c4b360697c (patch)
treef9c0a9bd9aca7f6c654be1b473143dfbbbe4af1d /tensorflow/python/tools
parenta896aa5ac1e78fd9f71735769435bf8591c89bdd (diff)
Adds error message in freeze_graph.py.
PiperOrigin-RevId: 206639908
Diffstat (limited to 'tensorflow/python/tools')
-rw-r--r--tensorflow/python/tools/freeze_graph.py32
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py67
2 files changed, 97 insertions, 2 deletions
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index e9f1def48c..4349699a94 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -38,6 +38,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import re
import sys
from google.protobuf import text_format
@@ -116,16 +117,43 @@ def freeze_graph_with_def_protos(input_graph_def,
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
+
+ # List of all partition variables. Because the condition is heuristic
+ # based, the list could include false positives.
+ all_parition_variable_names = [
+ tensor.name.split(":")[0]
+ for op in sess.graph.get_operations()
+ for tensor in op.values()
+ if re.search(r"/part_\d+/", tensor.name)
+ ]
+ has_partition_var = False
+
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
+ if any(key in name for name in all_parition_variable_names):
+ has_partition_var = True
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
- saver = saver_lib.Saver(
- var_list=var_list, write_version=checkpoint_version)
+
+ try:
+ saver = saver_lib.Saver(
+ var_list=var_list, write_version=checkpoint_version)
+ except TypeError as e:
+ # `var_list` is required to be a map of variable names to Variable
+ # tensors. Partition variables are Identity tensors that cannot be
+ # handled by Saver.
+ if has_partition_var:
+ print("Models containing partition variables cannot be converted "
+ "from checkpoint files. Please pass in a SavedModel using "
+ "the flag --input_saved_model_dir.")
+ return -1
+ else:
+ raise e
+
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index 91f0061ebc..e38945fabc 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+import re
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import graph_pb2
@@ -31,7 +32,10 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
@@ -262,6 +266,69 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output = sess.run(output_node, feed_dict={input_node: [example]})
self.assertNear(feature_value, output, 0.00001)
+ def testSinglePartitionedVariable(self):
+ """Ensures partitioned variables fail cleanly with freeze graph."""
+ checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
+ checkpoint_state_name = "checkpoint_state"
+ input_graph_name = "input_graph.pb"
+ output_graph_name = "output_graph.pb"
+
+ # Create a graph with partition variables. When weights are partitioned into
+ # a single partition, the weights variable is followed by a identity ->
+ # identity (an additional identity node).
+ partitioner = partitioned_variables.fixed_size_partitioner(1)
+ with ops.Graph().as_default():
+ with variable_scope.variable_scope("part", partitioner=partitioner):
+ batch_size, height, width, depth = 5, 128, 128, 3
+ input1 = array_ops.zeros(
+ (batch_size, height, width, depth), name="input1")
+ input2 = array_ops.zeros(
+ (batch_size, height, width, depth), name="input2")
+
+ num_nodes = depth
+ filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
+ filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
+ conv = nn.conv2d(
+ input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
+ node = math_ops.add(conv, input2, name="test/add")
+ node = nn.relu6(node, name="test/relu6")
+
+ # Save graph and checkpoints.
+ sess = session.Session()
+ sess.run(variables.global_variables_initializer())
+
+ saver = saver_lib.Saver()
+ checkpoint_path = saver.save(
+ sess,
+ checkpoint_prefix,
+ global_step=0,
+ latest_filename=checkpoint_state_name)
+ graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
+
+ # Ensure this graph has partition variables.
+ self.assertTrue([
+ tensor.name.split(":")[0]
+ for op in sess.graph.get_operations()
+ for tensor in op.values()
+ if re.search(r"/part_\d+/", tensor.name)
+ ])
+
+ # Test freezing graph doesn't make it crash.
+ output_node_names = "save/restore_all"
+ output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
+
+ return_value = freeze_graph.freeze_graph_with_def_protos(
+ input_graph_def=sess.graph_def,
+ input_saver_def=None,
+ input_checkpoint=checkpoint_path,
+ output_node_names=output_node_names,
+ restore_op_name="save/restore_all", # default value
+ filename_tensor_name="save/Const:0", # default value
+ output_graph=output_graph_path,
+ clear_devices=False,
+ initializer_nodes="")
+ self.assertTrue(return_value, -1)
+
if __name__ == "__main__":
test.main()