aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-23 12:27:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 12:45:10 -0800
commit92da8abfd35b93488ed7a55308b8f589ee23b622 (patch)
tree481798091b54e269a6ac7fb7a2c185d34a512a60 /tensorflow/python/ops/rnn.py
parent75254c37f8289dd70739fe8c8c5826782b086563 (diff)
Cleanup and consistency for variable handling in RNNCells.
In the run-up to TF 1.0, we are making RNNCells' variable names compatible with those of tf layers. This is a breaking change for those who wish to reload their old RNN model checkpoints in newly created graphs. After this change is in, variables created with RNNCells will have slightly different names than before; loading old checkpoints to run with newly created graphs requires renaming at load time. Loading and executing old graphs with old checkpoints will continue to work without any problems. Creating and loading new checkpoints with graphs after this change is in will work without any problems. The only people affected by this change are those who want to load old RNN model checkpoints graphs created after this change is in. Renaming on checkpoint load can be performed with tf.contrib.framework.variables.assign_from_checkpoint. Example usage is available here[1] if you use Saver and/or Supervisor, and [2] if you are using the newer tf.learn classes. Examples of renamed parameters: LSTMCell without sharding: my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag my_scope/LSTMCell/B -> my_scope/lstm_cell/biases LSTMCell with sharding: my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights/part_0 my_scope/LSTMCell/W_1 -> my_scope/lstm_cell/weights/part_1 my_scope/LSTMCell/W_2 -> my_scope/lstm_cell/weights/part_2 my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag my_scope/LSTMCell/B -> my_scope/lstm_cell/biases BasicLSTMCell: my_scope/BasicLSTMCell/Linear/Matrix -> my_scope/basic_lstm_cell/weights my_scope/BasicLSTMCell/Linear/Bias -> my_scope/basic_lstm_cell/biases MultiRNNCell: my_scope/MultiRNNCell/Cell0/LSTMCell/W_0 -> my_scope/multi_rnn_cell/cell_0/lstm_cell/weights my_scope/MultiRNNCell/Cell0/LSTMCell/W_F_diag -> my_scope/multi_rnn_cell/cell_0/lstm_cell/w_f_diag my_scope/MultiRNNCell/Cell0/LSTMCell/B -> my_scope/multi_rnn_cell/cell_0/lstm_cell/biases 1. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/README.md 2. https://github.com/tensorflow/tensorflow/blob/86f5ab7474825da756838b34e1b4eac93f5fc68a/tensorflow/contrib/framework/python/ops/variables_test.py#L810 Change: 140060366
Diffstat (limited to 'tensorflow/python/ops/rnn.py')
-rw-r--r--tensorflow/python/ops/rnn.py32
1 files changed, 17 insertions, 15 deletions
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index f67f4f35e8..b1270a1937 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -113,7 +113,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
dtype.
sequence_length: Specifies the length of each sequence in inputs.
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
@@ -139,7 +139,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "RNN") as varscope:
+ with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -246,7 +246,7 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
be a single string.
sequence_length: (optional) An int32/int64 vector size [batch_size].
See the documentation for rnn() for more details about sequence_length.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
@@ -508,7 +508,8 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
either of the initial states are not provided.
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
containing the actual lengths for each of the sequences.
- scope: VariableScope for the created subgraph; defaults to "BiRNN"
+ scope: VariableScope for the created subgraph; defaults to
+ "bidirectional_rnn"
Returns:
A tuple (outputs, output_state_fw, output_state_bw) where:
@@ -531,14 +532,14 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
if not inputs:
raise ValueError("inputs must not be empty")
- with vs.variable_scope(scope or "BiRNN"):
+ with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
- with vs.variable_scope("FW") as fw_scope:
+ with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = rnn(cell_fw, inputs, initial_state_fw, dtype,
sequence_length, scope=fw_scope)
# Backward direction
- with vs.variable_scope("BW") as bw_scope:
+ with vs.variable_scope("bw") as bw_scope:
reversed_inputs = _reverse_seq(inputs, sequence_length)
tmp, output_state_bw = rnn(cell_bw, reversed_inputs, initial_state_bw,
dtype, sequence_length, scope=bw_scope)
@@ -610,7 +611,8 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
accepts input and emits output in batch-major form.
dtype: (optional) The data type for the initial state. Required if
either of the initial states are not provided.
- scope: VariableScope for the created subgraph; defaults to "BiRNN"
+ scope: VariableScope for the created subgraph; defaults to
+ "bidirectional_rnn"
Returns:
A tuple (outputs, output_states) where:
@@ -642,9 +644,9 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
if not isinstance(cell_bw, rnn_cell.RNNCell):
raise TypeError("cell_bw must be an instance of RNNCell")
- with vs.variable_scope(scope or "BiRNN"):
+ with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
- with vs.variable_scope("FW") as fw_scope:
+ with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
@@ -659,7 +661,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
time_dim = 0
batch_dim = 1
- with vs.variable_scope("BW") as bw_scope:
+ with vs.variable_scope("bw") as bw_scope:
inputs_reverse = array_ops.reverse_sequence(
input=inputs, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
@@ -746,7 +748,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
transposes at the beginning and end of the RNN calculation. However,
most TensorFlow data is batch-major, so by default this function
accepts input and emits output in batch-major form.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
@@ -801,7 +803,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "RNN") as varscope:
+ with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
input_shape = tuple(array_ops.shape(input_) for input_ in flat_input)
@@ -1161,7 +1163,7 @@ def raw_rnn(cell, loop_fn,
but needed for back prop from GPU to CPU. This allows training RNNs
which would typically not fit on a single GPU, with very minimal (or no)
performance penalty.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A tuple `(emit_ta, final_state, final_loop_state)` where:
@@ -1201,7 +1203,7 @@ def raw_rnn(cell, loop_fn,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "RNN") as varscope:
+ with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)