diff options
author | Sukriti Ramesh <sukritiramesh@google.com> | 2017-03-20 21:07:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-20 22:28:38 -0700 |
commit | 9fe3c2fc2bc820b2426fa81d7baac5765a09eacd (patch) | |
tree | 9c3ee9ed604cf664f0e8698db08868b9ea7ff227 /tensorflow/contrib/saved_model | |
parent | 0b939e44258e1211bd0f95318641a6f853f09ce9 (diff) |
Add a Reader API for SavedModel in contrib.
Change: 150717037
Diffstat (limited to 'tensorflow/contrib/saved_model')
-rw-r--r-- | tensorflow/contrib/saved_model/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/contrib/saved_model/python/saved_model/reader.py | 92 | ||||
-rw-r--r-- | tensorflow/contrib/saved_model/python/saved_model/reader_test.py | 98 |
3 files changed, 207 insertions, 0 deletions
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index 42514c8379..6ab9631d29 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -41,6 +41,23 @@ py_library( ) py_test( + name = "reader_test", + size = "small", + srcs = ["python/saved_model/reader_test.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":saved_model_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lib", + "//tensorflow/python:variables", + "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:tag_constants", + ], +) + +py_test( name = "signature_def_utils_test", size = "small", srcs = ["python/saved_model/signature_def_utils_test.py"], diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader.py b/tensorflow/contrib/saved_model/python/saved_model/reader.py new file mode 100644 index 0000000000..b9e5319181 --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/reader.py @@ -0,0 +1,92 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel functionality to read a SavedModel from disk.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from google.protobuf import message +from google.protobuf import text_format +from tensorflow.core.protobuf import saved_model_pb2 +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat + + +def read_saved_model(saved_model_dir): + """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. + + Args: + saved_model_dir: Directory containing the SavedModel file. + + Returns: + A `SavedModel` protocol buffer. + + Raises: + IOError: If the file does not exist, or cannot be successfully parsed. + """ + # Build the path to the SavedModel in pbtxt format. + path_to_pbtxt = os.path.join( + compat.as_bytes(saved_model_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) + # Build the path to the SavedModel in pb format. + path_to_pb = os.path.join( + compat.as_bytes(saved_model_dir), + compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) + + # Ensure that the SavedModel exists at either path. + if not file_io.file_exists(path_to_pbtxt) and not file_io.file_exists( + path_to_pb): + raise IOError("SavedModel file does not exist at: %s" % saved_model_dir) + + # Parse the SavedModel protocol buffer. + saved_model = saved_model_pb2.SavedModel() + if file_io.file_exists(path_to_pb): + try: + file_content = file_io.FileIO(path_to_pb, "rb").read() + saved_model.ParseFromString(file_content) + return saved_model + except message.DecodeError as e: + raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) + elif file_io.file_exists(path_to_pbtxt): + try: + file_content = file_io.FileIO(path_to_pbtxt, "rb").read() + text_format.Merge(file_content.decode("utf-8"), saved_model) + return saved_model + except text_format.ParseError as e: + raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e))) + else: + raise IOError("SavedModel file does not exist at: %s/{%s|%s}" % + (saved_model_dir, constants.SAVED_MODEL_FILENAME_PBTXT, + constants.SAVED_MODEL_FILENAME_PB)) + + +def get_saved_model_tag_sets(saved_model_dir): + """Retrieves all the tag-sets available in the SavedModel. + + Args: + saved_model_dir: Directory containing the SavedModel. + + Returns: + String representation of all tag-sets in the SavedModel. + """ + saved_model = read_saved_model(saved_model_dir) + all_tags = [] + for meta_graph_def in saved_model.meta_graphs: + all_tags.append(list(meta_graph_def.meta_info_def.tags)) + return all_tags diff --git a/tensorflow/contrib/saved_model/python/saved_model/reader_test.py b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py new file mode 100644 index 0000000000..76d5a3e96d --- /dev/null +++ b/tensorflow/contrib/saved_model/python/saved_model/reader_test.py @@ -0,0 +1,98 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SavedModel Reader.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.saved_model.python.saved_model import reader +from tensorflow.python.framework import ops +from tensorflow.python.lib.io import file_io +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import tag_constants + + +def tearDownModule(): + file_io.delete_recursively(test.get_temp_dir()) + + +class ReaderTest(test.TestCase): + + def _init_and_validate_variable(self, sess, variable_name, variable_value): + v = variables.Variable(variable_value, name=variable_name) + sess.run(variables.global_variables_initializer()) + self.assertEqual(variable_value, v.eval()) + + def testReadSavedModelValid(self): + saved_model_dir = os.path.join(test.get_temp_dir(), "valid_saved_model") + builder = saved_model_builder.SavedModelBuilder(saved_model_dir) + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) + builder.save() + + actual_saved_model_pb = reader.read_saved_model(saved_model_dir) + self.assertEqual(len(actual_saved_model_pb.meta_graphs), 1) + self.assertEqual( + len(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags), 1) + self.assertEqual(actual_saved_model_pb.meta_graphs[0].meta_info_def.tags[0], + tag_constants.TRAINING) + + def testReadSavedModelInvalid(self): + saved_model_dir = os.path.join(test.get_temp_dir(), "invalid_saved_model") + with self.assertRaisesRegexp( + IOError, "SavedModel file does not exist at: %s" % saved_model_dir): + reader.read_saved_model(saved_model_dir) + + def testGetSavedModelTagSets(self): + saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags") + builder = saved_model_builder.SavedModelBuilder(saved_model_dir) + + # Graph with a single variable. SavedModel invoked to: + # - add with weights. + # - a single tag (from predefined constants). + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) + + # Graph that updates the single variable. SavedModel invoked to: + # - simply add the model (weights are not updated). + # - a single tag (from predefined constants). + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 43) + builder.add_meta_graph([tag_constants.SERVING]) + + # Graph that updates the single variable. SavedModel is invoked: + # - to add the model (weights are not updated). + # - multiple custom tags. + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 44) + builder.add_meta_graph(["foo", "bar"]) + + # Save the SavedModel to disk. + builder.save() + + actual_tags = reader.get_saved_model_tag_sets(saved_model_dir) + expected_tags = [["train"], ["serve"], ["foo", "bar"]] + self.assertEqual(expected_tags, actual_tags) + + +if __name__ == "__main__": + test.main() |