aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/saved_model
diff options
context:
space:
mode:
authorGravatar Kiril Gorovoy <kgorovoy@google.com>2017-01-27 09:06:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-27 09:24:36 -0800
commit4c420657c8b854354c21d30ae78624c30151eb25 (patch)
tree5fd9eec8f54593b785be680f3b2fae882b054fae /tensorflow/examples/saved_model
parent54b2180f996d69ca6558cdc07939a5887ba52d80 (diff)
Add additional named signatures to test half_plus_two model.
Change: 145801253
Diffstat (limited to 'tensorflow/examples/saved_model')
-rw-r--r--tensorflow/examples/saved_model/saved_model_half_plus_two.py99
1 files changed, 65 insertions, 34 deletions
diff --git a/tensorflow/examples/saved_model/saved_model_half_plus_two.py b/tensorflow/examples/saved_model/saved_model_half_plus_two.py
index f466778296..55a9013118 100644
--- a/tensorflow/examples/saved_model/saved_model_half_plus_two.py
+++ b/tensorflow/examples/saved_model/saved_model_half_plus_two.py
@@ -45,6 +45,7 @@ import sys
import tensorflow as tf
+from tensorflow.python.lib.io import file_io
FLAGS = None
@@ -60,15 +61,52 @@ def _write_assets(assets_directory, assets_filename):
Returns:
The path to which the assets file was written.
"""
- if not tf.python_io.file_exists(assets_directory):
- tf.python_io.recursive_create_dir(assets_directory)
+ if not file_io.file_exists(assets_directory):
+ file_io.recursive_create_dir(assets_directory)
path = os.path.join(
tf.compat.as_bytes(assets_directory), tf.compat.as_bytes(assets_filename))
- tf.python_io.write_string_to_file(path, "asset-file-contents")
+ file_io.write_string_to_file(path, "asset-file-contents")
return path
+def _build_regression_signature(input_tensor, output_tensor):
+ """Helper function for building a regression SignatureDef."""
+ input_tensor_info = tf.TensorInfo()
+ input_tensor_info.name = input_tensor.name
+ signature_inputs = {
+ tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor_info
+ }
+ output_tensor_info = tf.TensorInfo()
+ output_tensor_info.name = tf.identity(output_tensor).name
+ signature_outputs = {
+ tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor_info
+ }
+ return tf.saved_model.signature_def_utils.build_signature_def(
+ signature_inputs, signature_outputs,
+ tf.saved_model.signature_constants.REGRESS_METHOD_NAME)
+
+
+# Possibly extend this to allow passing in 'classes', but for now this is
+# sufficient for testing purposes.
+def _build_classification_signature(input_tensor, scores_tensor):
+ """Helper function for building a classification SignatureDef."""
+ input_tensor_info = tf.TensorInfo()
+ input_tensor_info.name = input_tensor.name
+ signature_inputs = {
+ tf.saved_model.signature_constants.CLASSIFY_INPUTS: input_tensor_info
+ }
+ output_tensor_info = tf.TensorInfo()
+ output_tensor_info.name = tf.identity(scores_tensor).name
+ signature_outputs = {
+ tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
+ output_tensor_info
+ }
+ return tf.saved_model.signature_def_utils.build_signature_def(
+ signature_inputs, signature_outputs,
+ tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)
+
+
def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
"""Generates SavedModel for half plus two.
@@ -90,14 +128,20 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
# Parse the tensorflow.Example looking for a feature named "x" with a single
# floating point value.
- feature_configs = {"x": tf.FixedLenFeature([1], dtype=tf.float32)}
+ feature_configs = {
+ "x": tf.FixedLenFeature(
+ [1], dtype=tf.float32),
+ "x2": tf.FixedLenFeature(
+ [1], dtype=tf.float32, default_value=[0.0])
+ }
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
# Use tf.identity() to assign name
x = tf.identity(tf_example["x"], name="x")
y = tf.add(tf.multiply(a, x), b, name="y")
+ y2 = tf.add(tf.multiply(a, x), c, name="y2")
- x2 = tf.placeholder(tf.float32, name="x2")
- tf.add(tf.multiply(a, x2), c, name="y2")
+ x2 = tf.identity(tf_example["x2"], name="x2")
+ y3 = tf.add(tf.multiply(a, x2), c, name="y3")
# Create an assets file that can be saved and restored as part of the
# SavedModel.
@@ -116,34 +160,15 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
collections=[])
assign_filename_op = filename_tensor.assign(original_assets_filename)
- # Set up the signature for regression with input and output tensor
- # specification.
- input_tensor = tf.TensorInfo()
- input_tensor.name = serialized_tf_example.name
- signature_inputs = {
- tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor}
-
- output_tensor = tf.TensorInfo()
- output_tensor.name = tf.identity(y).name
- signature_outputs = {
- tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor}
- signature_def = tf.saved_model.signature_def_utils.build_signature_def(
- signature_inputs, signature_outputs,
- tf.saved_model.signature_constants.REGRESS_METHOD_NAME)
-
# Set up the signature for Predict with input and output tensor
# specification.
predict_input_tensor = tf.TensorInfo()
predict_input_tensor.name = x.name
- predict_signature_inputs = {
- "x": predict_input_tensor
- }
+ predict_signature_inputs = {"x": predict_input_tensor}
predict_output_tensor = tf.TensorInfo()
predict_output_tensor.name = y.name
- predict_signature_outputs = {
- "y": predict_output_tensor
- }
+ predict_signature_outputs = {"y": predict_output_tensor}
predict_signature_def = (
tf.saved_model.signature_def_utils.build_signature_def(
predict_signature_inputs, predict_signature_outputs,
@@ -151,15 +176,21 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
# Initialize all variables and then save the SavedModel.
sess.run(tf.global_variables_initializer())
- signature_def_map = {
- tf.saved_model.signature_constants.REGRESS_METHOD_NAME:
- signature_def,
- tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
- predict_signature_def
- }
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
- signature_def_map=signature_def_map,
+ signature_def_map={
+ "regress_x_to_y":
+ _build_regression_signature(serialized_tf_example, y),
+ "regress_x_to_y2":
+ _build_regression_signature(serialized_tf_example, y2),
+ "regress_x2_to_y3":
+ _build_regression_signature(x2, y3),
+ "classify_x_to_y":
+ _build_classification_signature(serialized_tf_example, y),
+ tf.saved_model.signature_constants.
+ DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ predict_signature_def
+ },
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=tf.group(assign_filename_op))
builder.save(as_text)