aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/image_retraining
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-10 09:12:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-10 10:23:18 -0700
commit8018346e12f9fef76cdc7accc248de17514f6d38 (patch)
treee91d27ba111b493ebd706ca221fda30bf30be3a4 /tensorflow/examples/image_retraining
parent520f0efd53acf574c2393560692a3b7b98b8bd64 (diff)
Replace tf.flags usage with argparse everywhere
Change: 135688498
Diffstat (limited to 'tensorflow/examples/image_retraining')
-rw-r--r--tensorflow/examples/image_retraining/retrain.py204
-rw-r--r--tensorflow/examples/image_retraining/retrain_test.py4
2 files changed, 147 insertions, 61 deletions
diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py
index 7812117e5d..d52a23fd15 100644
--- a/tensorflow/examples/image_retraining/retrain.py
+++ b/tensorflow/examples/image_retraining/retrain.py
@@ -64,12 +64,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
from datetime import datetime
import glob
import hashlib
import os.path
import random
import re
+import struct
import sys
import tarfile
@@ -82,74 +84,15 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
from tensorflow.python.util import compat
-
-import struct
-
-FLAGS = tf.app.flags.FLAGS
+FLAGS = None
# Input and output file flags.
-tf.app.flags.DEFINE_string('image_dir', '',
- """Path to folders of labeled images.""")
-tf.app.flags.DEFINE_string('output_graph', '/tmp/output_graph.pb',
- """Where to save the trained graph.""")
-tf.app.flags.DEFINE_string('output_labels', '/tmp/output_labels.txt',
- """Where to save the trained graph's labels.""")
-tf.app.flags.DEFINE_string('summaries_dir', '/tmp/retrain_logs',
- """Where to save summary logs for TensorBoard.""")
# Details of the training configuration.
-tf.app.flags.DEFINE_integer('how_many_training_steps', 4000,
- """How many training steps to run before ending.""")
-tf.app.flags.DEFINE_float('learning_rate', 0.01,
- """How large a learning rate to use when training.""")
-tf.app.flags.DEFINE_integer(
- 'testing_percentage', 10,
- """What percentage of images to use as a test set.""")
-tf.app.flags.DEFINE_integer(
- 'validation_percentage', 10,
- """What percentage of images to use as a validation set.""")
-tf.app.flags.DEFINE_integer('eval_step_interval', 10,
- """How often to evaluate the training results.""")
-tf.app.flags.DEFINE_integer('train_batch_size', 100,
- """How many images to train on at a time.""")
-tf.app.flags.DEFINE_integer('test_batch_size', 500,
- """How many images to test on at a time. This"""
- """ test set is only used infrequently to verify"""
- """ the overall accuracy of the model.""")
-tf.app.flags.DEFINE_integer(
- 'validation_batch_size', 100,
- """How many images to use in an evaluation batch. This validation set is"""
- """ used much more often than the test set, and is an early indicator of"""
- """ how accurate the model is during training.""")
# File-system cache locations.
-tf.app.flags.DEFINE_string('model_dir', '/tmp/imagenet',
- """Path to classify_image_graph_def.pb, """
- """imagenet_synset_to_human_label_map.txt, and """
- """imagenet_2012_challenge_label_map_proto.pbtxt.""")
-tf.app.flags.DEFINE_string(
- 'bottleneck_dir', '/tmp/bottleneck',
- """Path to cache bottleneck layer values as files.""")
-tf.app.flags.DEFINE_string('final_tensor_name', 'final_result',
- """The name of the output classification layer in"""
- """ the retrained graph.""")
# Controls the distortions used during training.
-tf.app.flags.DEFINE_boolean(
- 'flip_left_right', False,
- """Whether to randomly flip half of the training images horizontally.""")
-tf.app.flags.DEFINE_integer(
- 'random_crop', 0,
- """A percentage determining how much of a margin to randomly crop off the"""
- """ training images.""")
-tf.app.flags.DEFINE_integer(
- 'random_scale', 0,
- """A percentage determining how much to randomly scale up the size of the"""
- """ training images by.""")
-tf.app.flags.DEFINE_integer(
- 'random_brightness', 0,
- """A percentage determining how much to randomly multiply the training"""
- """ image input pixels up or down by.""")
# These are all parameters that are tied to the particular model architecture
# we're using for Inception v3. These include things like tensor names and their
@@ -927,4 +870,145 @@ def main(_):
if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--image_dir',
+ type=str,
+ default='',
+ help='Path to folders of labeled images.'
+ )
+ parser.add_argument(
+ '--output_graph',
+ type=str,
+ default='/tmp/output_graph.pb',
+ help='Where to save the trained graph.'
+ )
+ parser.add_argument(
+ '--output_labels',
+ type=str,
+ default='/tmp/output_labels.txt',
+ help='Where to save the trained graph\'s labels.'
+ )
+ parser.add_argument(
+ '--summaries_dir',
+ type=str,
+ default='/tmp/retrain_logs',
+ help='Where to save summary logs for TensorBoard.'
+ )
+ parser.add_argument(
+ '--how_many_training_steps',
+ type=int,
+ default=4000,
+ help='How many training steps to run before ending.'
+ )
+ parser.add_argument(
+ '--learning_rate',
+ type=float,
+ default=0.01,
+ help='How large a learning rate to use when training.'
+ )
+ parser.add_argument(
+ '--testing_percentage',
+ type=int,
+ default=10,
+ help='What percentage of images to use as a test set.'
+ )
+ parser.add_argument(
+ '--validation_percentage',
+ type=int,
+ default=10,
+ help='What percentage of images to use as a validation set.'
+ )
+ parser.add_argument(
+ '--eval_step_interval',
+ type=int,
+ default=10,
+ help='How often to evaluate the training results.'
+ )
+ parser.add_argument(
+ '--train_batch_size',
+ type=int,
+ default=100,
+ help='How many images to train on at a time.'
+ )
+ parser.add_argument(
+ '--test_batch_size',
+ type=int,
+ default=500,
+ help="""\
+ How many images to test on at a time. This test set is only used
+ infrequently to verify the overall accuracy of the model.\
+ """
+ )
+ parser.add_argument(
+ '--validation_batch_size',
+ type=int,
+ default=100,
+ help="""\
+ How many images to use in an evaluation batch. This validation set is
+ used much more often than the test set, and is an early indicator of how
+ accurate the model is during training.\
+ """
+ )
+ parser.add_argument(
+ '--model_dir',
+ type=str,
+ default='/tmp/imagenet',
+ help="""\
+ Path to classify_image_graph_def.pb,
+ imagenet_synset_to_human_label_map.txt, and
+ imagenet_2012_challenge_label_map_proto.pbtxt.\
+ """
+ )
+ parser.add_argument(
+ '--bottleneck_dir',
+ type=str,
+ default='/tmp/bottleneck',
+ help='Path to cache bottleneck layer values as files.'
+ )
+ parser.add_argument(
+ '--final_tensor_name',
+ type=str,
+ default='final_result',
+ help="""\
+ The name of the output classification layer in the retrained graph.\
+ """
+ )
+ parser.add_argument(
+ '--flip_left_right',
+ default=False,
+ help="""\
+ Whether to randomly flip half of the training images horizontally.\
+ """,
+ action='store_true'
+ )
+ parser.add_argument(
+ '--random_crop',
+ type=int,
+ default=0,
+ help="""\
+ A percentage determining how much of a margin to randomly crop off the
+ training images.\
+ """
+ )
+ parser.add_argument(
+ '--random_scale',
+ type=int,
+ default=0,
+ help="""\
+ A percentage determining how much to randomly scale up the size of the
+ training images by.\
+ """
+ )
+ parser.add_argument(
+ '--random_brightness',
+ type=int,
+ default=0,
+ help="""\
+ A percentage determining how much to randomly multiply the training image
+ input pixels up or down by.\
+ """
+ )
+ FLAGS = parser.parse_args()
+
tf.app.run()
diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py
index 072998ae60..3b802e54d1 100644
--- a/tensorflow/examples/image_retraining/retrain_test.py
+++ b/tensorflow/examples/image_retraining/retrain_test.py
@@ -65,7 +65,9 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase):
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortJPGInput:0'))
self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0'))
- def testAddFinalTrainingOps(self):
+ @tf.test.mock.patch('tensorflow.examples.'
+ 'image_retraining.retrain.FLAGS', learning_rate=0.01)
+ def testAddFinalTrainingOps(self, flags_mock):
with tf.Graph().as_default():
with tf.Session() as sess:
bottleneck = tf.placeholder(