aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2017-01-31 12:18:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-31 12:28:51 -0800
commit5cff7dc88dd8275c029e70b23c0e55b0775f9974 (patch)
treeddd11b41dc83ecccabaff6e112f86c14bcf48639
parentd1d4c02c3d7466b3ac54dd329f26b6fe4ce82469 (diff)
SavedModel updates for TensorInfo validation and exposing existing APIs as tf.saved_model.
Change: 146150019
-rw-r--r--tensorflow/python/saved_model/BUILD15
-rw-r--r--tensorflow/python/saved_model/builder_impl.py54
-rw-r--r--tensorflow/python/saved_model/constants.py30
-rw-r--r--tensorflow/python/saved_model/main_op.py48
-rw-r--r--tensorflow/python/saved_model/main_op_impl.py57
-rw-r--r--tensorflow/python/saved_model/saved_model.py6
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py45
-rw-r--r--tensorflow/python/saved_model/utils.py26
-rw-r--r--tensorflow/python/saved_model/utils_impl.py41
9 files changed, 262 insertions, 60 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 47980e6ff3..79399c11c4 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -20,10 +20,13 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":builder",
+ ":constants",
":loader",
+ ":main_op",
":signature_constants",
":signature_def_utils",
":tag_constants",
+ ":utils",
],
)
@@ -83,8 +86,10 @@ py_library(
py_library(
name = "main_op",
- testonly = 1,
- srcs = ["main_op.py"],
+ srcs = [
+ "main_op.py",
+ "main_op_impl.py",
+ ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:control_flow_ops",
@@ -108,6 +113,7 @@ py_test(
":main_op",
":signature_def_utils",
":tag_constants",
+ ":utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
@@ -124,7 +130,10 @@ py_test(
py_library(
name = "utils",
- srcs = ["utils.py"],
+ srcs = [
+ "utils.py",
+ "utils_impl.py",
+ ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 6b9e3c4693..f80bba8562 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -22,6 +22,7 @@ import os
from google.protobuf.any_pb2 import Any
+from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.core.protobuf import saver_pb2
@@ -258,6 +259,51 @@ class SavedModelBuilder(object):
proto_meta_graph_def = self._saved_model.meta_graphs.add()
proto_meta_graph_def.CopyFrom(meta_graph_def)
+ def _validate_tensor_info(self, tensor_info):
+ """Validates the `TensorInfo` proto.
+
+ Checks if the `name` and `dtype` fields exist and are non-empty.
+
+ Args:
+ tensor_info: `TensorInfo` protocol buffer to validate.
+
+ Raises:
+ AssertionError: If the `name` or `dtype` fields of the supplied
+ `TensorInfo` proto are not populated.
+ """
+ if tensor_info is None:
+ raise AssertionError(
+ "All TensorInfo protos used in the SignatureDefs must have the name "
+ "and dtype fields set.")
+ if not tensor_info.name:
+ raise AssertionError(
+ "All TensorInfo protos used in the SignatureDefs must have the name "
+ "field set: %s" % tensor_info)
+ if tensor_info.dtype is types_pb2.DT_INVALID:
+ raise AssertionError(
+ "All TensorInfo protos used in the SignatureDefs must have the dtype "
+ "field set: %s" % tensor_info)
+
+ def _validate_signature_def_map(self, signature_def_map):
+ """Validates the `SignatureDef` entries in the signature def map.
+
+ Validation of entries in the signature def map includes ensuring that the
+ `name` and `dtype` fields of the TensorInfo protos of the `inputs` and
+ `outputs` of each `SignatureDef` are populated.
+
+ Args:
+ signature_def_map: The map of signature defs to be validated.
+ """
+ if signature_def_map is not None:
+ for signature_def_key in signature_def_map:
+ signature_def = signature_def_map[signature_def_key]
+ inputs = signature_def.inputs
+ outputs = signature_def.outputs
+ for inputs_key in inputs:
+ self._validate_tensor_info(inputs[inputs_key])
+ for outputs_key in outputs:
+ self._validate_tensor_info(outputs[outputs_key])
+
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -293,6 +339,10 @@ class SavedModelBuilder(object):
"Variables and assets have not been saved yet. "
"Please invoke `add_meta_graph_and_variables()` first.")
+ # Validate the signature def map to ensure all included TensorInfos are
+ # properly populated.
+ self._validate_signature_def_map(signature_def_map)
+
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)
@@ -347,6 +397,10 @@ class SavedModelBuilder(object):
raise AssertionError("Variables and assets have already been saved. "
"Please invoke `add_meta_graph()` instead.")
+ # Validate the signature def map to ensure all included TensorInfos are
+ # properly populated.
+ self._validate_signature_def_map(signature_def_map)
+
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index b94109e4ac..7e3e8df47f 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -19,15 +19,45 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.util.all_util import remove_undocumented
+
+# Subdirectory name containing the asset files.
ASSETS_DIRECTORY = "assets"
+
+# CollectionDef key containing SavedModel assets.
ASSETS_KEY = "saved_model_assets"
+# CollectionDef key for the legacy init op.
LEGACY_INIT_OP_KEY = "legacy_init_op"
+
+# CollectionDef key for the SavedModel main op.
MAIN_OP_KEY = "saved_model_main_op"
+# Schema version for SavedModel.
SAVED_MODEL_SCHEMA_VERSION = 1
+
+# File name for SavedModel protocol buffer.
SAVED_MODEL_FILENAME_PB = "saved_model.pb"
+
+# File name for text version of SavedModel protocol buffer.
SAVED_MODEL_FILENAME_PBTXT = "saved_model.pbtxt"
+# Subdirectory name containing the variables/checkpoint files.
VARIABLES_DIRECTORY = "variables"
+
+# File name used for variables.
VARIABLES_FILENAME = "variables"
+
+
+_allowed_symbols = [
+ "ASSETS_DIRECTORY",
+ "ASSETS_KEY",
+ "LEGACY_INIT_OP_KEY",
+ "MAIN_OP_KEY",
+ "SAVED_MODEL_SCHEMA_VERSION",
+ "SAVED_MODEL_FILENAME_PB",
+ "SAVED_MODEL_FILENAME_PBTXT",
+ "VARIABLES_DIRECTORY",
+ "VARIABLES_FILENAME",
+]
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/main_op.py b/tensorflow/python/saved_model/main_op.py
index 3f25dc137e..04cadeab66 100644
--- a/tensorflow/python/saved_model/main_op.py
+++ b/tensorflow/python/saved_model/main_op.py
@@ -22,40 +22,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops
-from tensorflow.python.ops import variables
-
-
-def main_op():
- """Returns a main op to init variables and tables.
-
- Returns the main op including the group of ops that initializes all
- variables, initializes local variables and initialize all tables.
-
- Returns:
- The set of ops to be run as part of the main op upon the load operation.
- """
- init = variables.global_variables_initializer()
- init_local = variables.local_variables_initializer()
- init_tables = tf_data_flow_ops.tables_initializer()
- return control_flow_ops.group(init, init_local, init_tables)
-
-
-def main_op_with_restore(restore_op_name):
- """Returns a main op to init variables, tables and restore the graph.
-
- Returns the main op including the group of ops that initializes all
- variables, initialize local variables, initialize all tables and the restore
- op name.
-
- Args:
- restore_op_name: Name of the op to use to restore the graph.
-
- Returns:
- The set of ops to be run as part of the main op upon the load operation.
- """
- with ops.control_dependencies([main_op()]):
- main_op_with_restore = control_flow_ops.group(restore_op_name)
- return main_op_with_restore
+# pylint: disable=unused-import
+from tensorflow.python.saved_model.main_op_impl import main_op
+from tensorflow.python.saved_model.main_op_impl import main_op_with_restore
+# pylint: enable=unused-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "main_op",
+ "main_op_with_restore",
+]
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py
new file mode 100644
index 0000000000..51462310a6
--- /dev/null
+++ b/tensorflow/python/saved_model/main_op_impl.py
@@ -0,0 +1,57 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel main op implementation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops as tf_data_flow_ops
+from tensorflow.python.ops import variables
+
+
+def main_op():
+ """Returns a main op to init variables and tables.
+
+ Returns the main op including the group of ops that initializes all
+ variables, initializes local variables and initialize all tables.
+
+ Returns:
+ The set of ops to be run as part of the main op upon the load operation.
+ """
+ init = variables.global_variables_initializer()
+ init_local = variables.local_variables_initializer()
+ init_tables = tf_data_flow_ops.tables_initializer()
+ return control_flow_ops.group(init, init_local, init_tables)
+
+
+def main_op_with_restore(restore_op_name):
+ """Returns a main op to init variables, tables and restore the graph.
+
+ Returns the main op including the group of ops that initializes all
+ variables, initialize local variables, initialize all tables and the restore
+ op name.
+
+ Args:
+ restore_op_name: Name of the op to use to restore the graph.
+
+ Returns:
+ The set of ops to be run as part of the main op upon the load operation.
+ """
+ with ops.control_dependencies([main_op()]):
+ main_op_with_restore = control_flow_ops.group(restore_op_name)
+ return main_op_with_restore
diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py
index 1de2617eef..8c59f7afe7 100644
--- a/tensorflow/python/saved_model/saved_model.py
+++ b/tensorflow/python/saved_model/saved_model.py
@@ -22,10 +22,13 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model import utils
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
@@ -33,9 +36,12 @@ from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
"builder",
+ "constants",
"loader",
+ "main_op",
"signature_constants",
"signature_def_utils",
"tag_constants",
+ "utils",
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 286a18fb10..4ce19322af 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
+from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
@@ -82,6 +83,31 @@ class SavedModelTest(test.TestCase):
self.assertEqual(expected_asset_file_name, asset.filename)
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
+ def _validate_inputs_tensor_info(self, builder, tensor_info):
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ foo_signature = signature_def_utils.build_signature_def({
+ "foo_inputs": tensor_info
+ }, dict(), "foo")
+ self.assertRaises(
+ AssertionError,
+ builder.add_meta_graph_and_variables,
+ sess, ["foo"],
+ signature_def_map={"foo_key": foo_signature})
+
+ def _validate_outputs_tensor_info(self, builder, tensor_info):
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ foo_signature = signature_def_utils.build_signature_def(
+ dict(), {"foo_outputs": tensor_info}, "foo")
+ self.assertRaises(
+ AssertionError,
+ builder.add_meta_graph_and_variables,
+ sess, ["foo"],
+ signature_def_map={"foo_key": foo_signature})
+
def testMaybeSavedModelDir(self):
base_path = test.test_src_dir_path("/python/saved_model")
self.assertFalse(loader.maybe_saved_model_directory(base_path))
@@ -385,6 +411,25 @@ class SavedModelTest(test.TestCase):
self.assertEqual("bar", bar_signature["bar_key"].method_name)
self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
+ def testSignatureDefValidation(self):
+ export_dir = os.path.join(test.get_temp_dir(),
+ "test_signature_def_validation")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ tensor_without_name = meta_graph_pb2.TensorInfo()
+ tensor_without_name.dtype = types_pb2.DT_FLOAT
+ self._validate_inputs_tensor_info(builder, tensor_without_name)
+ self._validate_outputs_tensor_info(builder, tensor_without_name)
+
+ tensor_without_dtype = meta_graph_pb2.TensorInfo()
+ tensor_without_dtype.name = "x"
+ self._validate_inputs_tensor_info(builder, tensor_without_dtype)
+ self._validate_outputs_tensor_info(builder, tensor_without_dtype)
+
+ tensor_empty = meta_graph_pb2.TensorInfo()
+ self._validate_inputs_tensor_info(builder, tensor_empty)
+ self._validate_outputs_tensor_info(builder, tensor_empty)
+
def testAssets(self):
export_dir = os.path.join(test.get_temp_dir(), "test_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py
index ecc58fbc7a..8e970e96ae 100644
--- a/tensorflow/python/saved_model/utils.py
+++ b/tensorflow/python/saved_model/utils.py
@@ -20,24 +20,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.python.framework import dtypes
+# pylint: disable=unused-import
+from tensorflow.python.saved_model.utils_impl import build_tensor_info
+# pylint: enable=unused-import
+from tensorflow.python.util.all_util import remove_undocumented
-
-# TensorInfo helpers.
-
-
-def build_tensor_info(tensor):
- """Utility function to build TensorInfo proto.
-
- Args:
- tensor: Tensor whose name, dtype and shape are used to build the TensorInfo.
-
- Returns:
- A TensorInfo protocol buffer constructed based on the supplied argument.
- """
- dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
- return meta_graph_pb2.TensorInfo(
- name=tensor.name,
- dtype=dtype_enum,
- tensor_shape=tensor.get_shape().as_proto())
+_allowed_symbols = ["build_tensor_info",]
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
new file mode 100644
index 0000000000..fcb6fc91b6
--- /dev/null
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -0,0 +1,41 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SavedModel utility functions implementation."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.framework import dtypes
+
+
+# TensorInfo helpers.
+
+
+def build_tensor_info(tensor):
+ """Utility function to build TensorInfo proto.
+
+ Args:
+ tensor: Tensor whose name, dtype and shape are used to build the TensorInfo.
+
+ Returns:
+ A TensorInfo protocol buffer constructed based on the supplied argument.
+ """
+ dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
+ return meta_graph_pb2.TensorInfo(
+ name=tensor.name,
+ dtype=dtype_enum,
+ tensor_shape=tensor.get_shape().as_proto())