aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle/example/export_half_plus_two.py
blob: e4b1947e038d5df38930e96f1de48c2bc3d57a8b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Exports a toy linear regression inference graph.

Exports a TensorFlow graph to /tmp/half_plus_two/ based on the Exporter
format, go/tf-exporter.

This graph calculates,
  y = a*x + b
where a and b are variables with a=0.5 and b=2.

Output from this program is typically used to exercise Session
loading and execution code.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter


def Export():
  export_path = "/tmp/half_plus_two"
  with tf.Session() as sess:
    # Make model parameters a&b variables instead of constants to
    # exercise the variable reloading mechanisms.
    a = tf.Variable(0.5, name="a")
    b = tf.Variable(2.0, name="b")

    # Calculate, y = a*x + b
    # here we use a placeholder 'x' which is fed at inference time.
    x = tf.placeholder(tf.float32, name="x")
    y = tf.add(tf.mul(a, x), b, name="y")

    # Setup a standard Saver for our variables.
    save = tf.train.Saver({"a": a, "b": b}, sharded=True)

    # asset_path contains the base directory of assets used in training (e.g.
    # vocabulary files).
    original_asset_path = tf.constant("/tmp/original/export/assets")
    # Ops reading asset files should reference the asset_path tensor
    # which stores the original asset path at training time and the
    # overridden assets directory at restore time.
    asset_path = tf.Variable(original_asset_path,
                             name="asset_path",
                             trainable=False,
                             collections=[])
    assign_asset_path = asset_path.assign(original_asset_path)

    # Use a fixed global step number.
    global_step_tensor = tf.Variable(123, name="global_step")

    # Create a RegressionSignature for our input and output.
    signature = exporter.regression_signature(input_tensor=x, output_tensor=y)

    # Create two filename assets and corresponding tensors.
    # TODO(b/26254158) Consider adding validation of file existance as well as
    # hashes (e.g. sha1) for consistency.
    original_filename1 = tf.constant("hello1.txt")
    tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename1)
    filename1 = tf.Variable(original_filename1,
                            name="filename1",
                            trainable=False,
                            collections=[])
    assign_filename1 = filename1.assign(original_filename1)
    original_filename2 = tf.constant("hello2.txt")
    tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename2)
    filename2 = tf.Variable(original_filename2,
                            name="filename2",
                            trainable=False,
                            collections=[])
    assign_filename2 = filename2.assign(original_filename2)

    # Init op contains a group of all variables that we assign.
    init_op = tf.group(assign_asset_path, assign_filename1, assign_filename2)

    # CopyAssets is used as a callback during export to copy files to the
    # given export directory.
    def CopyAssets(filepaths, export_path):
      print("copying asset files to: %s" % export_path)
      for filepath in filepaths:
        print("copying asset file: %s" % filepath)

    # Run an export.
    tf.initialize_all_variables().run()
    export = exporter.Exporter(save)
    export.init(
        sess.graph.as_graph_def(),
        init_op=init_op,
        default_graph_signature=signature,
        assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
        assets_callback=CopyAssets)
    export.export(export_path, global_step_tensor, sess)


def main(_):
  Export()


if __name__ == "__main__":
  tf.app.run()