diff options
author | 2017-01-12 10:54:44 -0800 | |
---|---|---|
committer | 2017-01-12 11:04:32 -0800 | |
commit | d5062e83509ce76910ac90955c55ea6868394274 (patch) | |
tree | 472a322a351b1c16100101a9a11a4bf713bcc152 /tensorflow/examples/saved_model | |
parent | e10f56e6b71429b71312b3bf47201f7295ae1ec2 (diff) |
Expose saved_model via tf.saved_model.
This also exposes meta_Graph_pb2.TensorInfo as tf.TensorInfo.
Change: 144344131
Diffstat (limited to 'tensorflow/examples/saved_model')
-rw-r--r-- | tensorflow/examples/saved_model/BUILD | 36 | ||||
-rw-r--r-- | tensorflow/examples/saved_model/saved_model_half_plus_two.py | 189 |
2 files changed, 225 insertions, 0 deletions
diff --git a/tensorflow/examples/saved_model/BUILD b/tensorflow/examples/saved_model/BUILD new file mode 100644 index 0000000000..844e99dcd4 --- /dev/null +++ b/tensorflow/examples/saved_model/BUILD @@ -0,0 +1,36 @@ +# Description: SavedModel half plus two example. + +package( + default_visibility = ["//tensorflow:internal"], + features = [ + "-layering_check", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "g3doc/sitemap.md", + ], + ), + visibility = ["//visibility:public"], +) + +py_binary( + name = "saved_model_half_plus_two", + srcs = [ + "saved_model_half_plus_two.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/examples/saved_model/saved_model_half_plus_two.py b/tensorflow/examples/saved_model/saved_model_half_plus_two.py new file mode 100644 index 0000000000..f466778296 --- /dev/null +++ b/tensorflow/examples/saved_model/saved_model_half_plus_two.py @@ -0,0 +1,189 @@ +## Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Exports an example linear regression inference graph. + +Exports a TensorFlow graph to `/tmp/saved_model/half_plus_two/` based on the +`SavedModel` format. + +This graph calculates, + +\\( + y = a*x + b +\\) + +and/or, independently, + +\\( + y2 = a*x2 + c +\\) + +where `a`, `b` and `c` are variables with `a=0.5` and `b=2` and `c=3`. + +Output from this program is typically used to exercise SavedModel load and +execution code. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys + +import tensorflow as tf + + +FLAGS = None + + +def _write_assets(assets_directory, assets_filename): + """Writes asset files to be used with SavedModel for half plus two. + + Args: + assets_directory: The directory to which the assets should be written. + assets_filename: Name of the file to which the asset contents should be + written. + + Returns: + The path to which the assets file was written. + """ + if not tf.python_io.file_exists(assets_directory): + tf.python_io.recursive_create_dir(assets_directory) + + path = os.path.join( + tf.compat.as_bytes(assets_directory), tf.compat.as_bytes(assets_filename)) + tf.python_io.write_string_to_file(path, "asset-file-contents") + return path + + +def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): + """Generates SavedModel for half plus two. + + Args: + export_dir: The directory to which the SavedModel should be written. + as_text: Writes the SavedModel protocol buffer in text format to disk. + """ + builder = tf.saved_model.builder.SavedModelBuilder(export_dir) + + with tf.Session(graph=tf.Graph()) as sess: + # Set up the model parameters as variables to exercise variable loading + # functionality upon restore. + a = tf.Variable(0.5, name="a") + b = tf.Variable(2.0, name="b") + c = tf.Variable(3.0, name="c") + + # Create a placeholder for serialized tensorflow.Example messages to be fed. + serialized_tf_example = tf.placeholder(tf.string, name="tf_example") + + # Parse the tensorflow.Example looking for a feature named "x" with a single + # floating point value. + feature_configs = {"x": tf.FixedLenFeature([1], dtype=tf.float32)} + tf_example = tf.parse_example(serialized_tf_example, feature_configs) + # Use tf.identity() to assign name + x = tf.identity(tf_example["x"], name="x") + y = tf.add(tf.multiply(a, x), b, name="y") + + x2 = tf.placeholder(tf.float32, name="x2") + tf.add(tf.multiply(a, x2), c, name="y2") + + # Create an assets file that can be saved and restored as part of the + # SavedModel. + original_assets_directory = "/tmp/original/export/assets" + original_assets_filename = "foo.txt" + original_assets_filepath = _write_assets(original_assets_directory, + original_assets_filename) + + # Set up the assets collection. + assets_filepath = tf.constant(original_assets_filepath) + tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, assets_filepath) + filename_tensor = tf.Variable( + original_assets_filename, + name="filename_tensor", + trainable=False, + collections=[]) + assign_filename_op = filename_tensor.assign(original_assets_filename) + + # Set up the signature for regression with input and output tensor + # specification. + input_tensor = tf.TensorInfo() + input_tensor.name = serialized_tf_example.name + signature_inputs = { + tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor} + + output_tensor = tf.TensorInfo() + output_tensor.name = tf.identity(y).name + signature_outputs = { + tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor} + signature_def = tf.saved_model.signature_def_utils.build_signature_def( + signature_inputs, signature_outputs, + tf.saved_model.signature_constants.REGRESS_METHOD_NAME) + + # Set up the signature for Predict with input and output tensor + # specification. + predict_input_tensor = tf.TensorInfo() + predict_input_tensor.name = x.name + predict_signature_inputs = { + "x": predict_input_tensor + } + + predict_output_tensor = tf.TensorInfo() + predict_output_tensor.name = y.name + predict_signature_outputs = { + "y": predict_output_tensor + } + predict_signature_def = ( + tf.saved_model.signature_def_utils.build_signature_def( + predict_signature_inputs, predict_signature_outputs, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) + + # Initialize all variables and then save the SavedModel. + sess.run(tf.global_variables_initializer()) + signature_def_map = { + tf.saved_model.signature_constants.REGRESS_METHOD_NAME: + signature_def, + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + predict_signature_def + } + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map, + assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), + legacy_init_op=tf.group(assign_filename_op)) + builder.save(as_text) + + +def main(_): + _generate_saved_model_for_half_plus_two(FLAGS.output_dir) + print("SavedModel generated at: %s" % FLAGS.output_dir) + + _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True) + print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--output_dir", + type=str, + default="/tmp/saved_model_half_plus_two", + help="Directory where to ouput SavedModel.") + parser.add_argument( + "--output_dir_pbtxt", + type=str, + default="/tmp/saved_model_half_plus_two_pbtxt", + help="Directory where to ouput the text format of SavedModel.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |