aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2018-03-05 10:11:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-05 10:16:25 -0800
commit602f54c065eb9513ef3bb8557887d106637f96e5 (patch)
treeb1af639a65e0330079607439ba7d30ad4c20ed28 /tensorflow/python/saved_model
parentf09e7f9ebad85b3395628381777cba3e71f768a5 (diff)
Make SavedModel builder validation accept signatures involving sparse tensors.
PiperOrigin-RevId: 187883080
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/builder_impl.py11
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py72
2 files changed, 67 insertions, 16 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 7347da7536..3447d917e9 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -193,7 +193,8 @@ class SavedModelBuilder(object):
def _validate_tensor_info(self, tensor_info):
"""Validates the `TensorInfo` proto.
- Checks if the `name` and `dtype` fields exist and are non-empty.
+ Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist
+ and are non-empty.
Args:
tensor_info: `TensorInfo` protocol buffer to validate.
@@ -206,10 +207,12 @@ class SavedModelBuilder(object):
raise AssertionError(
"All TensorInfo protos used in the SignatureDefs must have the name "
"and dtype fields set.")
- if not tensor_info.name:
+ if tensor_info.WhichOneof("encoding") is None:
+ # TODO(soergel) validate each of the fields of coo_sparse
raise AssertionError(
- "All TensorInfo protos used in the SignatureDefs must have the name "
- "field set: %s" % tensor_info)
+ "All TensorInfo protos used in the SignatureDefs must have one of "
+ "the 'encoding' fields (e.g., name or coo_sparse) set: %s"
+ % tensor_info)
if tensor_info.dtype is types_pb2.DT_INVALID:
raise AssertionError(
"All TensorInfo protos used in the SignatureDefs must have the dtype "
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index d9d3168825..804255375e 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -94,7 +94,7 @@ class SavedModelTest(test.TestCase):
self.assertEqual(expected_asset_file_name, asset.filename)
self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
- def _validate_inputs_tensor_info(self, builder, tensor_info):
+ def _validate_inputs_tensor_info_fail(self, builder, tensor_info):
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
@@ -107,7 +107,18 @@ class SavedModelTest(test.TestCase):
sess, ["foo"],
signature_def_map={"foo_key": foo_signature})
- def _validate_outputs_tensor_info(self, builder, tensor_info):
+ def _validate_inputs_tensor_info_accept(self, builder, tensor_info):
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ foo_signature = signature_def_utils.build_signature_def({
+ "foo_inputs": tensor_info
+ }, dict(), "foo")
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"],
+ signature_def_map={"foo_key": foo_signature})
+
+ def _validate_outputs_tensor_info_fail(self, builder, tensor_info):
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
@@ -119,6 +130,16 @@ class SavedModelTest(test.TestCase):
sess, ["foo"],
signature_def_map={"foo_key": foo_signature})
+ def _validate_outputs_tensor_info_accept(self, builder, tensor_info):
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ foo_signature = signature_def_utils.build_signature_def(
+ dict(), {"foo_outputs": tensor_info}, "foo")
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"],
+ signature_def_map={"foo_key": foo_signature})
+
def testMaybeSavedModelDir(self):
base_path = test.test_src_dir_path("/python/saved_model")
self.assertFalse(loader.maybe_saved_model_directory(base_path))
@@ -538,23 +559,50 @@ class SavedModelTest(test.TestCase):
self.assertEqual("bar", bar_signature["bar_key"].method_name)
self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
- def testSignatureDefValidation(self):
- export_dir = self._get_export_dir("test_signature_def_validation")
+ def testSignatureDefValidationFails(self):
+ export_dir = self._get_export_dir("test_signature_def_validation_fail")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- tensor_without_name = meta_graph_pb2.TensorInfo()
- tensor_without_name.dtype = types_pb2.DT_FLOAT
- self._validate_inputs_tensor_info(builder, tensor_without_name)
- self._validate_outputs_tensor_info(builder, tensor_without_name)
+ tensor_without_encoding = meta_graph_pb2.TensorInfo()
+ tensor_without_encoding.dtype = types_pb2.DT_FLOAT
+ self._validate_inputs_tensor_info_fail(builder, tensor_without_encoding)
+ self._validate_outputs_tensor_info_fail(builder, tensor_without_encoding)
tensor_without_dtype = meta_graph_pb2.TensorInfo()
tensor_without_dtype.name = "x"
- self._validate_inputs_tensor_info(builder, tensor_without_dtype)
- self._validate_outputs_tensor_info(builder, tensor_without_dtype)
+ self._validate_inputs_tensor_info_fail(builder, tensor_without_dtype)
+ self._validate_outputs_tensor_info_fail(builder, tensor_without_dtype)
tensor_empty = meta_graph_pb2.TensorInfo()
- self._validate_inputs_tensor_info(builder, tensor_empty)
- self._validate_outputs_tensor_info(builder, tensor_empty)
+ self._validate_inputs_tensor_info_fail(builder, tensor_empty)
+ self._validate_outputs_tensor_info_fail(builder, tensor_empty)
+
+ def testSignatureDefValidationSucceedsWithName(self):
+ tensor_with_name = meta_graph_pb2.TensorInfo()
+ tensor_with_name.name = "foo"
+ tensor_with_name.dtype = types_pb2.DT_FLOAT
+
+ export_dir = self._get_export_dir("test_signature_def_validation_name_1")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+ self._validate_inputs_tensor_info_accept(builder, tensor_with_name)
+
+ export_dir = self._get_export_dir("test_signature_def_validation_name_2")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+ self._validate_outputs_tensor_info_accept(builder, tensor_with_name)
+
+ def testSignatureDefValidationSucceedsWithCoo(self):
+ tensor_with_coo = meta_graph_pb2.TensorInfo()
+ # TODO(soergel) test validation of each of the fields of coo_sparse
+ tensor_with_coo.coo_sparse.values_tensor_name = "foo"
+ tensor_with_coo.dtype = types_pb2.DT_FLOAT
+
+ export_dir = self._get_export_dir("test_signature_def_validation_coo_1")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+ self._validate_inputs_tensor_info_accept(builder, tensor_with_coo)
+
+ export_dir = self._get_export_dir("test_signature_def_validation_coo_2")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+ self._validate_outputs_tensor_info_accept(builder, tensor_with_coo)
def testAssets(self):
export_dir = self._get_export_dir("test_assets")