diff options
author | 2017-12-14 13:05:24 -0800 | |
---|---|---|
committer | 2017-12-14 13:08:58 -0800 | |
commit | ccef6a711dcadfc57b80783216ee025bfcae4b47 (patch) | |
tree | 4128755c60a7277b55bd5a289d225a3ca146e04b | |
parent | a99b32fb149d028cd31fe638f81c6ca56c6e3b57 (diff) |
Add RNN performance information.
Update cudnn_rnn_ops_benchmark as it had API rotted.
PiperOrigin-RevId: 179084042
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py | 43 | ||||
-rw-r--r-- | tensorflow/docs_src/performance/performance_guide.md | 52 |
3 files changed, 72 insertions, 28 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fce2c03e69..0751624bc4 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -146,10 +146,10 @@ cuda_py_test( cuda_py_test( name = "cudnn_rnn_ops_benchmark", - size = "large", + size = "small", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -164,7 +164,6 @@ cuda_py_test( "//tensorflow/python:variables", ], tags = [ - "manual", "noasan", # http://b/62067814 "nomsan", "notsan", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index ff409ac718..4fc5ff1bd1 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -20,8 +20,8 @@ from __future__ import print_function import time +from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -29,8 +29,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -44,19 +43,19 @@ class CudnnRNNBenchmark(test.Benchmark): "large": { "num_layers": 4, "num_units": 1024, - "seq_length": 40, + "seq_length": 50, "batch_size": 64, }, "medium": { "num_layers": 4, "num_units": 512, - "seq_length": 30, + "seq_length": 50, "batch_size": 64, }, "small": { "num_layers": 4, "num_units": 128, - "seq_length": 20, + "seq_length": 50, "batch_size": 64, }, } @@ -71,7 +70,7 @@ class CudnnRNNBenchmark(test.Benchmark): def _BenchmarkOp(self, op, desc): burn_in_steps = 10 - benchmark_steps = 40 + benchmark_steps = 20 with session.Session() as sess: sess.run(variables.global_variables_initializer()) for i in xrange(burn_in_steps + benchmark_steps): @@ -126,16 +125,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) - - cell = rnn_cell.LSTMCell( - num_units=num_units, initializer=initializer, state_is_tuple=True) - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [contrib_rnn.BasicLSTMCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) @@ -154,14 +149,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop - - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [lstm_ops.LSTMBlockCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index 17f71a6d77..3ebafb9074 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -18,6 +18,7 @@ following sections: * [Input pipeline optimizations](#input-pipeline-optimization) * [Data formats](#data-formats) * [Common fused Ops](#common-fused-ops) +* [RNN Performance](#rnn-performance) * [Building and installing from source](#building-and-installing-from-source) ### Input pipeline optimization @@ -197,6 +198,57 @@ since before TensorFlow 1.0. bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW') ``` +### RNN Performance + +There are many ways to specify an RNN computation in Tensorflow and they have +have trade-offs with respect to model flexibility and performance. The +@{tf.nn.rnn_cell.BasicLSTMCell} should be considered a reference implementation +and used only as a last resort when no other options will work. + +When using one of the cells, rather than the fully fused RNN layers, you have a +choice of whether to use @{tf.nn.static_rnn} or @{tf.nn.dynamic_rnn}. There +shouldn't generally be a performance difference at runtime, but large unroll +amounts can increase the graph size of the @{tf.nn.static_rnn} and cause long +compile times. An additional advantage of @{tf.nn.dynamic_rnn} is that it can +optionally swap memory from the GPU to the CPU to enable training of very long +sequences. Depending on the model and hardware configuration, this can come at +a performance cost. It is also possible to run multiple iterations of +@{tf.nn.dynamic_rnn} and the underlying @{tf.while_loop} construct in parallel, +although this is rarely useful with RNN models as they are inherently +sequential. + +On NVIDIA GPUs, the use of @{tf.contrib.cudnn_rnn} should always be preferred +unless you want layer normalization, which it doesn't support. It is often at +least an order of magnitude faster than @{tf.contrib.rnn.BasicLSTMCell} and +@{tf.contrib.rnn.LSTMBlockCell} and uses 3-4x less memory than +@{tf.contrib.rnn.BasicLSTMCell}. Unfortunately, @{tf.contrib.cudnn_rnn} is not +compatible with @{tf.train.SyncReplicasOptimizer} so you should either use a +different synchronization mechanism (consider an all-reduce based strategy) or +use the @{tf.contrib.rnn.LSTMBlockFusedCell} (at a significant performance +penalty). + +If you need to run one step of the RNN at a time, as might be the case in +reinforcement learning with a recurrent policy, then you should use the +@{tf.contrib.rnn.LSTMBlockCell} with your own environment interaction loop +inside a @{tf.while_loop} construct. Running one step of the RNN at a time and +returning to python is possible but it will be slower. + +On CPUs, mobile devices, and if @{tf.contrib.cudnn_rnn} is not available on +your GPU, the fastest and most memory efficient option is +@{tf.contrib.rnn.LSTMBlockFusedCell}. + +For all of the less common cell types like @{tf.contrib.rnn.NASCell}, +@{tf.contrib.rnn.PhasedLSTMCell}, @{tf.contrib.rnn.UGRNNCell}, +@{tf.contrib.rnn.GLSTMCell}, @{tf.contrib.rnn.Conv1DLSTMCell}, +@{tf.contrib.rnn.Conv2DLSTMCell}, @{tf.contrib.rnn.LayerNormBasicLSTMCell}, +etc., one should be aware that they are implemented in the graph like +@{tf.contrib.rnn.BasicLSTMCell} and as such will suffer from the same poor +performance and high memory usage. One should consider whether or not those +trade-offs are worth it before using these cells. For example, while layer +normalization can speed up convergence, because cuDNN is 20x faster the fastest +wall clock time to convergence is usually obtained without it. + + ### Building and installing from source The default TensorFlow binaries target the broadest range of hardware to make |