diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-10-10 09:12:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-10 10:23:18 -0700 |
commit | 8018346e12f9fef76cdc7accc248de17514f6d38 (patch) | |
tree | e91d27ba111b493ebd706ca221fda30bf30be3a4 /tensorflow/examples/image_retraining | |
parent | 520f0efd53acf574c2393560692a3b7b98b8bd64 (diff) |
Replace tf.flags usage with argparse everywhere
Change: 135688498
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r-- | tensorflow/examples/image_retraining/retrain.py | 204 | ||||
-rw-r--r-- | tensorflow/examples/image_retraining/retrain_test.py | 4 |
2 files changed, 147 insertions, 61 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 7812117e5d..d52a23fd15 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -64,12 +64,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse from datetime import datetime import glob import hashlib import os.path import random import re +import struct import sys import tarfile @@ -82,74 +84,15 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import gfile from tensorflow.python.util import compat - -import struct - -FLAGS = tf.app.flags.FLAGS +FLAGS = None # Input and output file flags. -tf.app.flags.DEFINE_string('image_dir', '', - """Path to folders of labeled images.""") -tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb', - """Where to save the trained graph.""") -tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt', - """Where to save the trained graph's labels.""") -tf.app.flags.DEFINE_string('summaries_dir', '/tmp/retrain_logs', - """Where to save summary logs for TensorBoard.""") # Details of the training configuration. -tf.app.flags.DEFINE_integer('how_many_training_steps', 4000, - """How many training steps to run before ending.""") -tf.app.flags.DEFINE_float('learning_rate', 0.01, - """How large a learning rate to use when training.""") -tf.app.flags.DEFINE_integer( - 'testing_percentage', 10, - """What percentage of images to use as a test set.""") -tf.app.flags.DEFINE_integer( - 'validation_percentage', 10, - """What percentage of images to use as a validation set.""") -tf.app.flags.DEFINE_integer('eval_step_interval', 10, - """How often to evaluate the training results.""") -tf.app.flags.DEFINE_integer('train_batch_size', 100, - """How many images to train on at a time.""") -tf.app.flags.DEFINE_integer('test_batch_size', 500, - """How many images to test on at a time. This""" - """ test set is only used infrequently to verify""" - """ the overall accuracy of the model.""") -tf.app.flags.DEFINE_integer( - 'validation_batch_size', 100, - """How many images to use in an evaluation batch. This validation set is""" - """ used much more often than the test set, and is an early indicator of""" - """ how accurate the model is during training.""") # File-system cache locations. -tf.app.flags.DEFINE_string('model_dir', '/tmp/imagenet', - """Path to classify_image_graph_def.pb, """ - """imagenet_synset_to_human_label_map.txt, and """ - """imagenet_2012_challenge_label_map_proto.pbtxt.""") -tf.app.flags.DEFINE_string( - 'bottleneck_dir', '/tmp/bottleneck', - """Path to cache bottleneck layer values as files.""") -tf.app.flags.DEFINE_string('final_tensor_name', 'final_result', - """The name of the output classification layer in""" - """ the retrained graph.""") # Controls the distortions used during training. -tf.app.flags.DEFINE_boolean( - 'flip_left_right', False, - """Whether to randomly flip half of the training images horizontally.""") -tf.app.flags.DEFINE_integer( - 'random_crop', 0, - """A percentage determining how much of a margin to randomly crop off the""" - """ training images.""") -tf.app.flags.DEFINE_integer( - 'random_scale', 0, - """A percentage determining how much to randomly scale up the size of the""" - """ training images by.""") -tf.app.flags.DEFINE_integer( - 'random_brightness', 0, - """A percentage determining how much to randomly multiply the training""" - """ image input pixels up or down by.""") # These are all parameters that are tied to the particular model architecture # we're using for Inception v3. These include things like tensor names and their @@ -927,4 +870,145 @@ def main(_): if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--image_dir', + type=str, + default='', + help='Path to folders of labeled images.' + ) + parser.add_argument( + '--output_graph', + type=str, + default='/tmp/output_graph.pb', + help='Where to save the trained graph.' + ) + parser.add_argument( + '--output_labels', + type=str, + default='/tmp/output_labels.txt', + help='Where to save the trained graph\'s labels.' + ) + parser.add_argument( + '--summaries_dir', + type=str, + default='/tmp/retrain_logs', + help='Where to save summary logs for TensorBoard.' + ) + parser.add_argument( + '--how_many_training_steps', + type=int, + default=4000, + help='How many training steps to run before ending.' + ) + parser.add_argument( + '--learning_rate', + type=float, + default=0.01, + help='How large a learning rate to use when training.' + ) + parser.add_argument( + '--testing_percentage', + type=int, + default=10, + help='What percentage of images to use as a test set.' + ) + parser.add_argument( + '--validation_percentage', + type=int, + default=10, + help='What percentage of images to use as a validation set.' + ) + parser.add_argument( + '--eval_step_interval', + type=int, + default=10, + help='How often to evaluate the training results.' + ) + parser.add_argument( + '--train_batch_size', + type=int, + default=100, + help='How many images to train on at a time.' + ) + parser.add_argument( + '--test_batch_size', + type=int, + default=500, + help="""\ + How many images to test on at a time. This test set is only used + infrequently to verify the overall accuracy of the model.\ + """ + ) + parser.add_argument( + '--validation_batch_size', + type=int, + default=100, + help="""\ + How many images to use in an evaluation batch. This validation set is + used much more often than the test set, and is an early indicator of how + accurate the model is during training.\ + """ + ) + parser.add_argument( + '--model_dir', + type=str, + default='/tmp/imagenet', + help="""\ + Path to classify_image_graph_def.pb, + imagenet_synset_to_human_label_map.txt, and + imagenet_2012_challenge_label_map_proto.pbtxt.\ + """ + ) + parser.add_argument( + '--bottleneck_dir', + type=str, + default='/tmp/bottleneck', + help='Path to cache bottleneck layer values as files.' + ) + parser.add_argument( + '--final_tensor_name', + type=str, + default='final_result', + help="""\ + The name of the output classification layer in the retrained graph.\ + """ + ) + parser.add_argument( + '--flip_left_right', + default=False, + help="""\ + Whether to randomly flip half of the training images horizontally.\ + """, + action='store_true' + ) + parser.add_argument( + '--random_crop', + type=int, + default=0, + help="""\ + A percentage determining how much of a margin to randomly crop off the + training images.\ + """ + ) + parser.add_argument( + '--random_scale', + type=int, + default=0, + help="""\ + A percentage determining how much to randomly scale up the size of the + training images by.\ + """ + ) + parser.add_argument( + '--random_brightness', + type=int, + default=0, + help="""\ + A percentage determining how much to randomly multiply the training image + input pixels up or down by.\ + """ + ) + FLAGS = parser.parse_args() + tf.app.run() diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index 072998ae60..3b802e54d1 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -65,7 +65,9 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortJPGInput:0')) self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0')) - def testAddFinalTrainingOps(self): + @tf.test.mock.patch('tensorflow.examples.' + 'image_retraining.retrain.FLAGS', learning_rate=0.01) + def testAddFinalTrainingOps(self, flags_mock): with tf.Graph().as_default(): with tf.Session() as sess: bottleneck = tf.placeholder( |