aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-02-22 14:24:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-22 14:29:27 -0800
commitdce9a49c19f406ba45919e8c94474e55dc5ccd54 (patch)
tree928db8a52603e00aef76985cda16b8bceb9debb2 /tensorflow/examples/image_retraining
parentcb7e1963c625fd9713e7475d85621f95be6762f1 (diff)
Merge changes from github.
PiperOrigin-RevId: 186674197
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py55
1 files changed, 54 insertions, 1 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 868310cbc0..25e09fecbf 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -41,7 +41,6 @@ The subfolder names are important, since they define what label is applied to
each image, but the filenames themselves don't matter. Once your images are
prepared, you can run the training with a command like this:
-
```bash
bazel build tensorflow/examples/image_retraining:retrain && \
bazel-bin/tensorflow/examples/image_retraining/retrain \
@@ -70,12 +69,14 @@ on resource-limited platforms, you can try the `--architecture` flag with a
Mobilenet model. For example:
Run floating-point version of mobilenet:
+
```bash
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos --architecture mobilenet_1.0_224
```
Run quantized version of mobilenet:
+
```bash
python tensorflow/examples/image_retraining/retrain.py \
--image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized
@@ -96,6 +97,12 @@ Visualize the summaries with this command:
tensorboard --logdir /tmp/retrain_logs
+To use with Tensorflow Serving:
+
+```bash
+tensorflow_model_server --port=9000 --model_name=inception \
+ --model_base_path=/tmp/saved_models/
+```
"""
from __future__ import absolute_import
from __future__ import division
@@ -1004,6 +1011,45 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean,
return jpeg_data, mul_image
+def export_model(sess, architecture, saved_model_dir):
+ """Exports model for serving.
+
+ Args:
+ sess: Current active TensorFlow Session.
+ architecture: Model architecture.
+ saved_model_dir: Directory in which to save exported model and variables.
+ """
+ if architecture == 'inception_v3':
+ input_tensor = 'DecodeJpeg/contents:0'
+ elif architecture.startswith('mobilenet_'):
+ input_tensor = 'input:0'
+ else:
+ raise ValueError('Unknown architecture', architecture)
+ in_image = sess.graph.get_tensor_by_name(input_tensor)
+ inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
+
+ out_classes = sess.graph.get_tensor_by_name('final_result:0')
+ outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
+
+ signature = tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=inputs,
+ outputs=outputs,
+ method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
+
+ legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
+
+ # Save out the SavedModel.
+ builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
+ builder.add_meta_graph_and_variables(
+ sess, [tf.saved_model.tag_constants.SERVING],
+ signature_def_map={
+ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
+ signature
+ },
+ legacy_init_op=legacy_init_op)
+ builder.save()
+
+
def main(_):
# Needed to make sure the logging output is visible.
# See https://github.com/tensorflow/tensorflow/issues/3047
@@ -1179,6 +1225,8 @@ def main(_):
with gfile.FastGFile(FLAGS.output_labels, 'w') as f:
f.write('\n'.join(image_lists.keys()) + '\n')
+ export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@@ -1362,5 +1410,10 @@ if __name__ == '__main__':
takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html
for more information on Mobilenet.\
""")
+ parser.add_argument(
+ '--saved_model_dir',
+ type=str,
+ default='/tmp/saved_models/1/',
+ help='Where to save the exported graph.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)