diff options
author | 2017-01-11 15:01:00 -0800 | |
---|---|---|
committer | 2017-01-11 15:08:53 -0800 | |
commit | 963674de719abd35bc58f31d1e1d57d56fd89484 (patch) | |
tree | 6f95022357a06e49412683620a9720cc94c2173b /tensorflow/compiler/tests/lstm_test.py | |
parent | 37af1b8790d633b9002ab04a0e664ca3c1dbe508 (diff) |
More conversions of flags library to argparse.
Add argv to benchmark/main function so they can handle passing
command line arguments.
Change: 144254260
Diffstat (limited to 'tensorflow/compiler/tests/lstm_test.py')
-rw-r--r-- | tensorflow/compiler/tests/lstm_test.py | 75 |
1 files changed, 53 insertions, 22 deletions
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 9ffeb6c2a2..31093c6571 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -18,7 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import os +import sys import numpy as np @@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import flags as flags_lib from tensorflow.python.platform import test -flags = flags_lib -FLAGS = flags.FLAGS - -flags.DEFINE_integer('batch_size', 128, - 'Inputs are fed in batches of this size, for both ' - 'inference and training. Larger values cause the matmul ' - 'in each LSTM cell to have higher dimensionality.') -flags.DEFINE_integer('seq_length', 60, - 'Length of the unrolled sequence of LSTM cells in a layer.' - 'Larger values cause more LSTM matmuls to be run.') -flags.DEFINE_integer('num_inputs', 1024, - 'Dimension of inputs that are fed into each LSTM cell.') -flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.') -flags.DEFINE_string('device', 'gpu', - 'TensorFlow device to assign ops to, e.g. "gpu", "cpu". ' - 'For details see documentation for tf.Graph.device.') - -flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in ' - '*.pbtxt format to this directory.') - def _DumpGraph(graph, basename): if FLAGS.dump_graph_dir: @@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark): if __name__ == '__main__': - test.main() + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--batch_size', + type=int, + default=128, + help="""\ + Inputs are fed in batches of this size, for both inference and training. + Larger values cause the matmul in each LSTM cell to have higher + dimensionality.\ + """ + ) + parser.add_argument( + '--seq_length', + type=int, + default=60, + help="""\ + Length of the unrolled sequence of LSTM cells in a layer.Larger values + cause more LSTM matmuls to be run.\ + """ + ) + parser.add_argument( + '--num_inputs', + type=int, + default=1024, + help='Dimension of inputs that are fed into each LSTM cell.' + ) + parser.add_argument( + '--num_nodes', + type=int, + default=1024, + help='Number of nodes in each LSTM cell.' + ) + parser.add_argument( + '--device', + type=str, + default='gpu', + help="""\ + TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see + documentation for tf.Graph.device.\ + """ + ) + parser.add_argument( + '--dump_graph_dir', + type=str, + default='', + help='If non-empty, dump graphs in *.pbtxt format to this directory.' + ) + global FLAGS # pylint:disable=global-at-module-level + FLAGS, unparsed = parser.parse_known_args() + test.main(argv=[sys.argv[0]] + unparsed) |