aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/saved_model
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-03-20 21:07:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-20 22:28:38 -0700
commit9fe3c2fc2bc820b2426fa81d7baac5765a09eacd (patch)
tree9c3ee9ed604cf664f0e8698db08868b9ea7ff227 /tensorflow/contrib/saved_model
parent0b939e44258e1211bd0f95318641a6f853f09ce9 (diff)
Add a Reader API for SavedModel in contrib.
Change: 150717037
Diffstat (limited to 'tensorflow/contrib/saved_model')
-rw-r--r--tensorflow/contrib/saved_model/BUILD17
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/reader.py92
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/reader_test.py98
3 files changed, 207 insertions, 0 deletions
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index 42514c8379..6ab9631d29 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -41,6 +41,23 @@ py_library(
)
py_test(
+ name = "reader_test",
+ size = "small",
+ srcs = ["python/saved_model/reader_test.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":saved_model_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:tag_constants",
+ ],
+)
+
+py_test(
name = "signature_def_utils_test",
size = "small",
srcs = ["python/saved_model/signature_def_utils_test.py"],
diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader.py b/tensorflow/contrib/saved_model/python/saved_model/reader.py
new file mode 100644
index 0000000000..b9e5319181
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/reader.py
@@ -0,0 +1,92 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel functionality to read a SavedModel from disk."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from google.protobuf import message
+from google.protobuf import text_format
+from tensorflow.core.protobuf import saved_model_pb2
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.saved_model import constants
+from tensorflow.python.util import compat
+
+
+def read_saved_model(saved_model_dir):
+ """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
+
+ Args:
+ saved_model_dir: Directory containing the SavedModel file.
+
+ Returns:
+ A `SavedModel` protocol buffer.
+
+ Raises:
+ IOError: If the file does not exist, or cannot be successfully parsed.
+ """
+ # Build the path to the SavedModel in pbtxt format.
+ path_to_pbtxt = os.path.join(
+ compat.as_bytes(saved_model_dir),
+ compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
+ # Build the path to the SavedModel in pb format.
+ path_to_pb = os.path.join(
+ compat.as_bytes(saved_model_dir),
+ compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
+
+ # Ensure that the SavedModel exists at either path.
+ if not file_io.file_exists(path_to_pbtxt) and not file_io.file_exists(
+ path_to_pb):
+ raise IOError("SavedModel file does not exist at: %s" % saved_model_dir)
+
+ # Parse the SavedModel protocol buffer.
+ saved_model = saved_model_pb2.SavedModel()
+ if file_io.file_exists(path_to_pb):
+ try:
+ file_content = file_io.FileIO(path_to_pb, "rb").read()
+ saved_model.ParseFromString(file_content)
+ return saved_model
+ except message.DecodeError as e:
+ raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
+ elif file_io.file_exists(path_to_pbtxt):
+ try:
+ file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
+ text_format.Merge(file_content.decode("utf-8"), saved_model)
+ return saved_model
+ except text_format.ParseError as e:
+ raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
+ else:
+ raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
+ (saved_model_dir, constants.SAVED_MODEL_FILENAME_PBTXT,
+ constants.SAVED_MODEL_FILENAME_PB))
+
+
+def get_saved_model_tag_sets(saved_model_dir):
+ """Retrieves all the tag-sets available in the SavedModel.
+
+ Args:
+ saved_model_dir: Directory containing the SavedModel.
+
+ Returns:
+ String representation of all tag-sets in the SavedModel.
+ """
+ saved_model = read_saved_model(saved_model_dir)
+ all_tags = []
+ for meta_graph_def in saved_model.meta_graphs:
+ all_tags.append(list(meta_graph_def.meta_info_def.tags))
+ return all_tags
diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
new file mode 100644
index 0000000000..76d5a3e96d
--- /dev/null
+++ b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py
@@ -0,0 +1,98 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SavedModel Reader."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.saved_model.python.saved_model import reader
+from tensorflow.python.framework import ops
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder as saved_model_builder
+from tensorflow.python.saved_model import tag_constants
+
+
+def tearDownModule():
+ file_io.delete_recursively(test.get_temp_dir())
+
+
+class ReaderTest(test.TestCase):
+
+ def _init_and_validate_variable(self, sess, variable_name, variable_value):
+ v = variables.Variable(variable_value, name=variable_name)
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(variable_value, v.eval())
+
+ def testReadSavedModelValid(self):
+ saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model")
+ builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+ builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
+ builder.save()
+
+ actual_saved_model_pb = reader.read_saved_model(saved_model_dir)
+ self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1)
+ self.assertEqual(
+ len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1)
+ self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0],
+ tag_constants.TRAINING)
+
+ def testReadSavedModelInvalid(self):
+ saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model")
+ with self.assertRaisesRegexp(
+ IOError, "SavedModel file does not exist at: %s" % saved_model_dir):
+ reader.read_saved_model(saved_model_dir)
+
+ def testGetSavedModelTagSets(self):
+ saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags")
+ builder = saved_model_builder.SavedModelBuilder(saved_model_dir)
+
+ # Graph with a single variable. SavedModel invoked to:
+ # - add with weights.
+ # - a single tag (from predefined constants).
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+ builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
+
+ # Graph that updates the single variable. SavedModel invoked to:
+ # - simply add the model (weights are not updated).
+ # - a single tag (from predefined constants).
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 43)
+ builder.add_meta_graph([tag_constants.SERVING])
+
+ # Graph that updates the single variable. SavedModel is invoked:
+ # - to add the model (weights are not updated).
+ # - multiple custom tags.
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 44)
+ builder.add_meta_graph(["foo", "bar"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ actual_tags = reader.get_saved_model_tag_sets(saved_model_dir)
+ expected_tags = [["train"], ["serve"], ["foo", "bar"]]
+ self.assertEqual(expected_tags, actual_tags)
+
+
+if __name__ == "__main__":
+ test.main()