diff options
author | 2016-10-10 09:12:03 -0800 | |
---|---|---|
committer | 2016-10-10 10:23:18 -0700 | |
commit | 8018346e12f9fef76cdc7accc248de17514f6d38 (patch) | |
tree | e91d27ba111b493ebd706ca221fda30bf30be3a4 /tensorflow/examples/tutorials | |
parent | 520f0efd53acf574c2393560692a3b7b98b8bd64 (diff) |
Replace tf.flags usage with argparse everywhere
Change: 135688498
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r-- | tensorflow/examples/tutorials/mnist/fully_connected_feed.py | 61 |
1 files changed, 48 insertions, 13 deletions
diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py index 2773d8d28c..38ae88ee5b 100644 --- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py +++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py @@ -14,11 +14,12 @@ # ============================================================================== """Trains and Evaluates the MNIST network using a feed dictionary.""" -# pylint: disable=missing-docstring from __future__ import absolute_import from __future__ import division from __future__ import print_function +# pylint: disable=missing-docstring +import argparse import os.path import time @@ -28,19 +29,8 @@ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import mnist - # Basic model parameters as external flags. -flags = tf.app.flags -FLAGS = flags.FLAGS -flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') -flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.') -flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') -flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') -flags.DEFINE_integer('batch_size', 100, 'Batch size. ' - 'Must divide evenly into the dataset sizes.') -flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.') -flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' - 'for unit testing.') +FLAGS = None def placeholder_inputs(batch_size): @@ -229,4 +219,49 @@ def main(_): if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--learning_rate', + type=float, + default=0.01, + help='Initial learning rate.' + ) + parser.add_argument( + '--max_steps', + type=int, + default=2000, + help='Number of steps to run trainer.' + ) + parser.add_argument( + '--hidden1', + type=int, + default=128, + help='Number of units in hidden layer 1.' + ) + parser.add_argument( + '--hidden2', + type=int, + default=32, + help='Number of units in hidden layer 2.' + ) + parser.add_argument( + '--batch_size', + type=int, + default=100, + help='Batch size. Must divide evenly into the dataset sizes.' + ) + parser.add_argument( + '--train_dir', + type=str, + default='data', + help='Directory to put the training data.' + ) + parser.add_argument( + '--fake_data', + default=False, + help='If true, uses fake data for unit testing.', + action='store_true' + ) + FLAGS = parser.parse_args() + tf.app.run() |