aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-14 13:05:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 13:08:58 -0800
commitccef6a711dcadfc57b80783216ee025bfcae4b47 (patch)
tree4128755c60a7277b55bd5a289d225a3ca146e04b
parenta99b32fb149d028cd31fe638f81c6ca56c6e3b57 (diff)
Add RNN performance information.
Update cudnn_rnn_ops_benchmark as it had API rotted. PiperOrigin-RevId: 179084042
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD5
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py43
-rw-r--r--tensorflow/docs_src/performance/performance_guide.md52
3 files changed, 72 insertions, 28 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index fce2c03e69..0751624bc4 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -146,10 +146,10 @@ cuda_py_test(
cuda_py_test(
name = "cudnn_rnn_ops_benchmark",
- size = "large",
+ size = "small",
srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"],
additional_deps = [
- ":cudnn_rnn_ops_py",
+ ":cudnn_rnn_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
@@ -164,7 +164,6 @@ cuda_py_test(
"//tensorflow/python:variables",
],
tags = [
- "manual",
"noasan", # http://b/62067814
"nomsan",
"notsan",
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
index ff409ac718..4fc5ff1bd1 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import time
+from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
-from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
@@ -29,8 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import rnn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -44,19 +43,19 @@ class CudnnRNNBenchmark(test.Benchmark):
"large": {
"num_layers": 4,
"num_units": 1024,
- "seq_length": 40,
+ "seq_length": 50,
"batch_size": 64,
},
"medium": {
"num_layers": 4,
"num_units": 512,
- "seq_length": 30,
+ "seq_length": 50,
"batch_size": 64,
},
"small": {
"num_layers": 4,
"num_units": 128,
- "seq_length": 20,
+ "seq_length": 50,
"batch_size": 64,
},
}
@@ -71,7 +70,7 @@ class CudnnRNNBenchmark(test.Benchmark):
def _BenchmarkOp(self, op, desc):
burn_in_steps = 10
- benchmark_steps = 40
+ benchmark_steps = 20
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
for i in xrange(burn_in_steps + benchmark_steps):
@@ -126,16 +125,12 @@ class CudnnRNNBenchmark(test.Benchmark):
seq_length = config["seq_length"]
with ops.Graph().as_default(), ops.device("/device:GPU:0"):
- inputs = seq_length * [
- array_ops.zeros([batch_size, num_units], dtypes.float32)
- ]
- initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127)
-
- cell = rnn_cell.LSTMCell(
- num_units=num_units, initializer=initializer, state_is_tuple=True)
- multi_cell = rnn_cell.MultiRNNCell(
- [cell() for _ in range(num_layers)])
- outputs, final_state = core_rnn.static_rnn(
+ inputs = array_ops.zeros([batch_size, seq_length, num_units],
+ dtypes.float32)
+
+ multi_cell = contrib_rnn.MultiRNNCell(
+ [contrib_rnn.BasicLSTMCell(num_units) for _ in range(num_layers)])
+ outputs, final_state = rnn.dynamic_rnn(
multi_cell, inputs, dtype=dtypes.float32)
trainable_variables = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES)
@@ -154,14 +149,12 @@ class CudnnRNNBenchmark(test.Benchmark):
seq_length = config["seq_length"]
with ops.Graph().as_default(), ops.device("/device:GPU:0"):
- inputs = seq_length * [
- array_ops.zeros([batch_size, num_units], dtypes.float32)
- ]
- cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop
-
- multi_cell = rnn_cell.MultiRNNCell(
- [cell() for _ in range(num_layers)])
- outputs, final_state = core_rnn.static_rnn(
+ inputs = array_ops.zeros([batch_size, seq_length, num_units],
+ dtypes.float32)
+
+ multi_cell = contrib_rnn.MultiRNNCell(
+ [lstm_ops.LSTMBlockCell(num_units) for _ in range(num_layers)])
+ outputs, final_state = rnn.dynamic_rnn(
multi_cell, inputs, dtype=dtypes.float32)
trainable_variables = ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES)
diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md
index 17f71a6d77..3ebafb9074 100644
--- a/tensorflow/docs_src/performance/performance_guide.md
+++ b/tensorflow/docs_src/performance/performance_guide.md
@@ -18,6 +18,7 @@ following sections:
* [Input pipeline optimizations](#input-pipeline-optimization)
* [Data formats](#data-formats)
* [Common fused Ops](#common-fused-ops)
+* [RNN Performance](#rnn-performance)
* [Building and installing from source](#building-and-installing-from-source)
### Input pipeline optimization
@@ -197,6 +198,57 @@ since before TensorFlow 1.0.
bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW')
```
+### RNN Performance
+
+There are many ways to specify an RNN computation in Tensorflow and they have
+have trade-offs with respect to model flexibility and performance. The
+@{tf.nn.rnn_cell.BasicLSTMCell} should be considered a reference implementation
+and used only as a last resort when no other options will work.
+
+When using one of the cells, rather than the fully fused RNN layers, you have a
+choice of whether to use @{tf.nn.static_rnn} or @{tf.nn.dynamic_rnn}. There
+shouldn't generally be a performance difference at runtime, but large unroll
+amounts can increase the graph size of the @{tf.nn.static_rnn} and cause long
+compile times. An additional advantage of @{tf.nn.dynamic_rnn} is that it can
+optionally swap memory from the GPU to the CPU to enable training of very long
+sequences. Depending on the model and hardware configuration, this can come at
+a performance cost. It is also possible to run multiple iterations of
+@{tf.nn.dynamic_rnn} and the underlying @{tf.while_loop} construct in parallel,
+although this is rarely useful with RNN models as they are inherently
+sequential.
+
+On NVIDIA GPUs, the use of @{tf.contrib.cudnn_rnn} should always be preferred
+unless you want layer normalization, which it doesn't support. It is often at
+least an order of magnitude faster than @{tf.contrib.rnn.BasicLSTMCell} and
+@{tf.contrib.rnn.LSTMBlockCell} and uses 3-4x less memory than
+@{tf.contrib.rnn.BasicLSTMCell}. Unfortunately, @{tf.contrib.cudnn_rnn} is not
+compatible with @{tf.train.SyncReplicasOptimizer} so you should either use a
+different synchronization mechanism (consider an all-reduce based strategy) or
+use the @{tf.contrib.rnn.LSTMBlockFusedCell} (at a significant performance
+penalty).
+
+If you need to run one step of the RNN at a time, as might be the case in
+reinforcement learning with a recurrent policy, then you should use the
+@{tf.contrib.rnn.LSTMBlockCell} with your own environment interaction loop
+inside a @{tf.while_loop} construct. Running one step of the RNN at a time and
+returning to python is possible but it will be slower.
+
+On CPUs, mobile devices, and if @{tf.contrib.cudnn_rnn} is not available on
+your GPU, the fastest and most memory efficient option is
+@{tf.contrib.rnn.LSTMBlockFusedCell}.
+
+For all of the less common cell types like @{tf.contrib.rnn.NASCell},
+@{tf.contrib.rnn.PhasedLSTMCell}, @{tf.contrib.rnn.UGRNNCell},
+@{tf.contrib.rnn.GLSTMCell}, @{tf.contrib.rnn.Conv1DLSTMCell},
+@{tf.contrib.rnn.Conv2DLSTMCell}, @{tf.contrib.rnn.LayerNormBasicLSTMCell},
+etc., one should be aware that they are implemented in the graph like
+@{tf.contrib.rnn.BasicLSTMCell} and as such will suffer from the same poor
+performance and high memory usage. One should consider whether or not those
+trade-offs are worth it before using these cells. For example, while layer
+normalization can speed up convergence, because cuDNN is 20x faster the fastest
+wall clock time to convergence is usually obtained without it.
+
+
### Building and installing from source
The default TensorFlow binaries target the broadest range of hardware to make