diff options
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 13 | ||||
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 19 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 35 |
3 files changed, 55 insertions, 12 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index f80bba8562..4fd87c04ec 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -336,7 +336,7 @@ class SavedModelBuilder(object): """ if not self._has_saved_variables: raise AssertionError( - "Variables and assets have not been saved yet. " + "Graph state including variables and assets has not been saved yet. " "Please invoke `add_meta_graph_and_variables()` first.") # Validate the signature def map to ensure all included TensorInfos are @@ -357,7 +357,8 @@ class SavedModelBuilder(object): saver = tf_saver.Saver( variables.global_variables(), sharded=True, - write_version=saver_pb2.SaverDef.V2) + write_version=saver_pb2.SaverDef.V2, + allow_empty=True) meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) @@ -394,8 +395,9 @@ class SavedModelBuilder(object): main_op: Op or group of ops to execute when the graph is loaded. """ if self._has_saved_variables: - raise AssertionError("Variables and assets have already been saved. " - "Please invoke `add_meta_graph()` instead.") + raise AssertionError("Graph state including variables and assets has " + "already been saved. Please invoke " + "`add_meta_graph()` instead.") # Validate the signature def map to ensure all included TensorInfos are # properly populated. @@ -426,7 +428,8 @@ class SavedModelBuilder(object): saver = tf_saver.Saver( variables.global_variables(), sharded=True, - write_version=saver_pb2.SaverDef.V2) + write_version=saver_pb2.SaverDef.V2, + allow_empty=True) # Save the variables. Also, disable writing the checkpoint state proto. The # file is not used during SavedModel loading. In addition, since a diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 86f59d6805..5c0c016536 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -27,6 +27,7 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.framework import ops from tensorflow.python.lib.io import file_io +from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat @@ -210,14 +211,18 @@ def load(sess, tags, export_dir, **saver_kwargs): # Build a saver by importing the meta graph def to load. saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs) - # Build the checkpoint path where the variables are located. - variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) + if saver: + # Build the checkpoint path where the variables are located. + variables_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.VARIABLES_DIRECTORY), + compat.as_bytes(constants.VARIABLES_FILENAME)) - # Restore the variables using the built saver in the provided session. - saver.restore(sess, variables_path) + # Restore the variables using the built saver in the provided session. + saver.restore(sess, variables_path) + else: + tf_logging.info("The specified SavedModel has no variables; no " + "checkpoints were restored.") # Get asset tensors, if any. asset_tensors_dictionary = _get_asset_tensors(export_dir, diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 4ce19322af..024b52fccb 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -246,6 +246,41 @@ class SavedModelTest(test.TestCase): self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"], export_dir) + def testGraphWithoutVariables(self): + export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Graph with no variables. + with self.test_session(graph=ops.Graph()) as sess: + constant_5_name = constant_op.constant(5.0).name + builder.add_meta_graph_and_variables(sess, ["foo"]) + + # Second graph with no variables + with self.test_session(graph=ops.Graph()) as sess: + constant_6_name = constant_op.constant(6.0).name + builder.add_meta_graph(["bar"]) + + # Save the SavedModel to disk. + builder.save() + + # Restore the graph with tag "foo". + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["foo"], export_dir) + # Read the constant a from the graph. + a = ops.get_default_graph().get_tensor_by_name(constant_5_name) + b = constant_op.constant(6.0) + c = a * b + self.assertEqual(30.0, sess.run(c)) + + # Restore the graph with tag "bar". + with self.test_session(graph=ops.Graph()) as sess: + loader.load(sess, ["bar"], export_dir) + # Read the constant a from the graph. + a = ops.get_default_graph().get_tensor_by_name(constant_6_name) + b = constant_op.constant(5.0) + c = a * b + self.assertEqual(30.0, sess.run(c)) + def testNoOverwrite(self): export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite") builder = saved_model_builder.SavedModelBuilder(export_dir) |