aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-10 09:12:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-10 10:23:18 -0700
commit8018346e12f9fef76cdc7accc248de17514f6d38 (patch)
treee91d27ba111b493ebd706ca221fda30bf30be3a4 /tensorflow/examples/tutorials
parent520f0efd53acf574c2393560692a3b7b98b8bd64 (diff)
Replace tf.flags usage with argparse everywhere
Change: 135688498
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r--tensorflow/examples/tutorials/mnist/fully_connected_feed.py61
1 files changed, 48 insertions, 13 deletions
diff --git a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
index 2773d8d28c..38ae88ee5b 100644
--- a/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
@@ -14,11 +14,12 @@
# ==============================================================================
"""Trains and Evaluates the MNIST network using a feed dictionary."""
-# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+# pylint: disable=missing-docstring
+import argparse
import os.path
import time
@@ -28,19 +29,8 @@ import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist
-
# Basic model parameters as external flags.
-flags = tf.app.flags
-FLAGS = flags.FLAGS
-flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
-flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
-flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
-flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
-flags.DEFINE_integer('batch_size', 100, 'Batch size. '
- 'Must divide evenly into the dataset sizes.')
-flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
-flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
- 'for unit testing.')
+FLAGS = None
def placeholder_inputs(batch_size):
@@ -229,4 +219,49 @@ def main(_):
if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--learning_rate',
+ type=float,
+ default=0.01,
+ help='Initial learning rate.'
+ )
+ parser.add_argument(
+ '--max_steps',
+ type=int,
+ default=2000,
+ help='Number of steps to run trainer.'
+ )
+ parser.add_argument(
+ '--hidden1',
+ type=int,
+ default=128,
+ help='Number of units in hidden layer 1.'
+ )
+ parser.add_argument(
+ '--hidden2',
+ type=int,
+ default=32,
+ help='Number of units in hidden layer 2.'
+ )
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=100,
+ help='Batch size. Must divide evenly into the dataset sizes.'
+ )
+ parser.add_argument(
+ '--train_dir',
+ type=str,
+ default='data',
+ help='Directory to put the training data.'
+ )
+ parser.add_argument(
+ '--fake_data',
+ default=False,
+ help='If true, uses fake data for unit testing.',
+ action='store_true'
+ )
+ FLAGS = parser.parse_args()
+
tf.app.run()