diff options
author | Kiril Gorovoy <kgorovoy@google.com> | 2017-01-27 09:06:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-27 09:24:36 -0800 |
commit | 4c420657c8b854354c21d30ae78624c30151eb25 (patch) | |
tree | 5fd9eec8f54593b785be680f3b2fae882b054fae /tensorflow/examples/saved_model | |
parent | 54b2180f996d69ca6558cdc07939a5887ba52d80 (diff) |
Add additional named signatures to test half_plus_two model.
Change: 145801253
Diffstat (limited to 'tensorflow/examples/saved_model')
-rw-r--r-- | tensorflow/examples/saved_model/saved_model_half_plus_two.py | 99 |
1 files changed, 65 insertions, 34 deletions
diff --git a/tensorflow/examples/saved_model/saved_model_half_plus_two.py b/tensorflow/examples/saved_model/saved_model_half_plus_two.py index f466778296..55a9013118 100644 --- a/tensorflow/examples/saved_model/saved_model_half_plus_two.py +++ b/tensorflow/examples/saved_model/saved_model_half_plus_two.py @@ -45,6 +45,7 @@ import sys import tensorflow as tf +from tensorflow.python.lib.io import file_io FLAGS = None @@ -60,15 +61,52 @@ def _write_assets(assets_directory, assets_filename): Returns: The path to which the assets file was written. """ - if not tf.python_io.file_exists(assets_directory): - tf.python_io.recursive_create_dir(assets_directory) + if not file_io.file_exists(assets_directory): + file_io.recursive_create_dir(assets_directory) path = os.path.join( tf.compat.as_bytes(assets_directory), tf.compat.as_bytes(assets_filename)) - tf.python_io.write_string_to_file(path, "asset-file-contents") + file_io.write_string_to_file(path, "asset-file-contents") return path +def _build_regression_signature(input_tensor, output_tensor): + """Helper function for building a regression SignatureDef.""" + input_tensor_info = tf.TensorInfo() + input_tensor_info.name = input_tensor.name + signature_inputs = { + tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor_info + } + output_tensor_info = tf.TensorInfo() + output_tensor_info.name = tf.identity(output_tensor).name + signature_outputs = { + tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor_info + } + return tf.saved_model.signature_def_utils.build_signature_def( + signature_inputs, signature_outputs, + tf.saved_model.signature_constants.REGRESS_METHOD_NAME) + + +# Possibly extend this to allow passing in 'classes', but for now this is +# sufficient for testing purposes. +def _build_classification_signature(input_tensor, scores_tensor): + """Helper function for building a classification SignatureDef.""" + input_tensor_info = tf.TensorInfo() + input_tensor_info.name = input_tensor.name + signature_inputs = { + tf.saved_model.signature_constants.CLASSIFY_INPUTS: input_tensor_info + } + output_tensor_info = tf.TensorInfo() + output_tensor_info.name = tf.identity(scores_tensor).name + signature_outputs = { + tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES: + output_tensor_info + } + return tf.saved_model.signature_def_utils.build_signature_def( + signature_inputs, signature_outputs, + tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME) + + def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): """Generates SavedModel for half plus two. @@ -90,14 +128,20 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): # Parse the tensorflow.Example looking for a feature named "x" with a single # floating point value. - feature_configs = {"x": tf.FixedLenFeature([1], dtype=tf.float32)} + feature_configs = { + "x": tf.FixedLenFeature( + [1], dtype=tf.float32), + "x2": tf.FixedLenFeature( + [1], dtype=tf.float32, default_value=[0.0]) + } tf_example = tf.parse_example(serialized_tf_example, feature_configs) # Use tf.identity() to assign name x = tf.identity(tf_example["x"], name="x") y = tf.add(tf.multiply(a, x), b, name="y") + y2 = tf.add(tf.multiply(a, x), c, name="y2") - x2 = tf.placeholder(tf.float32, name="x2") - tf.add(tf.multiply(a, x2), c, name="y2") + x2 = tf.identity(tf_example["x2"], name="x2") + y3 = tf.add(tf.multiply(a, x2), c, name="y3") # Create an assets file that can be saved and restored as part of the # SavedModel. @@ -116,34 +160,15 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): collections=[]) assign_filename_op = filename_tensor.assign(original_assets_filename) - # Set up the signature for regression with input and output tensor - # specification. - input_tensor = tf.TensorInfo() - input_tensor.name = serialized_tf_example.name - signature_inputs = { - tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor} - - output_tensor = tf.TensorInfo() - output_tensor.name = tf.identity(y).name - signature_outputs = { - tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor} - signature_def = tf.saved_model.signature_def_utils.build_signature_def( - signature_inputs, signature_outputs, - tf.saved_model.signature_constants.REGRESS_METHOD_NAME) - # Set up the signature for Predict with input and output tensor # specification. predict_input_tensor = tf.TensorInfo() predict_input_tensor.name = x.name - predict_signature_inputs = { - "x": predict_input_tensor - } + predict_signature_inputs = {"x": predict_input_tensor} predict_output_tensor = tf.TensorInfo() predict_output_tensor.name = y.name - predict_signature_outputs = { - "y": predict_output_tensor - } + predict_signature_outputs = {"y": predict_output_tensor} predict_signature_def = ( tf.saved_model.signature_def_utils.build_signature_def( predict_signature_inputs, predict_signature_outputs, @@ -151,15 +176,21 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): # Initialize all variables and then save the SavedModel. sess.run(tf.global_variables_initializer()) - signature_def_map = { - tf.saved_model.signature_constants.REGRESS_METHOD_NAME: - signature_def, - tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: - predict_signature_def - } builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], - signature_def_map=signature_def_map, + signature_def_map={ + "regress_x_to_y": + _build_regression_signature(serialized_tf_example, y), + "regress_x_to_y2": + _build_regression_signature(serialized_tf_example, y2), + "regress_x2_to_y3": + _build_regression_signature(x2, y3), + "classify_x_to_y": + _build_classification_signature(serialized_tf_example, y), + tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY: + predict_signature_def + }, assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), legacy_init_op=tf.group(assign_filename_op)) builder.save(as_text) |