aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-02-01 10:14:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-01 10:28:01 -0800
commiteef537ce8dbbca0b005565146f26873d75c6e234 (patch)
treece68b75ffcc85e9cd1a05d2e2e9be9f95997e67e
parentcfdd541f3f4bae0d50dad528aebd406333d5af34 (diff)
Allow exporting and loading graphs with no variables.
Change: 146257820
-rw-r--r--tensorflow/python/saved_model/builder_impl.py13
-rw-r--r--tensorflow/python/saved_model/loader_impl.py19
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py35
3 files changed, 55 insertions, 12 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index f80bba8562..4fd87c04ec 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -336,7 +336,7 @@ class SavedModelBuilder(object):
"""
if not self._has_saved_variables:
raise AssertionError(
- "Variables and assets have not been saved yet. "
+ "Graph state including variables and assets has not been saved yet. "
"Please invoke `add_meta_graph_and_variables()` first.")
# Validate the signature def map to ensure all included TensorInfos are
@@ -357,7 +357,8 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver(
variables.global_variables(),
sharded=True,
- write_version=saver_pb2.SaverDef.V2)
+ write_version=saver_pb2.SaverDef.V2,
+ allow_empty=True)
meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
@@ -394,8 +395,9 @@ class SavedModelBuilder(object):
main_op: Op or group of ops to execute when the graph is loaded.
"""
if self._has_saved_variables:
- raise AssertionError("Variables and assets have already been saved. "
- "Please invoke `add_meta_graph()` instead.")
+ raise AssertionError("Graph state including variables and assets has "
+ "already been saved. Please invoke "
+ "`add_meta_graph()` instead.")
# Validate the signature def map to ensure all included TensorInfos are
# properly populated.
@@ -426,7 +428,8 @@ class SavedModelBuilder(object):
saver = tf_saver.Saver(
variables.global_variables(),
sharded=True,
- write_version=saver_pb2.SaverDef.V2)
+ write_version=saver_pb2.SaverDef.V2,
+ allow_empty=True)
# Save the variables. Also, disable writing the checkpoint state proto. The
# file is not used during SavedModel loading. In addition, since a
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index 86f59d6805..5c0c016536 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -27,6 +27,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
+from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
@@ -210,14 +211,18 @@ def load(sess, tags, export_dir, **saver_kwargs):
# Build a saver by importing the meta graph def to load.
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
- # Build the checkpoint path where the variables are located.
- variables_path = os.path.join(
- compat.as_bytes(export_dir),
- compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
+ if saver:
+ # Build the checkpoint path where the variables are located.
+ variables_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.VARIABLES_DIRECTORY),
+ compat.as_bytes(constants.VARIABLES_FILENAME))
- # Restore the variables using the built saver in the provided session.
- saver.restore(sess, variables_path)
+ # Restore the variables using the built saver in the provided session.
+ saver.restore(sess, variables_path)
+ else:
+ tf_logging.info("The specified SavedModel has no variables; no "
+ "checkpoints were restored.")
# Get asset tensors, if any.
asset_tensors_dictionary = _get_asset_tensors(export_dir,
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 4ce19322af..024b52fccb 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -246,6 +246,41 @@ class SavedModelTest(test.TestCase):
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
+ def testGraphWithoutVariables(self):
+ export_dir = os.path.join(test.get_temp_dir(), "test_graph_has_variables")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Graph with no variables.
+ with self.test_session(graph=ops.Graph()) as sess:
+ constant_5_name = constant_op.constant(5.0).name
+ builder.add_meta_graph_and_variables(sess, ["foo"])
+
+ # Second graph with no variables
+ with self.test_session(graph=ops.Graph()) as sess:
+ constant_6_name = constant_op.constant(6.0).name
+ builder.add_meta_graph(["bar"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ # Restore the graph with tag "foo".
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ # Read the constant a from the graph.
+ a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
+ b = constant_op.constant(6.0)
+ c = a * b
+ self.assertEqual(30.0, sess.run(c))
+
+ # Restore the graph with tag "bar".
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["bar"], export_dir)
+ # Read the constant a from the graph.
+ a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
+ b = constant_op.constant(5.0)
+ c = a * b
+ self.assertEqual(30.0, sess.run(c))
+
def testNoOverwrite(self):
export_dir = os.path.join(test.get_temp_dir(), "test_no_overwrite")
builder = saved_model_builder.SavedModelBuilder(export_dir)