From 5cff7dc88dd8275c029e70b23c0e55b0775f9974 Mon Sep 17 00:00:00 2001 From: Sukriti Ramesh Date: Tue, 31 Jan 2017 12:18:43 -0800 Subject: SavedModel updates for TensorInfo validation and exposing existing APIs as tf.saved_model. Change: 146150019 --- tensorflow/python/saved_model/BUILD | 15 ++++-- tensorflow/python/saved_model/builder_impl.py | 54 +++++++++++++++++++++ tensorflow/python/saved_model/constants.py | 30 ++++++++++++ tensorflow/python/saved_model/main_op.py | 48 +++++-------------- tensorflow/python/saved_model/main_op_impl.py | 57 +++++++++++++++++++++++ tensorflow/python/saved_model/saved_model.py | 6 +++ tensorflow/python/saved_model/saved_model_test.py | 45 ++++++++++++++++++ tensorflow/python/saved_model/utils.py | 26 +++-------- tensorflow/python/saved_model/utils_impl.py | 41 ++++++++++++++++ 9 files changed, 262 insertions(+), 60 deletions(-) create mode 100644 tensorflow/python/saved_model/main_op_impl.py create mode 100644 tensorflow/python/saved_model/utils_impl.py diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 47980e6ff3..79399c11c4 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -20,10 +20,13 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":builder", + ":constants", ":loader", + ":main_op", ":signature_constants", ":signature_def_utils", ":tag_constants", + ":utils", ], ) @@ -83,8 +86,10 @@ py_library( py_library( name = "main_op", - testonly = 1, - srcs = ["main_op.py"], + srcs = [ + "main_op.py", + "main_op_impl.py", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:control_flow_ops", @@ -108,6 +113,7 @@ py_test( ":main_op", ":signature_def_utils", ":tag_constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -124,7 +130,10 @@ py_test( py_library( name = "utils", - srcs = ["utils.py"], + srcs = [ + "utils.py", + "utils_impl.py", + ], srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 6b9e3c4693..f80bba8562 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -22,6 +22,7 @@ import os from google.protobuf.any_pb2 import Any +from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.core.protobuf import saver_pb2 @@ -258,6 +259,51 @@ class SavedModelBuilder(object): proto_meta_graph_def = self._saved_model.meta_graphs.add() proto_meta_graph_def.CopyFrom(meta_graph_def) + def _validate_tensor_info(self, tensor_info): + """Validates the `TensorInfo` proto. + + Checks if the `name` and `dtype` fields exist and are non-empty. + + Args: + tensor_info: `TensorInfo` protocol buffer to validate. + + Raises: + AssertionError: If the `name` or `dtype` fields of the supplied + `TensorInfo` proto are not populated. + """ + if tensor_info is None: + raise AssertionError( + "All TensorInfo protos used in the SignatureDefs must have the name " + "and dtype fields set.") + if not tensor_info.name: + raise AssertionError( + "All TensorInfo protos used in the SignatureDefs must have the name " + "field set: %s" % tensor_info) + if tensor_info.dtype is types_pb2.DT_INVALID: + raise AssertionError( + "All TensorInfo protos used in the SignatureDefs must have the dtype " + "field set: %s" % tensor_info) + + def _validate_signature_def_map(self, signature_def_map): + """Validates the `SignatureDef` entries in the signature def map. + + Validation of entries in the signature def map includes ensuring that the + `name` and `dtype` fields of the TensorInfo protos of the `inputs` and + `outputs` of each `SignatureDef` are populated. + + Args: + signature_def_map: The map of signature defs to be validated. + """ + if signature_def_map is not None: + for signature_def_key in signature_def_map: + signature_def = signature_def_map[signature_def_key] + inputs = signature_def.inputs + outputs = signature_def.outputs + for inputs_key in inputs: + self._validate_tensor_info(inputs[inputs_key]) + for outputs_key in outputs: + self._validate_tensor_info(outputs[outputs_key]) + def add_meta_graph(self, tags, signature_def_map=None, @@ -293,6 +339,10 @@ class SavedModelBuilder(object): "Variables and assets have not been saved yet. " "Please invoke `add_meta_graph_and_variables()` first.") + # Validate the signature def map to ensure all included TensorInfos are + # properly populated. + self._validate_signature_def_map(signature_def_map) + # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) @@ -347,6 +397,10 @@ class SavedModelBuilder(object): raise AssertionError("Variables and assets have already been saved. " "Please invoke `add_meta_graph()` instead.") + # Validate the signature def map to ensure all included TensorInfos are + # properly populated. + self._validate_signature_def_map(signature_def_map) + # Save asset files and write them to disk, if any. self._save_and_write_assets(assets_collection) diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py index b94109e4ac..7e3e8df47f 100644 --- a/tensorflow/python/saved_model/constants.py +++ b/tensorflow/python/saved_model/constants.py @@ -19,15 +19,45 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.util.all_util import remove_undocumented + +# Subdirectory name containing the asset files. ASSETS_DIRECTORY = "assets" + +# CollectionDef key containing SavedModel assets. ASSETS_KEY = "saved_model_assets" +# CollectionDef key for the legacy init op. LEGACY_INIT_OP_KEY = "legacy_init_op" + +# CollectionDef key for the SavedModel main op. MAIN_OP_KEY = "saved_model_main_op" +# Schema version for SavedModel. SAVED_MODEL_SCHEMA_VERSION = 1 + +# File name for SavedModel protocol buffer. SAVED_MODEL_FILENAME_PB = "saved_model.pb" + +# File name for text version of SavedModel protocol buffer. SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt" +# Subdirectory name containing the variables/checkpoint files. VARIABLES_DIRECTORY = "variables" + +# File name used for variables. VARIABLES_FILENAME = "variables" + + +_allowed_symbols = [ + "ASSETS_DIRECTORY", + "ASSETS_KEY", + "LEGACY_INIT_OP_KEY", + "MAIN_OP_KEY", + "SAVED_MODEL_SCHEMA_VERSION", + "SAVED_MODEL_FILENAME_PB", + "SAVED_MODEL_FILENAME_PBTXT", + "VARIABLES_DIRECTORY", + "VARIABLES_FILENAME", +] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/saved_model/main_op.py b/tensorflow/python/saved_model/main_op.py index 3f25dc137e..04cadeab66 100644 --- a/tensorflow/python/saved_model/main_op.py +++ b/tensorflow/python/saved_model/main_op.py @@ -22,40 +22,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops -from tensorflow.python.ops import variables - - -def main_op(): - """Returns a main op to init variables and tables. - - Returns the main op including the group of ops that initializes all - variables, initializes local variables and initialize all tables. - - Returns: - The set of ops to be run as part of the main op upon the load operation. - """ - init = variables.global_variables_initializer() - init_local = variables.local_variables_initializer() - init_tables = tf_data_flow_ops.tables_initializer() - return control_flow_ops.group(init, init_local, init_tables) - - -def main_op_with_restore(restore_op_name): - """Returns a main op to init variables, tables and restore the graph. - - Returns the main op including the group of ops that initializes all - variables, initialize local variables, initialize all tables and the restore - op name. - - Args: - restore_op_name: Name of the op to use to restore the graph. - - Returns: - The set of ops to be run as part of the main op upon the load operation. - """ - with ops.control_dependencies([main_op()]): - main_op_with_restore = control_flow_ops.group(restore_op_name) - return main_op_with_restore +# pylint: disable=unused-import +from tensorflow.python.saved_model.main_op_impl import main_op +from tensorflow.python.saved_model.main_op_impl import main_op_with_restore +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "main_op", + "main_op_with_restore", +] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py new file mode 100644 index 0000000000..51462310a6 --- /dev/null +++ b/tensorflow/python/saved_model/main_op_impl.py @@ -0,0 +1,57 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel main op implementation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops +from tensorflow.python.ops import variables + + +def main_op(): + """Returns a main op to init variables and tables. + + Returns the main op including the group of ops that initializes all + variables, initializes local variables and initialize all tables. + + Returns: + The set of ops to be run as part of the main op upon the load operation. + """ + init = variables.global_variables_initializer() + init_local = variables.local_variables_initializer() + init_tables = tf_data_flow_ops.tables_initializer() + return control_flow_ops.group(init, init_local, init_tables) + + +def main_op_with_restore(restore_op_name): + """Returns a main op to init variables, tables and restore the graph. + + Returns the main op including the group of ops that initializes all + variables, initialize local variables, initialize all tables and the restore + op name. + + Args: + restore_op_name: Name of the op to use to restore the graph. + + Returns: + The set of ops to be run as part of the main op upon the load operation. + """ + with ops.control_dependencies([main_op()]): + main_op_with_restore = control_flow_ops.group(restore_op_name) + return main_op_with_restore diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py index 1de2617eef..8c59f7afe7 100644 --- a/tensorflow/python/saved_model/saved_model.py +++ b/tensorflow/python/saved_model/saved_model.py @@ -22,10 +22,13 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants +from tensorflow.python.saved_model import utils # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented @@ -33,9 +36,12 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "builder", + "constants", "loader", + "main_op", "signature_constants", "signature_def_utils", "tag_constants", + "utils", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 286a18fb10..4ce19322af 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os +from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session @@ -82,6 +83,31 @@ class SavedModelTest(test.TestCase): self.assertEqual(expected_asset_file_name, asset.filename) self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name) + def _validate_inputs_tensor_info(self, builder, tensor_info): + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + foo_signature = signature_def_utils.build_signature_def({ + "foo_inputs": tensor_info + }, dict(), "foo") + self.assertRaises( + AssertionError, + builder.add_meta_graph_and_variables, + sess, ["foo"], + signature_def_map={"foo_key": foo_signature}) + + def _validate_outputs_tensor_info(self, builder, tensor_info): + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + foo_signature = signature_def_utils.build_signature_def( + dict(), {"foo_outputs": tensor_info}, "foo") + self.assertRaises( + AssertionError, + builder.add_meta_graph_and_variables, + sess, ["foo"], + signature_def_map={"foo_key": foo_signature}) + def testMaybeSavedModelDir(self): base_path = test.test_src_dir_path("/python/saved_model") self.assertFalse(loader.maybe_saved_model_directory(base_path)) @@ -385,6 +411,25 @@ class SavedModelTest(test.TestCase): self.assertEqual("bar", bar_signature["bar_key"].method_name) self.assertEqual("foo_new", bar_signature["foo_key"].method_name) + def testSignatureDefValidation(self): + export_dir = os.path.join(test.get_temp_dir(), + "test_signature_def_validation") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + tensor_without_name = meta_graph_pb2.TensorInfo() + tensor_without_name.dtype = types_pb2.DT_FLOAT + self._validate_inputs_tensor_info(builder, tensor_without_name) + self._validate_outputs_tensor_info(builder, tensor_without_name) + + tensor_without_dtype = meta_graph_pb2.TensorInfo() + tensor_without_dtype.name = "x" + self._validate_inputs_tensor_info(builder, tensor_without_dtype) + self._validate_outputs_tensor_info(builder, tensor_without_dtype) + + tensor_empty = meta_graph_pb2.TensorInfo() + self._validate_inputs_tensor_info(builder, tensor_empty) + self._validate_outputs_tensor_info(builder, tensor_empty) + def testAssets(self): export_dir = os.path.join(test.get_temp_dir(), "test_assets") builder = saved_model_builder.SavedModelBuilder(export_dir) diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py index ecc58fbc7a..8e970e96ae 100644 --- a/tensorflow/python/saved_model/utils.py +++ b/tensorflow/python/saved_model/utils.py @@ -20,24 +20,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.framework import dtypes +# pylint: disable=unused-import +from tensorflow.python.saved_model.utils_impl import build_tensor_info +# pylint: enable=unused-import +from tensorflow.python.util.all_util import remove_undocumented - -# TensorInfo helpers. - - -def build_tensor_info(tensor): - """Utility function to build TensorInfo proto. - - Args: - tensor: Tensor whose name, dtype and shape are used to build the TensorInfo. - - Returns: - A TensorInfo protocol buffer constructed based on the supplied argument. - """ - dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum - return meta_graph_pb2.TensorInfo( - name=tensor.name, - dtype=dtype_enum, - tensor_shape=tensor.get_shape().as_proto()) +_allowed_symbols = ["build_tensor_info",] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py new file mode 100644 index 0000000000..fcb6fc91b6 --- /dev/null +++ b/tensorflow/python/saved_model/utils_impl.py @@ -0,0 +1,41 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SavedModel utility functions implementation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.framework import dtypes + + +# TensorInfo helpers. + + +def build_tensor_info(tensor): + """Utility function to build TensorInfo proto. + + Args: + tensor: Tensor whose name, dtype and shape are used to build the TensorInfo. + + Returns: + A TensorInfo protocol buffer constructed based on the supplied argument. + """ + dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum + return meta_graph_pb2.TensorInfo( + name=tensor.name, + dtype=dtype_enum, + tensor_shape=tensor.get_shape().as_proto()) -- cgit v1.2.3