aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/lstm_test.py
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-01-11 15:01:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-11 15:08:53 -0800
commit963674de719abd35bc58f31d1e1d57d56fd89484 (patch)
tree6f95022357a06e49412683620a9720cc94c2173b /tensorflow/compiler/tests/lstm_test.py
parent37af1b8790d633b9002ab04a0e664ca3c1dbe508 (diff)
More conversions of flags library to argparse.
Add argv to benchmark/main function so they can handle passing command line arguments. Change: 144254260
Diffstat (limited to 'tensorflow/compiler/tests/lstm_test.py')
-rw-r--r--tensorflow/compiler/tests/lstm_test.py75
1 files changed, 53 insertions, 22 deletions
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py
index 9ffeb6c2a2..31093c6571 100644
--- a/tensorflow/compiler/tests/lstm_test.py
+++ b/tensorflow/compiler/tests/lstm_test.py
@@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import os
+import sys
import numpy as np
@@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
-from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import test
-flags = flags_lib
-FLAGS = flags.FLAGS
-
-flags.DEFINE_integer('batch_size', 128,
- 'Inputs are fed in batches of this size, for both '
- 'inference and training. Larger values cause the matmul '
- 'in each LSTM cell to have higher dimensionality.')
-flags.DEFINE_integer('seq_length', 60,
- 'Length of the unrolled sequence of LSTM cells in a layer.'
- 'Larger values cause more LSTM matmuls to be run.')
-flags.DEFINE_integer('num_inputs', 1024,
- 'Dimension of inputs that are fed into each LSTM cell.')
-flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
-flags.DEFINE_string('device', 'gpu',
- 'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
- 'For details see documentation for tf.Graph.device.')
-
-flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
- '*.pbtxt format to this directory.')
-
def _DumpGraph(graph, basename):
if FLAGS.dump_graph_dir:
@@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark):
if __name__ == '__main__':
- test.main()
+ parser = argparse.ArgumentParser()
+ parser.register('type', 'bool', lambda v: v.lower() == 'true')
+ parser.add_argument(
+ '--batch_size',
+ type=int,
+ default=128,
+ help="""\
+ Inputs are fed in batches of this size, for both inference and training.
+ Larger values cause the matmul in each LSTM cell to have higher
+ dimensionality.\
+ """
+ )
+ parser.add_argument(
+ '--seq_length',
+ type=int,
+ default=60,
+ help="""\
+ Length of the unrolled sequence of LSTM cells in a layer.Larger values
+ cause more LSTM matmuls to be run.\
+ """
+ )
+ parser.add_argument(
+ '--num_inputs',
+ type=int,
+ default=1024,
+ help='Dimension of inputs that are fed into each LSTM cell.'
+ )
+ parser.add_argument(
+ '--num_nodes',
+ type=int,
+ default=1024,
+ help='Number of nodes in each LSTM cell.'
+ )
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='gpu',
+ help="""\
+ TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see
+ documentation for tf.Graph.device.\
+ """
+ )
+ parser.add_argument(
+ '--dump_graph_dir',
+ type=str,
+ default='',
+ help='If non-empty, dump graphs in *.pbtxt format to this directory.'
+ )
+ global FLAGS # pylint:disable=global-at-module-level
+ FLAGS, unparsed = parser.parse_known_args()
+ test.main(argv=[sys.argv[0]] + unparsed)