aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-06-20 16:54:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-20 16:57:32 -0700
commita936b239cf91a9727d15ab94fbe1b08e68c685b3 (patch)
treeb94288c80d6ef866f041ef186a379de4a764ffd4
parent4be287671af6190ca52ce04bd93bf356bb7fce20 (diff)
Support reuse cuDNNLSTM-trained checkpoints by multi-layer LSTM(Block)Cell
* Add tensor stitching/partition in RNNParamSaveable for saving/restoring properly shaped and formatted weights/biases to share w/ LSTM(Block)Cell * Add remapped names for canonical tensors during saving. * Unittests PiperOrigin-RevId: 159634913
-rw-r--r--tensorflow/contrib/cudnn_rnn/BUILD2
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py241
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py229
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py12
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py5
5 files changed, 465 insertions, 24 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD
index b1caac476a..fc473d3380 100644
--- a/tensorflow/contrib/cudnn_rnn/BUILD
+++ b/tensorflow/contrib/cudnn_rnn/BUILD
@@ -87,6 +87,8 @@ cuda_py_test(
additional_deps = [
":cudnn_rnn_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/rnn:rnn_py",
+ "//tensorflow/python/ops/losses:losses",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
index 08ec3076e4..0e51ab9935 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py
@@ -20,8 +20,13 @@ from __future__ import print_function
import os
import unittest
+import numpy as np
+
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
+from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework.test_util import TensorFlowTestCase
@@ -29,10 +34,14 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import rnn
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
from tensorflow.python.training import saver as saver_lib
@@ -69,7 +78,8 @@ class CudnnRNNTest(TensorFlowTestCase):
model: a CudnnRNN model.
"""
params_saveable = cudnn_rnn_ops.RNNParamsSaveable(
- model.params_to_canonical, model.canonical_to_params, [params])
+ model, model.params_to_canonical, model.canonical_to_params, [params],
+ "rnn")
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable)
def _testSaveRestoreVariable(self, rnn_mode):
@@ -93,6 +103,218 @@ class CudnnRNNTest(TensorFlowTestCase):
params_v_restored = sess.run(params)
self.assertAllEqual(params_v, params_v_restored)
+ def _create_equivalent_canonical_rnn(self,
+ cudnn_model,
+ inputs,
+ use_block_cell,
+ scope="rnn"):
+ if cudnn_model.rnn_mode is not "lstm":
+ raise ValueError("%s is not supported!" % cudnn_model.rnn_mode)
+
+ num_units = cudnn_model.num_units
+ num_layers = cudnn_model.num_layers
+
+ # To reuse cuDNN-trained models, must set
+ # forget_bias, clip_cell = 0, False
+ # In LSTMCell and LSTMBlockCell, forget_bias is added in addition to learned
+ # bias, whereas cuDNN does not apply the additional bias.
+ if use_block_cell:
+ # pylint: disable=g-long-lambda
+ single_cell = lambda: lstm_ops.LSTMBlockCell(num_units, forget_bias=0,
+ clip_cell=False)
+ # pylint: enable=g-long-lambda
+ else:
+ single_cell = lambda: rnn_cell_impl.LSTMCell(num_units, forget_bias=0)
+ cell = rnn_cell_impl.MultiRNNCell(
+ [single_cell() for _ in range(num_layers)])
+ return rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope)
+
+ def _build_forward_cudnn_model(self,
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_data,
+ is_training=False):
+ input_data_shape = input_data.get_shape().with_rank(3)
+ batch_size = input_data_shape[1].value
+ input_size = input_data_shape[2].value
+ model = self._CreateModel(rnn_mode, num_layers, num_units, input_size)
+
+ # Set zero init input states
+ input_h = constant_op.constant(
+ np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32)
+ has_input_c = (rnn_mode == "lstm")
+ if has_input_c:
+ input_c = constant_op.constant(
+ np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32)
+
+ # Set rnn params
+ params_size_t = model.params_size()
+ params = variables.Variable(
+ random_ops.random_uniform([params_size_t]), validate_shape=False)
+ args = {
+ "input_data": input_data,
+ "input_h": input_h,
+ "params": params,
+ "is_training": is_training
+ }
+ if has_input_c:
+ args["input_c"] = input_c
+ # Build cell
+ output_tuple = model(**args)
+
+ # Create savable objects for params
+ self._create_params_savable(params, model)
+
+ return output_tuple, model, params
+
+ @unittest.skipUnless(test.is_built_with_cuda(),
+ "Test only applicable when running on GPUs")
+ def testCheckpointReusableByCanonicalLSTMCells(self):
+ configs = [
+ {
+ "num_layers": 1,
+ "seq_length": 3,
+ "num_units": 4,
+ "input_size": 5,
+ "batch_size": 6,
+ "rnn_mode": "lstm"
+ },
+ {
+ "num_layers": 2,
+ "seq_length": 8,
+ "num_units": 4,
+ "input_size": 8,
+ "batch_size": 16,
+ "rnn_mode": "lstm"
+ },
+ {
+ "num_layers": 2,
+ "seq_length": 3,
+ "num_units": 4,
+ "input_size": 5,
+ "batch_size": 6,
+ "rnn_mode": "lstm"
+ },
+ {
+ "num_layers": 1,
+ "seq_length": 2,
+ "num_units": 2,
+ "input_size": 4,
+ "batch_size": 1,
+ "rnn_mode": "lstm"
+ },
+ ]
+ for cfg in configs:
+ self._testCheckpointReusableByCanonicalLSTMCells(
+ cfg["num_layers"],
+ cfg["seq_length"],
+ cfg["num_units"],
+ cfg["input_size"],
+ cfg["batch_size"],
+ cfg["rnn_mode"],
+ use_block_cell=False)
+ self._testCheckpointReusableByCanonicalLSTMCells(
+ cfg["num_layers"],
+ cfg["seq_length"],
+ cfg["num_units"],
+ cfg["input_size"],
+ cfg["batch_size"],
+ cfg["rnn_mode"],
+ use_block_cell=True)
+
+ def _testCheckpointReusableByCanonicalLSTMCells(
+ self, num_layers, seq_length, num_units, input_size, batch_size, rnn_mode,
+ use_block_cell):
+ np.random.seed(0)
+ # Train graph
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(299)
+ input_data = array_ops.placeholder(
+ dtypes.float32, shape=[seq_length, batch_size, input_size])
+ output_tuple, cudnn_model, cudnn_params = self._build_forward_cudnn_model(
+ rnn_mode, num_layers, num_units, input_data, is_training=True)
+ target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ total_sum = sum(map(math_ops.reduce_sum, output_tuple))
+
+ loss_op = losses.log_loss(labels=target_output, predictions=total_sum)
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2)
+ train_op = optimizer.minimize(loss_op)
+
+ saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
+
+ # Train Cudnn model
+ with self.test_session(
+ use_gpu=True, graph=ops.get_default_graph()) as sess:
+ sess.run(variables.global_variables_initializer())
+ # Train 128 steps
+ num_steps = 128
+ for _ in range(num_steps):
+ inputs = np.random.rand(seq_length, batch_size,
+ input_size).astype(np.float32)
+ targets = np.random.rand()
+ sess.run(
+ train_op, feed_dict={input_data: inputs,
+ target_output: targets})
+
+ save_path = os.path.join(self.get_temp_dir(),
+ ("cudnn-rnn-%s-test" % rnn_mode))
+ save_v = saver.save(sess, save_path)
+ self.assertEqual(save_path, save_v)
+ cudnn_params_v = sess.run(cudnn_params)
+
+ # cuDNN inference graph
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(299)
+ cudnn_inputs = array_ops.placeholder(
+ dtypes.float32, shape=[seq_length, batch_size, input_size])
+ (cudnn_output_tuple, cudnn_model,
+ cudnn_params) = self._build_forward_cudnn_model(
+ rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False)
+ saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
+
+ inference_input = np.random.rand(seq_length, batch_size,
+ input_size).astype(np.float32)
+ with self.test_session(
+ use_gpu=True, graph=ops.get_default_graph()) as sess:
+ sess.run(variables.global_variables_initializer())
+ saver.restore(sess, save_path)
+ restored_cudnn_params_v = sess.run(cudnn_params)
+ self.assertAllEqual(cudnn_params_v, restored_cudnn_params_v)
+
+ # Cudnn inference
+ (cudnn_output, cudnn_output_h, cudnn_output_c) = sess.run(
+ cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input})
+
+ # LSTMBlockCell inference graph
+ with ops.Graph().as_default():
+ random_seed.set_random_seed(299)
+ cell_inputs = array_ops.placeholder(
+ dtypes.float32, shape=[seq_length, batch_size, input_size])
+ (output, states) = self._create_equivalent_canonical_rnn(
+ cudnn_model, cell_inputs, use_block_cell)
+ saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2)
+
+ with self.test_session(
+ use_gpu=True, graph=ops.get_default_graph()) as sess:
+ saver.restore(sess, save_path)
+
+ # BlockCell inference
+ output_v, states_v = sess.run(
+ [output, states], feed_dict={cell_inputs: inference_input})
+
+ # output across timestamps are packed into one tensor.
+ self.assertAllClose(cudnn_output, output_v, atol=1e-6, rtol=1e-6)
+
+ for i in range(num_layers):
+ # output_h
+ self.assertAllClose(
+ cudnn_output_h[i, :], states_v[i].h, atol=1e-6, rtol=1e-6)
+ # output_c
+ self.assertAllClose(
+ cudnn_output_c[i, :], states_v[i].c, atol=1e-6, rtol=1e-6)
+
def _testSaveRestoreOutput(self, rnn_mode):
num_layers = 2
num_units = 7
@@ -187,9 +409,13 @@ class CudnnRNNTest(TensorFlowTestCase):
batch_size, seq_length, dir_count, dropout,
expected, tolerance):
random_seed.set_random_seed(5678)
- model = self._CreateModel(rnn_mode, num_layers, num_units, input_size,
- input_mode="auto_select",
- dropout=dropout)
+ model = self._CreateModel(
+ rnn_mode,
+ num_layers,
+ num_units,
+ input_size,
+ input_mode="auto_select",
+ dropout=dropout)
has_input_c = (rnn_mode == "lstm")
params_size_t = model.params_size()
input_data = array_ops.ones([seq_length, batch_size, input_size])
@@ -216,7 +442,7 @@ class CudnnRNNTest(TensorFlowTestCase):
if has_input_c:
output_c_sum = math_ops.reduce_sum(output_c)
total_sum += output_c_sum
- with self.test_session(use_gpu=True) as sess:
+ with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess:
sess.run(variables.global_variables_initializer())
total_sum_v = sess.run([total_sum])
@@ -310,8 +536,8 @@ class CudnnRNNTest(TensorFlowTestCase):
os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
has_input_c = (rnn_mode == "lstm")
random_seed.set_random_seed(1234)
- model = self._CreateModel(rnn_mode, num_layers, num_units, input_size,
- dropout=dropout)
+ model = self._CreateModel(
+ rnn_mode, num_layers, num_units, input_size, dropout=dropout)
params_size_t = model.params_size()
input_data = variables.Variable(
random_ops.random_uniform([seq_length, batch_size, input_size]))
@@ -417,6 +643,7 @@ class CudnnRNNTest(TensorFlowTestCase):
},
},
]
+ ops.reset_default_graph()
with ops.Graph().as_default():
for config in test_configs:
rnn_mode = config["rnn_mode"]
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index cc0c7b0829..0437467f3f 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -16,7 +16,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import itertools
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
from tensorflow.contrib.util import loader
@@ -46,9 +45,11 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation that handles the RNN params variable."""
def __init__(self,
+ cudnn_rnn,
params_to_canonical,
canonical_to_params,
param_variables,
+ base_variable_scope=None,
name="params_canonical"):
"""Creates a RNNParamsSaveable object.
@@ -75,6 +76,7 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate.
Args:
+ cudnn_rnn: cudnn RNN class instance.
params_to_canonical: a function to convert params from a specific format
for cuDNN or other RNN ops to the canonical format.
_CudnnRNN.params_to_canonical() should be provided here.
@@ -87,25 +89,42 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
For cuDNN RNN ops, this is a single merged variable for both weights
and biases; for other RNN ops, this might be multiple unmerged or
partially merged variables respectively for weights and biases.
+ base_variable_scope: a string, name of outer variable scope, used as
+ part of prefix of names of saved variables.
name: the name of the RNNParamsSaveable object.
"""
# There is only a single merged parameter variable for cuDNN when saving.
+ self._cudnn_rnn = cudnn_rnn
weights, biases = params_to_canonical(param_variables[0])
+ weights, biases, = self._transform_canonical(weights, biases)
+ weight_names, biase_names = self._transformed_canonical_names(
+ weights, biases)
self._canonical_to_params = canonical_to_params
self._variables = param_variables
# We currently don't use slice_spec. It might be useful in a distributed
# setting where each parameter server node stores a slice of variable,
# instead of having the master pull all slices and then save them.
slice_spec = ""
+ params = weights + biases
+ param_names = weight_names + biase_names
+ if base_variable_scope:
+ param_names = ["%s/%s" % (base_variable_scope, pn) for pn in param_names]
specs = [
- saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param.name)
- for param in itertools.chain(weights, biases)
+ saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name)
+ for param, param_name in zip(params, param_names)
]
super(RNNParamsSaveable, self).__init__(None, specs, name)
def restore(self, restored_tensors, restored_shapes):
- weights = restored_tensors[:len(restored_tensors) // 2]
- biases = restored_tensors[len(restored_tensors) // 2:]
+ if (self._cudnn_rnn.direction == "unidirectional" and
+ self._cudnn_rnn.rnn_mode == "lstm"):
+ assert len(restored_tensors) % 4 == 0
+ weights = restored_tensors[:len(restored_tensors) // 4]
+ biases = restored_tensors[len(restored_tensors) // 4:]
+ else:
+ weights = restored_tensors[:len(restored_tensors) // 2]
+ biases = restored_tensors[len(restored_tensors) // 2:]
+ weights, biases = self._untransform_canonical(weights, biases)
params = self._canonical_to_params(weights, biases)
if not isinstance(params, tuple):
params = (params,)
@@ -115,6 +134,159 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
]
return control_flow_ops.group(*assign_ops)
+ def _switch_inner(self, array, base_idx):
+ array[base_idx + 1], array[base_idx + 2] = (array[base_idx + 2],
+ array[base_idx + 1])
+
+ def _transform_canonical(self, weights, biases):
+ if (self._cudnn_rnn.direction != "unidirectional" or
+ self._cudnn_rnn.rnn_mode != "lstm"):
+ return weights, biases
+ return self._transform_lstm_canonical(weights, biases)
+
+ def _transformed_canonical_names(self, weights, biases):
+ """Return canonical names for fused weight and bias tensors."""
+ if (self._cudnn_rnn.direction != "unidirectional" or
+ self._cudnn_rnn.rnn_mode != "lstm"):
+ assert len(weights) == len(biases)
+ return ([w.name for w in weights], [b.name for b in biases])
+ else:
+ w_names, b_names = [], []
+ assert len(weights) * 3 == len(biases)
+ num_layers = self._cudnn_rnn.num_layers
+ # TODO(jamesqin): get rid of multi_rnn_cell when num_layers is 1
+ for i in range(num_layers):
+ # One fused weight tensor each layer.
+ w_names.append("multi_rnn_cell/cell_%d/lstm_cell/kernel" % i)
+ # Three fused bias tensors each layer:
+ # the 1st is for LSTMBlockCell restore; the latter two sum up to the
+ # 1st, and are used for cuDNN restore.
+ b_names.append("multi_rnn_cell/cell_%d/lstm_cell/bias" % i)
+ b_names.extend([
+ "multi_rnn_cell/cell_%d/lstm_cell/bias_cudnn_%d" % (i, j)
+ for j in range(2)
+ ])
+ return w_names, b_names
+
+ def _transform_lstm_canonical(self, weights, biases):
+ """Create fused lstm canonical params.
+
+ Produce properly-shaped monolithic weight and bias tensors to share between
+ cuDNN and non-platform specific LSTM cells (w/o peephole).
+ Args:
+ weights: a list of Tensors recovered from cuDNN params_to_canonical.
+ biases: a list of Tensors recovered from cuDNN params_to_canonical.
+ Returns:
+ Two lists of tensors, one for weight and bias each.
+ The weight list contains num_layers tensors and bias one contains 3 *
+ num_layers tensors. Both original and combined biases since cuDNN biases
+ are not restorable from the fused version.
+ """
+ transformed_weights, transformed_biases = [], []
+ for i in range(self._cudnn_rnn.num_layers):
+ base_idx = i * 8
+ num_units = self._cudnn_rnn.num_units
+ input_size = self._cudnn_rnn.input_size if i == 0 else num_units
+ # cuDNN tensor shapes per time_step:
+ # input.shape: [batch_size, input_size],
+ # input_weights.shape: [num_units, input_size] (first layer)
+ # [num_units, num_units] (other layers)
+ # state_weights.shape: [num_units, num_units]
+ # biases.shape: [num_units]
+ #
+ # General LSTM cells compute gate functions using:
+ # [x, h_prev] * weights + biases
+ # Therefore for each layer, they expect
+ # weight.shape: [input_size + num_units, 4 * num_units] (first_layer)
+ # [num_units + num_units, 4 * num_units] (other layers)
+ # bias.shape: [4 * num_units]
+
+ # Stitch weights together in this layer.
+ stitched_w = []
+ for j in range(4):
+ stitched_w.append(
+ array_ops.concat(
+ [
+ array_ops.reshape(weights[base_idx + j],
+ [num_units, input_size]),
+ array_ops.reshape(weights[base_idx + j + 4],
+ [num_units, num_units])
+ ],
+ axis=1))
+ # cuDNN weights are in ifco order, convert to icfo order.
+ self._switch_inner(stitched_w, 0)
+ transformed_weights.append(
+ array_ops.transpose(array_ops.concat(stitched_w, axis=0)))
+
+ # Stitch biases together in this layer.
+ # Convert to icfo order.
+ self._switch_inner(biases, base_idx)
+ self._switch_inner(biases, base_idx + 4)
+ # The bias for layer input.
+ b_in = array_ops.concat(biases[base_idx:base_idx + 4], axis=0)
+ # The bias for recurrent input.
+ b_rec = array_ops.concat(biases[base_idx + 4:base_idx + 8], axis=0)
+
+ transformed_biases.extend([b_in + b_rec, b_in, b_rec])
+ return transformed_weights, transformed_biases
+
+ def _untransform_canonical(self, transformed_weights, transformed_biases):
+ if (self._cudnn_rnn.direction != "unidirectional" or
+ self._cudnn_rnn.rnn_mode != "lstm"):
+ return transformed_weights, transformed_biases
+ return self._untransform_lstm_canonical(transformed_weights,
+ transformed_biases)
+
+ def _untransform_lstm_canonical(self, transformed_weights,
+ transformed_biases):
+ """The reverse procedure of _transform_lstm_canonical().
+
+ Args:
+ transformed_weights: a list of tensors, one for each layer.
+ transformed_biases: a list of tensors , 3 for each layer: the 2nd for
+ layer input, the 3rd for recurrent input, the 1st is the sum of the
+ latter two.
+ Returns:
+ Two lists of tensors for weights and biases respectively.
+ There are 8 tensors per weight and per bias for each layer:
+ tensor 0-3 are applied to the input from the previous layer;
+ tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate;
+ tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate;
+ tensor 3 and 7 the output gate.
+ """
+ weights, biases = [], []
+ assert 3 * len(transformed_weights) == len(transformed_biases)
+ for i in range(len(transformed_weights)):
+ num_units = self._cudnn_rnn.num_units
+ input_size = self._cudnn_rnn.input_size if i == 0 else num_units
+ # weights applied on layer inputs.
+ wi = array_ops.slice(transformed_weights[i], [0, 0],
+ [input_size, 4 * num_units])
+ # weights applied on recurrent inputs.
+ wr = array_ops.slice(transformed_weights[i], [input_size, 0],
+ [num_units, 4 * num_units])
+ wi_list = array_ops.split(wi, 4, axis=1)
+ wr_list = array_ops.split(wr, 4, axis=1)
+
+ for j in range(len(wi_list)):
+ wi_list[j] = array_ops.reshape(array_ops.transpose(wi_list[j]), [-1])
+ wr_list[j] = array_ops.reshape(array_ops.transpose(wr_list[j]), [-1])
+ # canonical weights are in icfo order, convert to ifco order for cuDNN.
+ self._switch_inner(wi_list, 0)
+ self._switch_inner(wr_list, 0)
+ weights.extend(wi_list)
+ weights.extend(wr_list)
+
+ base_idx = 3 * i
+ bi_list = array_ops.split(transformed_biases[base_idx + 1], 4, axis=0)
+ br_list = array_ops.split(transformed_biases[base_idx + 2], 4, axis=0)
+ # canonical weights are in icfo order, convert to ifco order for cuDNN.
+ self._switch_inner(bi_list, 0)
+ self._switch_inner(br_list, 0)
+ biases.extend(bi_list)
+ biases.extend(br_list)
+ return weights, biases
+
_cudnn_rnn_common_doc_string = """
Cudnn RNN has an opaque parameter buffer that can be used for inference and
@@ -199,6 +371,26 @@ class _CudnnRNN(object):
if self._seed is None and self._seed2 is None:
self._seed, self._seed2 = 0, 0
+ @property
+ def input_size(self):
+ return self._input_size
+
+ @property
+ def num_units(self):
+ return self._num_units
+
+ @property
+ def num_layers(self):
+ return self._num_layers
+
+ @property
+ def rnn_mode(self):
+ return self._rnn_mode
+
+ @property
+ def direction(self):
+ return self._direction
+
def params_size(self):
"""Calculates the size of the opaque parameter buffer needed for this model.
@@ -222,9 +414,12 @@ class _CudnnRNN(object):
"""Runs the forward step for the RNN model.
Args:
- input_data: the input sequence to the RNN model.
- input_h: the initial hidden state for h.
+ input_data: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
input_c: the initial hidden state for c. This is only relevant for LSTM.
+ A Tensor of the same shape as input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.
@@ -308,7 +503,7 @@ class CudnnLSTM(_CudnnRNN):
num_layers,
num_units,
input_size,
- input_mode="auto_select",
+ input_mode="linear_input",
direction="unidirectional",
dropout=0.,
seed=0):
@@ -344,9 +539,12 @@ class CudnnLSTM(_CudnnRNN):
"""Runs the forward step for the Cudnn LSTM model.
Args:
- input_data: the input sequence to the LSTM model.
- input_h: the initial hidden state for h.
- input_c: the initial hidden state for c.
+ input_data: the input sequence to the LSTM model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
+ input_c: the initial hidden state for c. A Tensor of the same shape as
+ input_h.
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.
@@ -368,7 +566,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
num_layers,
num_units,
input_size,
- input_mode="auto_select",
+ input_mode="linear_input",
direction="unidirectional",
dropout=0.,
seed=0):
@@ -390,6 +588,7 @@ class _CudnnRNNNoInputC(_CudnnRNN):
dropout: whether to enable dropout. With it is 0, dropout is disabled.
seed: the seed used for initializing dropout.
"""
+
super(_CudnnRNNNoInputC, self).__init__(
self._rnn_mode,
num_layers,
@@ -404,8 +603,10 @@ class _CudnnRNNNoInputC(_CudnnRNN):
"""Runs the forward step for the Cudnn LSTM model.
Args:
- input_data: the input sequence to the LSTM model.
- input_h: the initial hidden state for h.
+ input_data: the input sequence to the RNN model. A Tensor of shape [?,
+ batch_size, input_size].
+ input_h: the initial hidden state for h. A Tensor of shape [num_layers,
+ batch_size, num_units].
params: the parameter buffer created for this model.
is_training: whether this operation will be used in training or inference.
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index c41b5793fc..97b9dcc905 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -58,7 +58,7 @@ def _lstm_block_cell(x,
```python
xh = [x, h_prev]
- [i, f, ci, o] = xh * w + b
+ [i, ci, f, o] = xh * w + b
f = f + forget_bias
if not use_peephole:
@@ -93,7 +93,7 @@ def _lstm_block_cell(x,
The weight matrix for output gate peephole connection.
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
cell_clip: An optional `float`. Defaults to `3`.
- Value to clip the 'cs' value to.
+ Value to clip the 'cs' value to. Disable by setting to negative value.
use_peephole: An optional `bool`. Defaults to `False`.
Whether to use peephole weights.
name: A name for the operation (optional).
@@ -341,17 +341,24 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
def __init__(self,
num_units,
forget_bias=1.0,
+ clip_cell=True,
use_peephole=False):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
+ clip_cell: boolean, whether to apply cell clipping. See
+ `_lstm_block_cell()` for details.
use_peephole: Whether to use peephole connections or not.
+
+ When restoring from CudnnLSTM-trained checkpoints, must set the following:
+ forget_bias, clip_cell, use_peephole = 0, False, False
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._use_peephole = use_peephole
+ self._clip_cell = clip_cell
self._names = {
"W": "kernel",
"b": "bias",
@@ -400,6 +407,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
wco=wco,
wcf=wcf,
forget_bias=self._forget_bias,
+ cell_clip=None if self._clip_cell else -1,
use_peephole=self._use_peephole)
new_state = rnn_cell_impl.LSTMStateTuple(cs, h)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 49a4aba473..ca69cddae2 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -345,6 +345,8 @@ class BasicLSTMCell(RNNCell):
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
+ Must set to `0.0` manually when restoring from CudnnLSTM-trained
+ checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
@@ -444,7 +446,8 @@ class LSTMCell(RNNCell):
Use a variable_scope partitioner instead.
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
- the training.
+ the training. Must set it manually to `0.0` when restoring from
+ CudnnLSTM trained checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. This latter behavior will soon be deprecated.