aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-05-04 16:01:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-05 08:30:01 -0700
commit008a3b69a601dc68fd940eb8a03b0c445714a339 (patch)
treedf7a92de37594adc3d8a3aef72baea1ea137fb1c /tensorflow/python/saved_model
parentab48fb528221152299fb08da8116d2eca54b8423 (diff)
Add the ability to export separate SavedModels for train and eval mode to Estimator with two new methods, available in tf.contrib: export_all_saved_models and export_saved_model_for_mode.
PiperOrigin-RevId: 195485922
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/builder_impl.py54
-rw-r--r--tensorflow/python/saved_model/constants.py6
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py90
-rw-r--r--tensorflow/python/saved_model/signature_constants.py6
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py2
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py56
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py95
-rw-r--r--tensorflow/python/saved_model/tag_constants.py5
8 files changed, 298 insertions, 16 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 3447d917e9..071033b066 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -168,6 +168,25 @@ class SavedModelBuilder(object):
raise TypeError("main_op needs to be an Operation: %r" % main_op)
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
+ def _add_train_op(self, train_op):
+ """Add train op to the SavedModel.
+
+ Note that this functionality is in development, and liable to be
+ moved elsewhere.
+
+ Args:
+ train_op: Op or group of ops that are used for training. These are
+ stored as a collection with key TRAIN_OP_KEY, but not executed.
+
+ Raises:
+ TypeError if Train op is not of type `Operation`.
+ """
+ if train_op is not None:
+ if (not isinstance(train_op, ops.Tensor) and
+ not isinstance(train_op, ops.Operation)):
+ raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
+ ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
+
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
"""Tags the meta graph def and adds it to the SavedModel.
@@ -238,6 +257,20 @@ class SavedModelBuilder(object):
for outputs_key in outputs:
self._validate_tensor_info(outputs[outputs_key])
+ def _add_collections(
+ self, assets_collection, legacy_init_op, main_op, train_op):
+ """Add asset and op collections to be saved."""
+ # Save asset files and write them to disk, if any.
+ self._save_and_write_assets(assets_collection)
+
+ if main_op is None:
+ # Add legacy init op to the SavedModel.
+ self._maybe_add_legacy_init_op(legacy_init_op)
+ else:
+ self._add_main_op(main_op)
+
+ self._add_train_op(train_op)
+
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -285,14 +318,8 @@ class SavedModelBuilder(object):
# properly populated.
self._validate_signature_def_map(signature_def_map)
- # Save asset files and write them to disk, if any.
- self._save_and_write_assets(assets_collection)
-
- if main_op is None:
- # Add legacy init op to the SavedModel.
- self._maybe_add_legacy_init_op(legacy_init_op)
- else:
- self._add_main_op(main_op)
+ # Add assets and ops
+ self._add_collections(assets_collection, legacy_init_op, main_op, None)
# Initialize a saver to generate a sharded output for all saveables in the
# current scope.
@@ -351,6 +378,7 @@ class SavedModelBuilder(object):
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+
"""
# pylint: enable=line-too-long
if self._has_saved_variables:
@@ -362,8 +390,8 @@ class SavedModelBuilder(object):
# properly populated.
self._validate_signature_def_map(signature_def_map)
- # Save asset files and write them to disk, if any.
- self._save_and_write_assets(assets_collection)
+ # Add assets and ops
+ self._add_collections(assets_collection, legacy_init_op, main_op, None)
# Create the variables sub-directory, if it does not exist.
variables_dir = os.path.join(
@@ -376,12 +404,6 @@ class SavedModelBuilder(object):
compat.as_text(variables_dir),
compat.as_text(constants.VARIABLES_FILENAME))
- if main_op is None:
- # Add legacy init op to the SavedModel.
- self._maybe_add_legacy_init_op(legacy_init_op)
- else:
- self._add_main_op(main_op)
-
# Initialize a saver to generate a sharded output for all saveables in the
# current scope.
saver = tf_saver.Saver(
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index 34206c6f6d..61c6ffbd0d 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -41,6 +41,10 @@ MAIN_OP_KEY = "saved_model_main_op"
tf_export("saved_model.constants.MAIN_OP_KEY").export_constant(
__name__, "MAIN_OP_KEY")
+# CollectionDef key for the SavedModel train op.
+# Not exported while export_all_saved_models is in contrib.
+TRAIN_OP_KEY = "saved_model_train_op"
+
# Schema version for SavedModel.
SAVED_MODEL_SCHEMA_VERSION = 1
tf_export("saved_model.constants.SAVED_MODEL_SCHEMA_VERSION").export_constant(
@@ -65,3 +69,5 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
VARIABLES_FILENAME = "variables"
tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant(
__name__, "VARIABLES_FILENAME")
+
+
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 804255375e..a4d994fd43 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -734,6 +734,96 @@ class SavedModelTest(test.TestCase):
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
+ def testTrainOp(self):
+ export_dir = self._get_export_dir("test_train_op")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Add `v1` and `v2` variables to the graph.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+ v2 = variables.Variable(2, name="v2")
+ ops.add_to_collection("v", v2)
+
+ sess.run(variables.global_variables_initializer())
+ train_op = state_ops.assign_add(v1, v2)
+
+ sess.run(train_op)
+ # TODO(karmel): remove explicit call when in the public method.
+ builder._add_train_op(train_op)
+ builder.add_meta_graph_and_variables(sess, ["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ self.assertEqual(3, ops.get_collection("v")[0].eval())
+ self.assertEqual(2, ops.get_collection("v")[1].eval())
+ self.assertIsInstance(
+ ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
+
+ def testTrainOpGroup(self):
+ export_dir = self._get_export_dir("test_train_op_group")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Add `v1` and `v2` variables to the graph.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+ v2 = variables.Variable(2, name="v2")
+ ops.add_to_collection("v", v2)
+
+ sess.run(variables.global_variables_initializer())
+ train_op = control_flow_ops.group()
+
+ sess.run(train_op)
+ # TODO(karmel): remove explicit call when in the public method.
+ builder._add_train_op(train_op)
+ builder.add_meta_graph_and_variables(sess, ["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ self.assertEqual(1, ops.get_collection("v")[0].eval())
+ self.assertEqual(2, ops.get_collection("v")[1].eval())
+ self.assertIsInstance(
+ ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
+
+ def testTrainOpAfterVariables(self):
+ export_dir = self._get_export_dir("test_train_op_after_variables")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Add `v1` and `v2` variables to the graph.
+ v1 = variables.Variable(1, name="v1")
+ ops.add_to_collection("v", v1)
+ v2 = variables.Variable(2, name="v2")
+ ops.add_to_collection("v", v2)
+
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(sess, ["pre_foo"])
+
+ train_op = state_ops.assign_add(v1, v2)
+ sess.run(train_op)
+ # TODO(karmel): remove explicit call when in the public method.
+ builder._add_train_op(train_op)
+ builder.add_meta_graph(["foo"])
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["foo"], export_dir)
+ self.assertIsInstance(
+ ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ loader.load(sess, ["pre_foo"], export_dir)
+ self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
+
def testMultipleAssets(self):
export_dir = self._get_export_dir("test_multiple_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py
index 819f351291..99007a9634 100644
--- a/tensorflow/python/saved_model/signature_constants.py
+++ b/tensorflow/python/saved_model/signature_constants.py
@@ -94,3 +94,9 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant(
__name__, "REGRESS_OUTPUTS")
################################################################################
+# Train/Eval API constants.
+# Not exported while export_all_saved_models is in contrib.
+
+SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training"
+
+SUPERVISED_EVAL_METHOD_NAME = "tensorflow/supervised/eval"
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
index ea0f52f17e..27d6b70e9d 100644
--- a/tensorflow/python/saved_model/signature_def_utils.py
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -26,6 +26,8 @@ from tensorflow.python.saved_model.signature_def_utils_impl import classificatio
from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
+from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def
+from tensorflow.python.saved_model.signature_def_utils_impl import supervised_train_signature_def
# pylint: enable=unused-import
del absolute_import
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index d033159188..f8ad788f77 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -185,6 +185,62 @@ def predict_signature_def(inputs, outputs):
return signature_def
+def supervised_train_signature_def(
+ inputs, loss, predictions=None, metrics=None):
+ return _supervised_signature_def(
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
+ predictions=predictions, metrics=metrics)
+
+
+def supervised_eval_signature_def(
+ inputs, loss, predictions=None, metrics=None):
+ return _supervised_signature_def(
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
+ predictions=predictions, metrics=metrics)
+
+
+def _supervised_signature_def(
+ method_name, inputs, loss=None, predictions=None,
+ metrics=None):
+ """Creates a signature for training and eval data.
+
+ This function produces signatures that describe the inputs and outputs
+ of a supervised process, such as training or evaluation, that
+ results in loss, metrics, and the like. Note that this function only requires
+ inputs to be not None.
+
+ Args:
+ method_name: Method name of the SignatureDef as a string.
+ inputs: dict of string to `Tensor`.
+ loss: dict of string to `Tensor` representing computed loss.
+ predictions: dict of string to `Tensor` representing the output predictions.
+ metrics: dict of string to `Tensor` representing metric ops.
+
+ Returns:
+ A train- or eval-flavored signature_def.
+
+ Raises:
+ ValueError: If inputs or outputs is `None`.
+ """
+ if inputs is None or not inputs:
+ raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
+
+ signature_inputs = {key: utils.build_tensor_info(tensor)
+ for key, tensor in inputs.items()}
+
+ signature_outputs = {}
+ for output_set in (loss, predictions, metrics):
+ if output_set is not None:
+ sig_out = {key: utils.build_tensor_info(tensor)
+ for key, tensor in output_set.items()}
+ signature_outputs.update(sig_out)
+
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs, method_name)
+
+ return signature_def
+
+
@tf_export('saved_model.signature_def_utils.is_valid_signature')
def is_valid_signature(signature_def):
"""Determine whether a SignatureDef can be served by TensorFlow Serving."""
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
index b2bd14db8c..ebc5450633 100644
--- a/tensorflow/python/saved_model/signature_def_utils_test.py
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -180,6 +180,101 @@ class SignatureDefUtilsTest(test.TestCase):
self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
+ def testTrainSignatureDef(self):
+ self._testSupervisedSignatureDef(
+ signature_def_utils_impl.supervised_train_signature_def,
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
+
+ def testEvalSignatureDef(self):
+ self._testSupervisedSignatureDef(
+ signature_def_utils_impl.supervised_eval_signature_def,
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME)
+
+ def _testSupervisedSignatureDef(self, fn_to_test, method_name):
+ inputs = {
+ "input-1": constant_op.constant("a", name="input-1"),
+ "input-2": constant_op.constant("b", name="input-2"),
+ }
+ loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
+ predictions = {
+ "classes": constant_op.constant([100], name="classes"),
+ }
+ metrics_val = constant_op.constant(100.0, name="metrics_val")
+ metrics = {
+ "metrics/value": metrics_val,
+ "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"),
+ }
+
+ signature_def = fn_to_test(inputs, loss, predictions, metrics)
+
+ self.assertEqual(method_name, signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(2, len(signature_def.inputs))
+ input1_tensor_info_actual = (signature_def.inputs["input-1"])
+ self.assertEqual("input-1:0", input1_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
+ input2_tensor_info_actual = (signature_def.inputs["input-2"])
+ self.assertEqual("input-2:0", input2_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(4, len(signature_def.outputs))
+ self.assertEqual("loss-1:0", signature_def.outputs["loss-1"].name)
+ self.assertEqual(types_pb2.DT_FLOAT, signature_def.outputs["loss-1"].dtype)
+
+ self.assertEqual("classes:0", signature_def.outputs["classes"].name)
+ self.assertEqual(1, len(signature_def.outputs["classes"].tensor_shape.dim))
+
+ self.assertEqual(
+ "metrics_val:0", signature_def.outputs["metrics/value"].name)
+ self.assertEqual(
+ types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
+
+ self.assertEqual(
+ "metrics_op:0", signature_def.outputs["metrics/update_op"].name)
+ self.assertEqual(
+ types_pb2.DT_FLOAT, signature_def.outputs["metrics/value"].dtype)
+
+ def testTrainSignatureDefMissingInputs(self):
+ self._testSupervisedSignatureDefMissingInputs(
+ signature_def_utils_impl.supervised_train_signature_def,
+ signature_constants.SUPERVISED_TRAIN_METHOD_NAME)
+
+ def testEvalSignatureDefMissingInputs(self):
+ self._testSupervisedSignatureDefMissingInputs(
+ signature_def_utils_impl.supervised_eval_signature_def,
+ signature_constants.SUPERVISED_EVAL_METHOD_NAME)
+
+ def _testSupervisedSignatureDefMissingInputs(self, fn_to_test, method_name):
+ inputs = {
+ "input-1": constant_op.constant("a", name="input-1"),
+ "input-2": constant_op.constant("b", name="input-2"),
+ }
+ loss = {"loss-1": constant_op.constant(0.45, name="loss-1")}
+ predictions = {
+ "classes": constant_op.constant([100], name="classes"),
+ }
+ metrics_val = constant_op.constant(100, name="metrics_val")
+ metrics = {
+ "metrics/value": metrics_val,
+ "metrics/update_op": array_ops.identity(metrics_val, name="metrics_op"),
+ }
+
+ with self.assertRaises(ValueError):
+ signature_def = fn_to_test(
+ {}, loss=loss, predictions=predictions, metrics=metrics)
+
+ signature_def = fn_to_test(inputs, loss=loss)
+ self.assertEqual(method_name, signature_def.method_name)
+ self.assertEqual(1, len(signature_def.outputs))
+
+ signature_def = fn_to_test(inputs, metrics=metrics, loss=loss)
+ self.assertEqual(method_name, signature_def.method_name)
+ self.assertEqual(3, len(signature_def.outputs))
+
def testGetShapeAndTypes(self):
inputs = {
"input-1": constant_op.constant(["a", "b"]),
diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py
index 5a797da791..c82154e7b9 100644
--- a/tensorflow/python/saved_model/tag_constants.py
+++ b/tensorflow/python/saved_model/tag_constants.py
@@ -32,6 +32,9 @@ TRAINING = "train"
tf_export("saved_model.tag_constants.TRAINING").export_constant(
__name__, "TRAINING")
+# Tag for the `eval` graph. Not exported while the export logic is in contrib.
+EVAL = "eval"
+
# Tag for the `gpu` graph.
GPU = "gpu"
tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU")
@@ -39,3 +42,5 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU")
# Tag for the `tpu` graph.
TPU = "tpu"
tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU")
+
+