aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-23 12:27:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 12:45:10 -0800
commit92da8abfd35b93488ed7a55308b8f589ee23b622 (patch)
tree481798091b54e269a6ac7fb7a2c185d34a512a60
parent75254c37f8289dd70739fe8c8c5826782b086563 (diff)
Cleanup and consistency for variable handling in RNNCells.
In the run-up to TF 1.0, we are making RNNCells' variable names compatible with those of tf layers. This is a breaking change for those who wish to reload their old RNN model checkpoints in newly created graphs. After this change is in, variables created with RNNCells will have slightly different names than before; loading old checkpoints to run with newly created graphs requires renaming at load time. Loading and executing old graphs with old checkpoints will continue to work without any problems. Creating and loading new checkpoints with graphs after this change is in will work without any problems. The only people affected by this change are those who want to load old RNN model checkpoints graphs created after this change is in. Renaming on checkpoint load can be performed with tf.contrib.framework.variables.assign_from_checkpoint. Example usage is available here[1] if you use Saver and/or Supervisor, and [2] if you are using the newer tf.learn classes. Examples of renamed parameters: LSTMCell without sharding: my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag my_scope/LSTMCell/B -> my_scope/lstm_cell/biases LSTMCell with sharding: my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights/part_0 my_scope/LSTMCell/W_1 -> my_scope/lstm_cell/weights/part_1 my_scope/LSTMCell/W_2 -> my_scope/lstm_cell/weights/part_2 my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag my_scope/LSTMCell/B -> my_scope/lstm_cell/biases BasicLSTMCell: my_scope/BasicLSTMCell/Linear/Matrix -> my_scope/basic_lstm_cell/weights my_scope/BasicLSTMCell/Linear/Bias -> my_scope/basic_lstm_cell/biases MultiRNNCell: my_scope/MultiRNNCell/Cell0/LSTMCell/W_0 -> my_scope/multi_rnn_cell/cell_0/lstm_cell/weights my_scope/MultiRNNCell/Cell0/LSTMCell/W_F_diag -> my_scope/multi_rnn_cell/cell_0/lstm_cell/w_f_diag my_scope/MultiRNNCell/Cell0/LSTMCell/B -> my_scope/multi_rnn_cell/cell_0/lstm_cell/biases 1. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/README.md 2. https://github.com/tensorflow/tensorflow/blob/86f5ab7474825da756838b34e1b4eac93f5fc68a/tensorflow/contrib/framework/python/ops/variables_test.py#L810 Change: 140060366
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py8
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py54
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn.py11
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py99
-rw-r--r--tensorflow/python/kernel_tests/rnn_cell_test.py37
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py103
-rw-r--r--tensorflow/python/ops/rnn.py32
-rw-r--r--tensorflow/python/ops/rnn_cell.py182
9 files changed, 255 insertions, 273 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
index 164ad822b8..77da735284 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py
@@ -81,9 +81,9 @@ class LSTMBlockCellTest(tf.test.TestCase):
basic_names = {v.name: v.get_shape() for v in tf.trainable_variables()}
with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()):
- cell = tf.contrib.rnn.LSTMBlockCell(10, use_compatible_names=True)
+ cell = tf.contrib.rnn.LSTMBlockCell(10)
pcell = tf.contrib.rnn.LSTMBlockCell(
- 10, use_peephole=True, use_compatible_names=True)
+ 10, use_peephole=True)
inputs = [tf.zeros([4, 5])] * 6
tf.nn.rnn(cell, inputs, dtype=tf.float32, scope="basic")
tf.nn.rnn(pcell, inputs, dtype=tf.float32, scope="peephole")
@@ -93,8 +93,8 @@ class LSTMBlockCellTest(tf.test.TestCase):
cell = tf.contrib.rnn.LSTMBlockFusedCell(10)
pcell = tf.contrib.rnn.LSTMBlockFusedCell(10, use_peephole=True)
inputs = [tf.zeros([4, 5])] * 6
- cell(inputs, dtype=tf.float32, scope="basic/LSTMCell")
- pcell(inputs, dtype=tf.float32, scope="peephole/LSTMCell")
+ cell(inputs, dtype=tf.float32, scope="basic/lstm_cell")
+ pcell(inputs, dtype=tf.float32, scope="peephole/lstm_cell")
fused_names = {v.name: v.get_shape() for v in tf.trainable_variables()}
self.assertEqual(basic_names, block_names)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
index 89890c18b5..91b3d7f417 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py
@@ -383,7 +383,7 @@ class StackBidirectionalRNNTest(tf.test.TestCase):
# check that all the variables names starts with the proper scope.
tf.global_variables_initializer()
all_vars = tf.all_variables()
- prefix = prefix or "StackRNN"
+ prefix = prefix or "stack_bidirectional_rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("StackRNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 3e8998f117..3e30f24310 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -334,44 +334,31 @@ class LSTMBlockCell(rnn_cell.RNNCell):
Unlike `rnn_cell.LSTMCell`, this is a monolithic op and should be much faster.
The weight and bias matrixes should be compatible as long as the variable
- scope matches, and you use `use_compatible_names=True`.
+ scope matches.
"""
def __init__(self,
num_units,
forget_bias=1.0,
- use_peephole=False,
- use_compatible_names=False):
+ use_peephole=False):
"""Initialize the basic LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
use_peephole: Whether to use peephole connections or not.
- use_compatible_names: If True, use the same variable naming as
- rnn_cell.LSTMCell
"""
self._num_units = num_units
self._forget_bias = forget_bias
self._use_peephole = use_peephole
- if use_compatible_names:
- self._names = {
- "W": "W_0",
- "b": "B",
- "wci": "W_I_diag",
- "wco": "W_O_diag",
- "wcf": "W_F_diag",
- "scope": "LSTMCell"
- }
- else:
- self._names = {
- "W": "W",
- "b": "b",
- "wci": "wci",
- "wco": "wco",
- "wcf": "wcf",
- "scope": "LSTMBlockCell"
- }
+ self._names = {
+ "W": "weights",
+ "b": "biases",
+ "wci": "w_i_diag",
+ "wco": "w_o_diag",
+ "wcf": "w_f_diag",
+ "scope": "lstm_cell"
+ }
@property
def state_size(self):
@@ -385,15 +372,15 @@ class LSTMBlockCell(rnn_cell.RNNCell):
"""Long short-term memory cell (LSTM)."""
with vs.variable_scope(scope or self._names["scope"]):
x_shape = x.get_shape().with_rank(2)
- if not x_shape[1]:
- raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape))
+ if not x_shape[1].value:
+ raise ValueError("Expecting x_shape[1] to be set: %s" % str(x_shape))
if len(states_prev) != 2:
raise ValueError("Expecting states_prev to be a tuple with length 2.")
- input_size = x_shape[1]
+ input_size = x_shape[1].value
w = vs.get_variable(self._names["W"], [input_size + self._num_units,
self._num_units * 4])
b = vs.get_variable(
- self._names["b"], [w.get_shape().with_rank(2)[1]],
+ self._names["b"], [w.get_shape().with_rank(2)[1].value],
initializer=init_ops.constant_initializer(0.0))
if self._use_peephole:
wci = vs.get_variable(self._names["wci"], [self._num_units])
@@ -490,7 +477,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
Raises:
ValueError: in case of shape mismatches
"""
- with vs.variable_scope(scope or type(self).__name__):
+ with vs.variable_scope(scope or "lstm_block_wrapper"):
is_list = isinstance(inputs, list)
if is_list:
inputs = array_ops.pack(inputs)
@@ -634,15 +621,16 @@ class LSTMBlockFusedCell(LSTMBlockWrapper):
time_len = array_ops.shape(inputs)[0]
input_size = inputs_shape[2].value
w = vs.get_variable(
- "W_0", [input_size + self._num_units, self._num_units * 4], dtype=dtype)
+ "weights",
+ [input_size + self._num_units, self._num_units * 4], dtype=dtype)
b = vs.get_variable(
- "B", [w.get_shape().with_rank(2)[1]],
+ "biases", [w.get_shape().with_rank(2)[1]],
initializer=init_ops.constant_initializer(0.0),
dtype=dtype)
if self._use_peephole:
- wci = vs.get_variable("W_I_diag", [self._num_units], dtype=dtype)
- wco = vs.get_variable("W_O_diag", [self._num_units], dtype=dtype)
- wcf = vs.get_variable("W_F_diag", [self._num_units], dtype=dtype)
+ wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype)
+ wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype)
+ wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype)
else:
wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype)
diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py
index d4df308042..aa9dd98fee 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn.py
@@ -95,7 +95,7 @@ def stack_bidirectional_rnn(cells_fw,
states_bw = []
prev_layer = inputs
- with vs.variable_scope(scope or "StackRNN"):
+ with vs.variable_scope(scope or "stack_bidirectional_rnn"):
for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
initial_state_fw = None
initial_state_bw = None
@@ -104,7 +104,7 @@ def stack_bidirectional_rnn(cells_fw,
if initial_states_bw:
initial_state_bw = initial_states_bw[i]
- with vs.variable_scope("Layer%d" % i):
+ with vs.variable_scope("cell_%d" % i) as cell_scope:
prev_layer, state_fw, state_bw = tf.nn.bidirectional_rnn(
cell_fw,
cell_bw,
@@ -112,7 +112,8 @@ def stack_bidirectional_rnn(cells_fw,
initial_state_fw=initial_state_fw,
initial_state_bw=initial_state_bw,
sequence_length=sequence_length,
- dtype=dtype)
+ dtype=dtype,
+ scope=cell_scope)
states_fw.append(state_fw)
states_bw.append(state_bw)
@@ -192,7 +193,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw,
states_bw = []
prev_layer = inputs
- with vs.variable_scope(scope or "StackRNN"):
+ with vs.variable_scope(scope or "stack_bidirectional_rnn"):
for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
initial_state_fw = None
initial_state_bw = None
@@ -201,7 +202,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw,
if initial_states_bw:
initial_state_bw = initial_states_bw[i]
- with vs.variable_scope("Layer%d" % i):
+ with vs.variable_scope("cell_%d" % i):
outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(
cell_fw,
cell_bw,
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index c1c25ba094..9890e712c1 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -200,8 +200,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell):
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(scope or type(self).__name__,
- initializer=self._initializer): # "LSTMCell"
+ with vs.variable_scope(scope or "coupled_input_forget_gate_lstm_cell",
+ initializer=self._initializer):
concat_w = _get_concat_variable(
"W", [input_size.value + num_proj, 3 * self._num_units],
dtype, self._num_unit_shards)
@@ -328,7 +328,7 @@ class TimeFreqLSTMCell(rnn_cell.RNNCell):
freq_inputs = self._make_tf_features(inputs)
dtype = inputs.dtype
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
- with vs.variable_scope(scope or type(self).__name__,
+ with vs.variable_scope(scope or "time_freq_lstm_cell",
initializer=self._initializer): # "TimeFreqLSTMCell"
concat_w = _get_concat_variable(
"W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
@@ -546,7 +546,7 @@ class GridLSTMCell(rnn_cell.RNNCell):
"""
batch_size = int(inputs.get_shape()[0])
freq_inputs = self._make_tf_features(inputs)
- with vs.variable_scope(scope or type(self).__name__,
+ with vs.variable_scope(scope or "grid_lstm_cell",
initializer=self._initializer): # "GridLSTMCell"
m_out_lst = []
state_out_lst = []
@@ -968,29 +968,29 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
bwd_inputs = fwd_inputs
# Forward processing
- with vs.variable_scope((scope or type(self).__name__) + "/fwd",
- initializer=self._initializer):
- fwd_m_out_lst = []
- fwd_state_out_lst = []
- for block in range(len(fwd_inputs)):
- fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
- fwd_inputs[block], block, state, batch_size,
- state_prefix="fwd_state", state_is_tuple=True)
- fwd_m_out_lst.extend(fwd_m_out_lst_current)
- fwd_state_out_lst.extend(fwd_state_out_lst_current)
- # Backward processing
- bwd_m_out_lst = []
- bwd_state_out_lst = []
- with vs.variable_scope((scope or type(self).__name__) + "/bwd",
+ with vs.variable_scope(scope or "bidirectional_grid_lstm_cell",
initializer=self._initializer):
- for block in range(len(bwd_inputs)):
- # Reverse the blocks
- bwd_inputs_reverse = bwd_inputs[block][::-1]
- bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
- bwd_inputs_reverse, block, state, batch_size,
- state_prefix="bwd_state", state_is_tuple=True)
- bwd_m_out_lst.extend(bwd_m_out_lst_current)
- bwd_state_out_lst.extend(bwd_state_out_lst_current)
+ with vs.variable_scope("fwd"):
+ fwd_m_out_lst = []
+ fwd_state_out_lst = []
+ for block in range(len(fwd_inputs)):
+ fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
+ fwd_inputs[block], block, state, batch_size,
+ state_prefix="fwd_state", state_is_tuple=True)
+ fwd_m_out_lst.extend(fwd_m_out_lst_current)
+ fwd_state_out_lst.extend(fwd_state_out_lst_current)
+ # Backward processing
+ bwd_m_out_lst = []
+ bwd_state_out_lst = []
+ with vs.variable_scope("bwd"):
+ for block in range(len(bwd_inputs)):
+ # Reverse the blocks
+ bwd_inputs_reverse = bwd_inputs[block][::-1]
+ bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
+ bwd_inputs_reverse, block, state, batch_size,
+ state_prefix="bwd_state", state_is_tuple=True)
+ bwd_m_out_lst.extend(bwd_m_out_lst_current)
+ bwd_state_out_lst.extend(bwd_state_out_lst_current)
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
# Outputs are always concated as it is never used separately.
m_out = array_ops.concat(1, fwd_m_out_lst + bwd_m_out_lst)
@@ -1071,7 +1071,7 @@ class AttentionCellWrapper(rnn_cell.RNNCell):
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell with attention (LSTMA)."""
- with vs.variable_scope(scope or type(self).__name__):
+ with vs.variable_scope(scope or "attention_cell_wrapper"):
if self._state_is_tuple:
state, attns, attn_states = state
else:
@@ -1094,7 +1094,7 @@ class AttentionCellWrapper(rnn_cell.RNNCell):
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
- with vs.variable_scope("AttnOutputProjection"):
+ with vs.variable_scope("attn_output_projection"):
output = _linear([lstm_output, new_attns], self._attn_size, True)
new_attn_states = array_ops.concat(1, [new_attn_states,
array_ops.expand_dims(output, 1)])
@@ -1111,9 +1111,10 @@ class AttentionCellWrapper(rnn_cell.RNNCell):
softmax = nn_ops.softmax
tanh = math_ops.tanh
- with vs.variable_scope("Attention"):
- k = vs.get_variable("AttnW", [1, 1, self._attn_size, self._attn_vec_size])
- v = vs.get_variable("AttnV", [self._attn_vec_size])
+ with vs.variable_scope("attention"):
+ k = vs.get_variable(
+ "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
+ v = vs.get_variable("attn_v", [self._attn_vec_size])
hidden = array_ops.reshape(attn_states,
[-1, self._attn_length, 1, self._attn_size])
hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
@@ -1191,30 +1192,30 @@ class LayerNormBasicLSTMCell(rnn_cell.RNNCell):
return self._num_units
def _norm(self, inp, scope):
- with vs.variable_scope(scope) as scope:
- shape = inp.get_shape()[-1:]
- gamma_init = init_ops.constant_initializer(self._g)
- beta_init = init_ops.constant_initializer(self._b)
- gamma = vs.get_variable("gamma", shape=shape, initializer=gamma_init) # pylint: disable=unused-variable
- beta = vs.get_variable("beta", shape=shape, initializer=beta_init) # pylint: disable=unused-variable
- normalized = layers.layer_norm(inp, reuse=True, scope=scope)
- return normalized
-
- def _linear(self, args, scope="linear"):
+ shape = inp.get_shape()[-1:]
+ gamma_init = init_ops.constant_initializer(self._g)
+ beta_init = init_ops.constant_initializer(self._b)
+ with vs.variable_scope(scope):
+ # Initialize beta and gamma for use by layer_norm.
+ vs.get_variable("gamma", shape=shape, initializer=gamma_init)
+ vs.get_variable("beta", shape=shape, initializer=beta_init)
+ normalized = layers.layer_norm(inp, reuse=True, scope=scope)
+ return normalized
+
+ def _linear(self, args):
out_size = 4 * self._num_units
proj_size = args.get_shape()[-1]
- with vs.variable_scope(scope) as scope:
- weights = vs.get_variable("weights", [proj_size, out_size])
- out = math_ops.matmul(args, weights)
- if not self._layer_norm:
- bias = vs.get_variable("b", [out_size])
- out += bias
- return out
+ weights = vs.get_variable("weights", [proj_size, out_size])
+ out = math_ops.matmul(args, weights)
+ if not self._layer_norm:
+ bias = vs.get_variable("biases", [out_size])
+ out = nn_ops.bias_add(out, bias)
+ return out
def __call__(self, inputs, state, scope=None):
"""LSTM cell with layer normalization and recurrent dropout."""
- with vs.variable_scope(scope or type(self).__name__) as scope: # LayerNormBasicLSTMCell # pylint: disable=unused-variables
+ with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"):
c, h = state
args = array_ops.concat(1, [inputs, h])
concat = self._linear(args)
diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py
index 69eeb116a6..e4e239169a 100644
--- a/tensorflow/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/python/kernel_tests/rnn_cell_test.py
@@ -95,6 +95,20 @@ class RNNCellTest(tf.test.TestCase):
res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]),
m.name: 0.1 * np.ones([1, 8])})
self.assertEqual(len(res), 2)
+ variables = tf.global_variables()
+ self.assertEqual(4, len(variables))
+ self.assertEquals(
+ variables[0].op.name,
+ "root/multi_rnn_cell/cell_0/basic_lstm_cell/weights")
+ self.assertEquals(
+ variables[1].op.name,
+ "root/multi_rnn_cell/cell_0/basic_lstm_cell/biases")
+ self.assertEquals(
+ variables[2].op.name,
+ "root/multi_rnn_cell/cell_1/basic_lstm_cell/weights")
+ self.assertEquals(
+ variables[3].op.name,
+ "root/multi_rnn_cell/cell_1/basic_lstm_cell/biases")
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
expected_mem = np.array([[0.68967271, 0.68967271,
@@ -204,6 +218,26 @@ class RNNCellTest(tf.test.TestCase):
self.assertTrue(
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
+ def testLSTMCellVariables(self):
+ with self.test_session():
+ num_units = 8
+ num_proj = 6
+ state_size = num_units + num_proj
+ batch_size = 3
+ input_size = 2
+ with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
+ x = tf.zeros([batch_size, input_size])
+ m = tf.zeros([batch_size, state_size])
+ cell = tf.nn.rnn_cell.LSTMCell(
+ num_units=num_units, num_proj=num_proj, forget_bias=1.0,
+ state_is_tuple=False)
+ cell(x, m) # Execute to create variables
+ variables = tf.global_variables()
+ self.assertEquals(variables[0].op.name, "root/lstm_cell/weights")
+ self.assertEquals(variables[1].op.name, "root/lstm_cell/biases")
+ self.assertEquals(
+ variables[2].op.name, "root/lstm_cell/projection/weights")
+
def testOutputProjectionWrapper(self):
with self.test_session() as sess:
with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)):
@@ -354,6 +388,7 @@ class SlimRNNCellTest(tf.test.TestCase):
# pylint: enable=protected-access
slim_outputs, slim_state = slim_cell(inputs, initial_state)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units)
+ tf.get_variable_scope().reuse_variables()
outputs, state = rnn_cell(inputs, initial_state)
self.assertEqual(slim_outputs.get_shape(), outputs.get_shape())
self.assertEqual(slim_state.get_shape(), state.get_shape())
@@ -377,7 +412,7 @@ def basic_rnn_cell(inputs, state, num_units, scope=None):
init_state.set_shape([batch_size, num_units])
return init_output, init_state
else:
- with tf.variable_scope(scope, "BasicRNNCell", [inputs, state]):
+ with tf.variable_scope(scope, "basic_rnn_cell", [inputs, state]):
output = tf.tanh(linear([inputs, state],
num_units, True))
return output, output
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 5a74adad76..d3897afb92 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -260,8 +260,8 @@ class RNNTest(tf.test.TestCase):
# check that all the variables names starts
# with the proper scope.
tf.global_variables_initializer()
- all_vars = tf.all_variables()
- prefix = prefix or "RNN"
+ all_vars = tf.global_variables()
+ prefix = prefix or "rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("RNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
@@ -333,8 +333,8 @@ class GRUTest(tf.test.TestCase):
# check that all the variables names starts
# with the proper scope.
- all_vars = tf.all_variables()
- prefix = prefix or "RNN"
+ all_vars = tf.global_variables()
+ prefix = prefix or "rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("RNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
@@ -567,13 +567,14 @@ class LSTMTest(tf.test.TestCase):
cell_tuple = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer, state_is_tuple=True)
- outputs_notuple, state_notuple = tf.nn.rnn(
- cell_notuple, inputs, dtype=tf.float32,
- sequence_length=sequence_length)
- tf.get_variable_scope().reuse_variables()
- outputs_tuple, state_tuple = tf.nn.rnn(
- cell_tuple, inputs, dtype=tf.float32,
- sequence_length=sequence_length)
+ with tf.variable_scope("root") as scope:
+ outputs_notuple, state_notuple = tf.nn.rnn(
+ cell_notuple, inputs, dtype=tf.float32,
+ sequence_length=sequence_length, scope=scope)
+ scope.reuse_variables()
+ outputs_tuple, state_tuple = tf.nn.rnn(
+ cell_tuple, inputs, dtype=tf.float32,
+ sequence_length=sequence_length, scope=scope)
self.assertEqual(len(outputs_notuple), len(inputs))
self.assertEqual(len(outputs_tuple), len(inputs))
self.assertTrue(isinstance(state_tuple, tuple))
@@ -624,31 +625,6 @@ class LSTMTest(tf.test.TestCase):
input_value = np.random.randn(batch_size, input_size)
sess.run(outputs, feed_dict={inputs[0]: input_value})
- def _testTooManyShards(self, use_gpu):
- num_units = 3
- input_size = 5
- num_proj = 4
- num_proj_shards = 4
- num_unit_shards = 2
- max_length = 8
- with self.test_session(use_gpu=use_gpu, graph=tf.Graph()):
- initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
-
- inputs = max_length * [
- tf.placeholder(tf.float32, shape=(None, input_size))]
-
- cell = tf.nn.rnn_cell.LSTMCell(
- num_units,
- use_peepholes=True,
- num_proj=num_proj,
- num_unit_shards=num_unit_shards,
- num_proj_shards=num_proj_shards,
- initializer=initializer,
- state_is_tuple=False)
-
- with self.assertRaises(ValueError):
- tf.nn.rnn(cell, inputs, dtype=tf.float32)
-
def _testDoubleInput(self, use_gpu):
num_units = 3
input_size = 5
@@ -871,13 +847,14 @@ class LSTMTest(tf.test.TestCase):
cell = tf.nn.rnn_cell.LSTMCell(
num_units, use_peepholes=True,
num_proj=num_proj, initializer=initializer, state_is_tuple=True)
- outputs_static, state_static = tf.nn.rnn(
- cell, inputs, dtype=tf.float32,
- sequence_length=sequence_length)
- tf.get_variable_scope().reuse_variables()
- outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
- cell, inputs_c, dtype=tf.float32, time_major=True,
- sequence_length=sequence_length)
+ with tf.variable_scope("root") as scope:
+ outputs_static, state_static = tf.nn.rnn(
+ cell, inputs, dtype=tf.float32,
+ sequence_length=sequence_length, scope=scope)
+ scope.reuse_variables()
+ outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
+ cell, inputs_c, dtype=tf.float32, time_major=True,
+ sequence_length=sequence_length, scope=scope)
self.assertTrue(isinstance(state_static, tf.nn.rnn_cell.LSTMStateTuple))
self.assertTrue(isinstance(state_dynamic, tf.nn.rnn_cell.LSTMStateTuple))
self.assertEqual(state_static[0], state_static.c)
@@ -932,13 +909,14 @@ class LSTMTest(tf.test.TestCase):
self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0])
self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1])
- outputs_static, state_static = tf.nn.rnn(
- cell, inputs, dtype=tf.float32,
- sequence_length=sequence_length)
- tf.get_variable_scope().reuse_variables()
- outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
- cell, inputs_c, dtype=tf.float32, time_major=True,
- sequence_length=sequence_length)
+ with tf.variable_scope("root") as scope:
+ outputs_static, state_static = tf.nn.rnn(
+ cell, inputs, dtype=tf.float32,
+ sequence_length=sequence_length, scope=scope)
+ scope.reuse_variables()
+ outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn(
+ cell, inputs_c, dtype=tf.float32, time_major=True,
+ sequence_length=sequence_length, scope=scope)
tf.global_variables_initializer().run()
@@ -1126,10 +1104,6 @@ class LSTMTest(tf.test.TestCase):
self._testProjSharding(use_gpu=False)
self._testProjSharding(use_gpu=True)
- def testTooManyShards(self):
- self._testTooManyShards(use_gpu=False)
- self._testTooManyShards(use_gpu=True)
-
def testShardNoShardEquivalentOutput(self):
self._testShardNoShardEquivalentOutput(use_gpu=False)
self._testShardNoShardEquivalentOutput(use_gpu=True)
@@ -1415,8 +1389,8 @@ class BidirectionalRNNTest(tf.test.TestCase):
# check that all the variables names starts
# with the proper scope.
tf.global_variables_initializer()
- all_vars = tf.all_variables()
- prefix = prefix or "BiRNN"
+ all_vars = tf.global_variables()
+ prefix = prefix or "bidirectional_rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("BiRNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
@@ -1667,13 +1641,16 @@ class RawRNNTest(tf.test.TestCase):
lambda: inputs_ta.read(time_))
return (elements_finished, next_input, next_state, emit_output, None)
- outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
+ reuse_scope = tf.get_variable_scope()
+
+ outputs_ta, final_state, _ = tf.nn.raw_rnn(
+ cell, loop_fn, scope=reuse_scope)
outputs = outputs_ta.pack()
- tf.get_variable_scope().reuse_variables()
+ reuse_scope.reuse_variables()
outputs_dynamic_rnn, final_state_dynamic_rnn = tf.nn.dynamic_rnn(
cell, inputs, time_major=True, dtype=tf.float32,
- sequence_length=sequence_length)
+ sequence_length=sequence_length, scope=reuse_scope)
variables = tf.trainable_variables()
gradients = tf.gradients([outputs, final_state], [inputs] + variables)
@@ -1854,8 +1831,8 @@ class RawRNNTest(tf.test.TestCase):
# check that all the variables names starts
# with the proper scope.
- all_vars = tf.all_variables()
- prefix = prefix or "RNN"
+ all_vars = tf.global_variables()
+ prefix = prefix or "rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("RNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
@@ -1917,8 +1894,8 @@ class StateSaverRNNTest(tf.test.TestCase):
# check that all the variables names starts
# with the proper scope.
- all_vars = tf.all_variables()
- prefix = prefix or "RNN"
+ all_vars = tf.global_variables()
+ prefix = prefix or "rnn"
scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")]
tf.logging.info("RNN with scope: %s (%s)"
% (prefix, "scope" if use_outer_scope else "str"))
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index f67f4f35e8..b1270a1937 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -113,7 +113,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
dtype.
sequence_length: Specifies the length of each sequence in inputs.
An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
@@ -139,7 +139,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "RNN") as varscope:
+ with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
@@ -246,7 +246,7 @@ def state_saving_rnn(cell, inputs, state_saver, state_name,
be a single string.
sequence_length: (optional) An int32/int64 vector size [batch_size].
See the documentation for rnn() for more details about sequence_length.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
@@ -508,7 +508,8 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
either of the initial states are not provided.
sequence_length: (optional) An int32/int64 vector, size `[batch_size]`,
containing the actual lengths for each of the sequences.
- scope: VariableScope for the created subgraph; defaults to "BiRNN"
+ scope: VariableScope for the created subgraph; defaults to
+ "bidirectional_rnn"
Returns:
A tuple (outputs, output_state_fw, output_state_bw) where:
@@ -531,14 +532,14 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs,
if not inputs:
raise ValueError("inputs must not be empty")
- with vs.variable_scope(scope or "BiRNN"):
+ with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
- with vs.variable_scope("FW") as fw_scope:
+ with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = rnn(cell_fw, inputs, initial_state_fw, dtype,
sequence_length, scope=fw_scope)
# Backward direction
- with vs.variable_scope("BW") as bw_scope:
+ with vs.variable_scope("bw") as bw_scope:
reversed_inputs = _reverse_seq(inputs, sequence_length)
tmp, output_state_bw = rnn(cell_bw, reversed_inputs, initial_state_bw,
dtype, sequence_length, scope=bw_scope)
@@ -610,7 +611,8 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
accepts input and emits output in batch-major form.
dtype: (optional) The data type for the initial state. Required if
either of the initial states are not provided.
- scope: VariableScope for the created subgraph; defaults to "BiRNN"
+ scope: VariableScope for the created subgraph; defaults to
+ "bidirectional_rnn"
Returns:
A tuple (outputs, output_states) where:
@@ -642,9 +644,9 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
if not isinstance(cell_bw, rnn_cell.RNNCell):
raise TypeError("cell_bw must be an instance of RNNCell")
- with vs.variable_scope(scope or "BiRNN"):
+ with vs.variable_scope(scope or "bidirectional_rnn"):
# Forward direction
- with vs.variable_scope("FW") as fw_scope:
+ with vs.variable_scope("fw") as fw_scope:
output_fw, output_state_fw = dynamic_rnn(
cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
initial_state=initial_state_fw, dtype=dtype,
@@ -659,7 +661,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
time_dim = 0
batch_dim = 1
- with vs.variable_scope("BW") as bw_scope:
+ with vs.variable_scope("bw") as bw_scope:
inputs_reverse = array_ops.reverse_sequence(
input=inputs, seq_lengths=sequence_length,
seq_dim=time_dim, batch_dim=batch_dim)
@@ -746,7 +748,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
transposes at the beginning and end of the RNN calculation. However,
most TensorFlow data is batch-major, so by default this function
accepts input and emits output in batch-major form.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A pair (outputs, state) where:
@@ -801,7 +803,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "RNN") as varscope:
+ with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
input_shape = tuple(array_ops.shape(input_) for input_ in flat_input)
@@ -1161,7 +1163,7 @@ def raw_rnn(cell, loop_fn,
but needed for back prop from GPU to CPU. This allows training RNNs
which would typically not fit on a single GPU, with very minimal (or no)
performance penalty.
- scope: VariableScope for the created subgraph; defaults to "RNN".
+ scope: VariableScope for the created subgraph; defaults to "rnn".
Returns:
A tuple `(emit_ta, final_state, final_loop_state)` where:
@@ -1201,7 +1203,7 @@ def raw_rnn(cell, loop_fn,
# Create a new scope in which the caching device is either
# determined by the parent scope, or is set to place the cached
# Variable using the same placement as for the rest of the RNN.
- with vs.variable_scope(scope or "RNN") as varscope:
+ with vs.variable_scope(scope or "rnn") as varscope:
if varscope.caching_device is None:
varscope.set_caching_device(lambda op: op.device)
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index aa521eabb1..d620177e90 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -53,6 +53,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.math_ops import sigmoid
@@ -195,9 +196,10 @@ class BasicRNNCell(RNNCell):
return self._num_units
def __call__(self, inputs, state, scope=None):
- """Most basic RNN: output = new_state = activation(W * input + U * state + B)."""
- with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
- output = self._activation(_linear([inputs, state], self._num_units, True))
+ """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
+ with vs.variable_scope(scope or "basic_rnn_cell"):
+ output = self._activation(
+ _linear([inputs, state], self._num_units, True, scope=scope))
return output, output
@@ -220,15 +222,17 @@ class GRUCell(RNNCell):
def __call__(self, inputs, state, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
- with vs.variable_scope(scope or type(self).__name__): # "GRUCell"
- with vs.variable_scope("Gates"): # Reset gate and update gate.
+ with vs.variable_scope(scope or "gru_cell"):
+ with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
- r, u = array_ops.split(1, 2, _linear([inputs, state],
- 2 * self._num_units, True, 1.0))
+ r, u = array_ops.split(
+ 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0,
+ scope=scope))
r, u = sigmoid(r), sigmoid(u)
- with vs.variable_scope("Candidate"):
+ with vs.variable_scope("candidate"):
c = self._activation(_linear([inputs, r * state],
- self._num_units, True))
+ self._num_units, True,
+ scope=scope))
new_h = u * state + (1 - u) * c
return new_h, new_h
@@ -302,13 +306,13 @@ class BasicLSTMCell(RNNCell):
def __call__(self, inputs, state, scope=None):
"""Long short-term memory cell (LSTM)."""
- with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
+ with vs.variable_scope(scope or "basic_lstm_cell"):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(1, 2, state)
- concat = _linear([inputs, h], 4 * self._num_units, True)
+ concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(1, 4, concat)
@@ -324,42 +328,6 @@ class BasicLSTMCell(RNNCell):
return new_h, new_state
-def _get_concat_variable(name, shape, dtype, num_shards):
- """Get a sharded variable concatenated into one tensor."""
- sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
- if len(sharded_variable) == 1:
- return sharded_variable[0]
-
- concat_name = name + "/concat"
- concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
- for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
- if value.name == concat_full_name:
- return value
-
- concat_variable = array_ops.concat(0, sharded_variable, name=concat_name)
- ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
- concat_variable)
- return concat_variable
-
-
-def _get_sharded_variable(name, shape, dtype, num_shards):
- """Get a list of sharded variables with the given dtype."""
- if num_shards > shape[0]:
- raise ValueError("Too many shards: shape=%s, num_shards=%d" %
- (shape, num_shards))
- unit_shard_size = int(math.floor(shape[0] / num_shards))
- remaining_rows = shape[0] - unit_shard_size * num_shards
-
- shards = []
- for i in range(num_shards):
- current_size = unit_shard_size
- if i < remaining_rows:
- current_size += 1
- shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
- dtype=dtype))
- return shards
-
-
class LSTMCell(RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
@@ -385,7 +353,7 @@ class LSTMCell(RNNCell):
def __init__(self, num_units, input_size=None,
use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None, proj_clip=None,
- num_unit_shards=1, num_proj_shards=1,
+ num_unit_shards=None, num_proj_shards=None,
forget_bias=1.0, state_is_tuple=True,
activation=tanh):
"""Initialize the parameters for an LSTM cell.
@@ -401,12 +369,12 @@ class LSTMCell(RNNCell):
num_proj: (optional) int, The output dimensionality for the projection
matrices. If None, no projection is performed.
proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
- provided, then the projected values are clipped elementwise to within
- `[-proj_clip, proj_clip]`.
- num_unit_shards: How to split the weight matrix. If >1, the weight
- matrix is stored across num_unit_shards.
- num_proj_shards: How to split the projection matrix. If >1, the
- projection matrix is stored across num_proj_shards.
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ num_unit_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
+ num_proj_shards: Deprecated, will be removed by Jan. 2017.
+ Use a variable_scope partitioner instead.
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
the training.
@@ -420,6 +388,12 @@ class LSTMCell(RNNCell):
"deprecated. Use state_is_tuple=True.", self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
+ if num_unit_shards is not None or num_proj_shards is not None:
+ logging.warn(
+ "%s: The num_unit_shards and proj_unit_shards parameters are "
+ "deprecated and will be removed in Jan 2017. "
+ "Use a variable scope with a partitioner instead.", self)
+
self._num_units = num_units
self._use_peepholes = use_peepholes
self._cell_clip = cell_clip
@@ -460,7 +434,7 @@ class LSTMCell(RNNCell):
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
`m_state`.
- scope: VariableScope for the created subgraph; defaults to "LSTMCell".
+ scope: VariableScope for the created subgraph; defaults to "lstm_cell".
Returns:
A tuple containing:
@@ -489,29 +463,28 @@ class LSTMCell(RNNCell):
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(scope or type(self).__name__,
- initializer=self._initializer): # "LSTMCell"
- concat_w = _get_concat_variable(
- "W", [input_size.value + num_proj, 4 * self._num_units],
- dtype, self._num_unit_shards)
-
- b = vs.get_variable(
- "B", shape=[4 * self._num_units],
- initializer=init_ops.zeros_initializer, dtype=dtype)
-
+ with vs.variable_scope(scope or "lstm_cell",
+ initializer=self._initializer) as unit_scope:
+ if self._num_unit_shards is not None:
+ unit_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_unit_shards))
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
- cell_inputs = array_ops.concat(1, [inputs, m_prev])
- lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
+ lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
+ scope=scope)
i, j, f, o = array_ops.split(1, 4, lstm_matrix)
# Diagonal connections
if self._use_peepholes:
- w_f_diag = vs.get_variable(
- "W_F_diag", shape=[self._num_units], dtype=dtype)
- w_i_diag = vs.get_variable(
- "W_I_diag", shape=[self._num_units], dtype=dtype)
- w_o_diag = vs.get_variable(
- "W_O_diag", shape=[self._num_units], dtype=dtype)
+ with vs.variable_scope(unit_scope) as projection_scope:
+ if self._num_unit_shards is not None:
+ projection_scope.set_partitioner(None)
+ w_f_diag = vs.get_variable(
+ "w_f_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = vs.get_variable(
+ "w_i_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = vs.get_variable(
+ "w_o_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
@@ -531,11 +504,13 @@ class LSTMCell(RNNCell):
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
- concat_w_proj = _get_concat_variable(
- "W_P", [self._num_units, self._num_proj],
- dtype, self._num_proj_shards)
+ with vs.variable_scope("projection") as proj_scope:
+ if self._num_proj_shards is not None:
+ proj_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_proj_shards))
+ m = _linear(m, self._num_proj, bias=False, scope=scope)
- m = math_ops.matmul(m, concat_w_proj)
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
@@ -585,8 +560,8 @@ class OutputProjectionWrapper(RNNCell):
"""Run the cell and output projection on inputs, starting from state."""
output, res_state = self._cell(inputs, state)
# Default scope: "OutputProjectionWrapper"
- with vs.variable_scope(scope or type(self).__name__):
- projected = _linear(output, self._output_size, True)
+ with vs.variable_scope(scope or "output_projection_wrapper"):
+ projected = _linear(output, self._output_size, True, scope=scope)
return projected, res_state
@@ -627,8 +602,8 @@ class InputProjectionWrapper(RNNCell):
def __call__(self, inputs, state, scope=None):
"""Run the input projection and then the cell."""
# Default scope: "InputProjectionWrapper"
- with vs.variable_scope(scope or type(self).__name__):
- projected = _linear(inputs, self._num_proj, True)
+ with vs.variable_scope(scope or "input_projection_wrapper"):
+ projected = _linear(inputs, self._num_proj, True, scope=scope)
return self._cell(projected, state)
@@ -731,7 +706,7 @@ class EmbeddingWrapper(RNNCell):
def __call__(self, inputs, state, scope=None):
"""Run the cell on embedded inputs."""
- with vs.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"
+ with vs.variable_scope(scope or "embedding_wrapper"): # "EmbeddingWrapper"
with ops.device("/cpu:0"):
if self._initializer:
initializer = self._initializer
@@ -796,12 +771,12 @@ class MultiRNNCell(RNNCell):
def __call__(self, inputs, state, scope=None):
"""Run this multi-layer cell on inputs, starting from state."""
- with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
+ with vs.variable_scope(scope or "multi_rnn_cell"):
cur_state_pos = 0
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
- with vs.variable_scope("Cell%d" % i):
+ with vs.variable_scope("cell_%d" % i):
if self._state_is_tuple:
if not nest.is_sequence(state):
raise ValueError(
@@ -872,7 +847,7 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
- scope: VariableScope for the created subgraph; defaults to "Linear".
+ scope: (optional) Variable scope to create parameters in.
Returns:
A 2D Tensor with shape [batch x output_size] equal to
@@ -888,30 +863,33 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
# Calculate the total size of arguments on dimension 1.
total_arg_size = 0
- shapes = [a.get_shape().as_list() for a in args]
+ shapes = [a.get_shape() for a in args]
for shape in shapes:
- if len(shape) != 2:
- raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes))
- if not shape[1]:
- raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes))
+ if shape.ndims != 2:
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
+ if shape[1].value is None:
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
+ "but saw %d" % (shape, shape[1]))
else:
- total_arg_size += shape[1]
+ total_arg_size += shape[1].value
dtype = [a.dtype for a in args][0]
# Now the computation.
- with vs.variable_scope(scope or "Linear"):
- matrix = vs.get_variable(
- "Matrix", [total_arg_size, output_size], dtype=dtype)
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope) as outer_scope:
+ weights = vs.get_variable(
+ "weights", [total_arg_size, output_size], dtype=dtype)
if len(args) == 1:
- res = math_ops.matmul(args[0], matrix)
+ res = math_ops.matmul(args[0], weights)
else:
- res = math_ops.matmul(array_ops.concat(1, args), matrix)
+ res = math_ops.matmul(array_ops.concat(1, args), weights)
if not bias:
return res
- bias_term = vs.get_variable(
- "Bias", [output_size],
- dtype=dtype,
- initializer=init_ops.constant_initializer(
- bias_start, dtype=dtype))
- return res + bias_term
+ with vs.variable_scope(outer_scope) as inner_scope:
+ inner_scope.set_partitioner(None)
+ biases = vs.get_variable(
+ "biases", [output_size],
+ dtype=dtype,
+ initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
+ return nn_ops.bias_add(res, biases)