diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2016-11-23 12:27:29 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-23 12:45:10 -0800 |
commit | 92da8abfd35b93488ed7a55308b8f589ee23b622 (patch) | |
tree | 481798091b54e269a6ac7fb7a2c185d34a512a60 | |
parent | 75254c37f8289dd70739fe8c8c5826782b086563 (diff) |
Cleanup and consistency for variable handling in RNNCells.
In the run-up to TF 1.0, we are making RNNCells' variable names compatible
with those of tf layers.
This is a breaking change for those who wish to reload their old RNN model
checkpoints in newly created graphs. After this change is in, variables
created with RNNCells will have slightly different names than before;
loading old checkpoints to run with newly created graphs requires
renaming at load time.
Loading and executing old graphs with old checkpoints will continue to work
without any problems. Creating and loading new checkpoints with graphs
after this change is in will work without any problems. The only people
affected by this change are those who want to load old RNN model checkpoints
graphs created after this change is in.
Renaming on checkpoint load can be performed with
tf.contrib.framework.variables.assign_from_checkpoint. Example usage
is available here[1] if you use Saver and/or Supervisor, and [2] if you
are using the newer tf.learn classes.
Examples of renamed parameters:
LSTMCell without sharding:
my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights
my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag
my_scope/LSTMCell/B -> my_scope/lstm_cell/biases
LSTMCell with sharding:
my_scope/LSTMCell/W_0 -> my_scope/lstm_cell/weights/part_0
my_scope/LSTMCell/W_1 -> my_scope/lstm_cell/weights/part_1
my_scope/LSTMCell/W_2 -> my_scope/lstm_cell/weights/part_2
my_scope/LSTMCell/W_F_diag -> my_scope/lstm_cell/w_f_diag
my_scope/LSTMCell/B -> my_scope/lstm_cell/biases
BasicLSTMCell:
my_scope/BasicLSTMCell/Linear/Matrix -> my_scope/basic_lstm_cell/weights
my_scope/BasicLSTMCell/Linear/Bias -> my_scope/basic_lstm_cell/biases
MultiRNNCell:
my_scope/MultiRNNCell/Cell0/LSTMCell/W_0 -> my_scope/multi_rnn_cell/cell_0/lstm_cell/weights
my_scope/MultiRNNCell/Cell0/LSTMCell/W_F_diag -> my_scope/multi_rnn_cell/cell_0/lstm_cell/w_f_diag
my_scope/MultiRNNCell/Cell0/LSTMCell/B -> my_scope/multi_rnn_cell/cell_0/lstm_cell/biases
1.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/README.md
2. https://github.com/tensorflow/tensorflow/blob/86f5ab7474825da756838b34e1b4eac93f5fc68a/tensorflow/contrib/framework/python/ops/variables_test.py#L810
Change: 140060366
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 54 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 99 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_cell_test.py | 37 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/rnn_test.py | 103 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn.py | 32 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn_cell.py | 182 |
9 files changed, 255 insertions, 273 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 164ad822b8..77da735284 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -81,9 +81,9 @@ class LSTMBlockCellTest(tf.test.TestCase): basic_names = {v.name: v.get_shape() for v in tf.trainable_variables()} with self.test_session(use_gpu=self._use_gpu, graph=tf.Graph()): - cell = tf.contrib.rnn.LSTMBlockCell(10, use_compatible_names=True) + cell = tf.contrib.rnn.LSTMBlockCell(10) pcell = tf.contrib.rnn.LSTMBlockCell( - 10, use_peephole=True, use_compatible_names=True) + 10, use_peephole=True) inputs = [tf.zeros([4, 5])] * 6 tf.nn.rnn(cell, inputs, dtype=tf.float32, scope="basic") tf.nn.rnn(pcell, inputs, dtype=tf.float32, scope="peephole") @@ -93,8 +93,8 @@ class LSTMBlockCellTest(tf.test.TestCase): cell = tf.contrib.rnn.LSTMBlockFusedCell(10) pcell = tf.contrib.rnn.LSTMBlockFusedCell(10, use_peephole=True) inputs = [tf.zeros([4, 5])] * 6 - cell(inputs, dtype=tf.float32, scope="basic/LSTMCell") - pcell(inputs, dtype=tf.float32, scope="peephole/LSTMCell") + cell(inputs, dtype=tf.float32, scope="basic/lstm_cell") + pcell(inputs, dtype=tf.float32, scope="peephole/lstm_cell") fused_names = {v.name: v.get_shape() for v in tf.trainable_variables()} self.assertEqual(basic_names, block_names) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py index 89890c18b5..91b3d7f417 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_test.py @@ -383,7 +383,7 @@ class StackBidirectionalRNNTest(tf.test.TestCase): # check that all the variables names starts with the proper scope. tf.global_variables_initializer() all_vars = tf.all_variables() - prefix = prefix or "StackRNN" + prefix = prefix or "stack_bidirectional_rnn" scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] tf.logging.info("StackRNN with scope: %s (%s)" % (prefix, "scope" if use_outer_scope else "str")) diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 3e8998f117..3e30f24310 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -334,44 +334,31 @@ class LSTMBlockCell(rnn_cell.RNNCell): Unlike `rnn_cell.LSTMCell`, this is a monolithic op and should be much faster. The weight and bias matrixes should be compatible as long as the variable - scope matches, and you use `use_compatible_names=True`. + scope matches. """ def __init__(self, num_units, forget_bias=1.0, - use_peephole=False, - use_compatible_names=False): + use_peephole=False): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). use_peephole: Whether to use peephole connections or not. - use_compatible_names: If True, use the same variable naming as - rnn_cell.LSTMCell """ self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole - if use_compatible_names: - self._names = { - "W": "W_0", - "b": "B", - "wci": "W_I_diag", - "wco": "W_O_diag", - "wcf": "W_F_diag", - "scope": "LSTMCell" - } - else: - self._names = { - "W": "W", - "b": "b", - "wci": "wci", - "wco": "wco", - "wcf": "wcf", - "scope": "LSTMBlockCell" - } + self._names = { + "W": "weights", + "b": "biases", + "wci": "w_i_diag", + "wco": "w_o_diag", + "wcf": "w_f_diag", + "scope": "lstm_cell" + } @property def state_size(self): @@ -385,15 +372,15 @@ class LSTMBlockCell(rnn_cell.RNNCell): """Long short-term memory cell (LSTM).""" with vs.variable_scope(scope or self._names["scope"]): x_shape = x.get_shape().with_rank(2) - if not x_shape[1]: - raise ValueError("Expecting x_shape[1] to be sets: %s" % str(x_shape)) + if not x_shape[1].value: + raise ValueError("Expecting x_shape[1] to be set: %s" % str(x_shape)) if len(states_prev) != 2: raise ValueError("Expecting states_prev to be a tuple with length 2.") - input_size = x_shape[1] + input_size = x_shape[1].value w = vs.get_variable(self._names["W"], [input_size + self._num_units, self._num_units * 4]) b = vs.get_variable( - self._names["b"], [w.get_shape().with_rank(2)[1]], + self._names["b"], [w.get_shape().with_rank(2)[1].value], initializer=init_ops.constant_initializer(0.0)) if self._use_peephole: wci = vs.get_variable(self._names["wci"], [self._num_units]) @@ -490,7 +477,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell): Raises: ValueError: in case of shape mismatches """ - with vs.variable_scope(scope or type(self).__name__): + with vs.variable_scope(scope or "lstm_block_wrapper"): is_list = isinstance(inputs, list) if is_list: inputs = array_ops.pack(inputs) @@ -634,15 +621,16 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): time_len = array_ops.shape(inputs)[0] input_size = inputs_shape[2].value w = vs.get_variable( - "W_0", [input_size + self._num_units, self._num_units * 4], dtype=dtype) + "weights", + [input_size + self._num_units, self._num_units * 4], dtype=dtype) b = vs.get_variable( - "B", [w.get_shape().with_rank(2)[1]], + "biases", [w.get_shape().with_rank(2)[1]], initializer=init_ops.constant_initializer(0.0), dtype=dtype) if self._use_peephole: - wci = vs.get_variable("W_I_diag", [self._num_units], dtype=dtype) - wco = vs.get_variable("W_O_diag", [self._num_units], dtype=dtype) - wcf = vs.get_variable("W_F_diag", [self._num_units], dtype=dtype) + wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype) + wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) + wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype) else: wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype) diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py index d4df308042..aa9dd98fee 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn.py +++ b/tensorflow/contrib/rnn/python/ops/rnn.py @@ -95,7 +95,7 @@ def stack_bidirectional_rnn(cells_fw, states_bw = [] prev_layer = inputs - with vs.variable_scope(scope or "StackRNN"): + with vs.variable_scope(scope or "stack_bidirectional_rnn"): for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)): initial_state_fw = None initial_state_bw = None @@ -104,7 +104,7 @@ def stack_bidirectional_rnn(cells_fw, if initial_states_bw: initial_state_bw = initial_states_bw[i] - with vs.variable_scope("Layer%d" % i): + with vs.variable_scope("cell_%d" % i) as cell_scope: prev_layer, state_fw, state_bw = tf.nn.bidirectional_rnn( cell_fw, cell_bw, @@ -112,7 +112,8 @@ def stack_bidirectional_rnn(cells_fw, initial_state_fw=initial_state_fw, initial_state_bw=initial_state_bw, sequence_length=sequence_length, - dtype=dtype) + dtype=dtype, + scope=cell_scope) states_fw.append(state_fw) states_bw.append(state_bw) @@ -192,7 +193,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, states_bw = [] prev_layer = inputs - with vs.variable_scope(scope or "StackRNN"): + with vs.variable_scope(scope or "stack_bidirectional_rnn"): for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)): initial_state_fw = None initial_state_bw = None @@ -201,7 +202,7 @@ def stack_bidirectional_dynamic_rnn(cells_fw, if initial_states_bw: initial_state_bw = initial_states_bw[i] - with vs.variable_scope("Layer%d" % i): + with vs.variable_scope("cell_%d" % i): outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn( cell_fw, cell_bw, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index c1c25ba094..9890e712c1 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -200,8 +200,8 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell): input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(scope or type(self).__name__, - initializer=self._initializer): # "LSTMCell" + with vs.variable_scope(scope or "coupled_input_forget_gate_lstm_cell", + initializer=self._initializer): concat_w = _get_concat_variable( "W", [input_size.value + num_proj, 3 * self._num_units], dtype, self._num_unit_shards) @@ -328,7 +328,7 @@ class TimeFreqLSTMCell(rnn_cell.RNNCell): freq_inputs = self._make_tf_features(inputs) dtype = inputs.dtype actual_input_size = freq_inputs[0].get_shape().as_list()[1] - with vs.variable_scope(scope or type(self).__name__, + with vs.variable_scope(scope or "time_freq_lstm_cell", initializer=self._initializer): # "TimeFreqLSTMCell" concat_w = _get_concat_variable( "W", [actual_input_size + 2*self._num_units, 4 * self._num_units], @@ -546,7 +546,7 @@ class GridLSTMCell(rnn_cell.RNNCell): """ batch_size = int(inputs.get_shape()[0]) freq_inputs = self._make_tf_features(inputs) - with vs.variable_scope(scope or type(self).__name__, + with vs.variable_scope(scope or "grid_lstm_cell", initializer=self._initializer): # "GridLSTMCell" m_out_lst = [] state_out_lst = [] @@ -968,29 +968,29 @@ class BidirectionalGridLSTMCell(GridLSTMCell): bwd_inputs = fwd_inputs # Forward processing - with vs.variable_scope((scope or type(self).__name__) + "/fwd", - initializer=self._initializer): - fwd_m_out_lst = [] - fwd_state_out_lst = [] - for block in range(len(fwd_inputs)): - fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( - fwd_inputs[block], block, state, batch_size, - state_prefix="fwd_state", state_is_tuple=True) - fwd_m_out_lst.extend(fwd_m_out_lst_current) - fwd_state_out_lst.extend(fwd_state_out_lst_current) - # Backward processing - bwd_m_out_lst = [] - bwd_state_out_lst = [] - with vs.variable_scope((scope or type(self).__name__) + "/bwd", + with vs.variable_scope(scope or "bidirectional_grid_lstm_cell", initializer=self._initializer): - for block in range(len(bwd_inputs)): - # Reverse the blocks - bwd_inputs_reverse = bwd_inputs[block][::-1] - bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( - bwd_inputs_reverse, block, state, batch_size, - state_prefix="bwd_state", state_is_tuple=True) - bwd_m_out_lst.extend(bwd_m_out_lst_current) - bwd_state_out_lst.extend(bwd_state_out_lst_current) + with vs.variable_scope("fwd"): + fwd_m_out_lst = [] + fwd_state_out_lst = [] + for block in range(len(fwd_inputs)): + fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( + fwd_inputs[block], block, state, batch_size, + state_prefix="fwd_state", state_is_tuple=True) + fwd_m_out_lst.extend(fwd_m_out_lst_current) + fwd_state_out_lst.extend(fwd_state_out_lst_current) + # Backward processing + bwd_m_out_lst = [] + bwd_state_out_lst = [] + with vs.variable_scope("bwd"): + for block in range(len(bwd_inputs)): + # Reverse the blocks + bwd_inputs_reverse = bwd_inputs[block][::-1] + bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( + bwd_inputs_reverse, block, state, batch_size, + state_prefix="bwd_state", state_is_tuple=True) + bwd_m_out_lst.extend(bwd_m_out_lst_current) + bwd_state_out_lst.extend(bwd_state_out_lst_current) state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) # Outputs are always concated as it is never used separately. m_out = array_ops.concat(1, fwd_m_out_lst + bwd_m_out_lst) @@ -1071,7 +1071,7 @@ class AttentionCellWrapper(rnn_cell.RNNCell): def __call__(self, inputs, state, scope=None): """Long short-term memory cell with attention (LSTMA).""" - with vs.variable_scope(scope or type(self).__name__): + with vs.variable_scope(scope or "attention_cell_wrapper"): if self._state_is_tuple: state, attns, attn_states = state else: @@ -1094,7 +1094,7 @@ class AttentionCellWrapper(rnn_cell.RNNCell): else: new_state_cat = new_state new_attns, new_attn_states = self._attention(new_state_cat, attn_states) - with vs.variable_scope("AttnOutputProjection"): + with vs.variable_scope("attn_output_projection"): output = _linear([lstm_output, new_attns], self._attn_size, True) new_attn_states = array_ops.concat(1, [new_attn_states, array_ops.expand_dims(output, 1)]) @@ -1111,9 +1111,10 @@ class AttentionCellWrapper(rnn_cell.RNNCell): softmax = nn_ops.softmax tanh = math_ops.tanh - with vs.variable_scope("Attention"): - k = vs.get_variable("AttnW", [1, 1, self._attn_size, self._attn_vec_size]) - v = vs.get_variable("AttnV", [self._attn_vec_size]) + with vs.variable_scope("attention"): + k = vs.get_variable( + "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) + v = vs.get_variable("attn_v", [self._attn_vec_size]) hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") @@ -1191,30 +1192,30 @@ class LayerNormBasicLSTMCell(rnn_cell.RNNCell): return self._num_units def _norm(self, inp, scope): - with vs.variable_scope(scope) as scope: - shape = inp.get_shape()[-1:] - gamma_init = init_ops.constant_initializer(self._g) - beta_init = init_ops.constant_initializer(self._b) - gamma = vs.get_variable("gamma", shape=shape, initializer=gamma_init) # pylint: disable=unused-variable - beta = vs.get_variable("beta", shape=shape, initializer=beta_init) # pylint: disable=unused-variable - normalized = layers.layer_norm(inp, reuse=True, scope=scope) - return normalized - - def _linear(self, args, scope="linear"): + shape = inp.get_shape()[-1:] + gamma_init = init_ops.constant_initializer(self._g) + beta_init = init_ops.constant_initializer(self._b) + with vs.variable_scope(scope): + # Initialize beta and gamma for use by layer_norm. + vs.get_variable("gamma", shape=shape, initializer=gamma_init) + vs.get_variable("beta", shape=shape, initializer=beta_init) + normalized = layers.layer_norm(inp, reuse=True, scope=scope) + return normalized + + def _linear(self, args): out_size = 4 * self._num_units proj_size = args.get_shape()[-1] - with vs.variable_scope(scope) as scope: - weights = vs.get_variable("weights", [proj_size, out_size]) - out = math_ops.matmul(args, weights) - if not self._layer_norm: - bias = vs.get_variable("b", [out_size]) - out += bias - return out + weights = vs.get_variable("weights", [proj_size, out_size]) + out = math_ops.matmul(args, weights) + if not self._layer_norm: + bias = vs.get_variable("biases", [out_size]) + out = nn_ops.bias_add(out, bias) + return out def __call__(self, inputs, state, scope=None): """LSTM cell with layer normalization and recurrent dropout.""" - with vs.variable_scope(scope or type(self).__name__) as scope: # LayerNormBasicLSTMCell # pylint: disable=unused-variables + with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"): c, h = state args = array_ops.concat(1, [inputs, h]) concat = self._linear(args) diff --git a/tensorflow/python/kernel_tests/rnn_cell_test.py b/tensorflow/python/kernel_tests/rnn_cell_test.py index 69eeb116a6..e4e239169a 100644 --- a/tensorflow/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/python/kernel_tests/rnn_cell_test.py @@ -95,6 +95,20 @@ class RNNCellTest(tf.test.TestCase): res = sess.run([g, out_m], {x.name: np.array([[1., 1.]]), m.name: 0.1 * np.ones([1, 8])}) self.assertEqual(len(res), 2) + variables = tf.global_variables() + self.assertEqual(4, len(variables)) + self.assertEquals( + variables[0].op.name, + "root/multi_rnn_cell/cell_0/basic_lstm_cell/weights") + self.assertEquals( + variables[1].op.name, + "root/multi_rnn_cell/cell_0/basic_lstm_cell/biases") + self.assertEquals( + variables[2].op.name, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/weights") + self.assertEquals( + variables[3].op.name, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/biases") # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) expected_mem = np.array([[0.68967271, 0.68967271, @@ -204,6 +218,26 @@ class RNNCellTest(tf.test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) + def testLSTMCellVariables(self): + with self.test_session(): + num_units = 8 + num_proj = 6 + state_size = num_units + num_proj + batch_size = 3 + input_size = 2 + with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): + x = tf.zeros([batch_size, input_size]) + m = tf.zeros([batch_size, state_size]) + cell = tf.nn.rnn_cell.LSTMCell( + num_units=num_units, num_proj=num_proj, forget_bias=1.0, + state_is_tuple=False) + cell(x, m) # Execute to create variables + variables = tf.global_variables() + self.assertEquals(variables[0].op.name, "root/lstm_cell/weights") + self.assertEquals(variables[1].op.name, "root/lstm_cell/biases") + self.assertEquals( + variables[2].op.name, "root/lstm_cell/projection/weights") + def testOutputProjectionWrapper(self): with self.test_session() as sess: with tf.variable_scope("root", initializer=tf.constant_initializer(0.5)): @@ -354,6 +388,7 @@ class SlimRNNCellTest(tf.test.TestCase): # pylint: enable=protected-access slim_outputs, slim_state = slim_cell(inputs, initial_state) rnn_cell = tf.nn.rnn_cell.BasicRNNCell(num_units) + tf.get_variable_scope().reuse_variables() outputs, state = rnn_cell(inputs, initial_state) self.assertEqual(slim_outputs.get_shape(), outputs.get_shape()) self.assertEqual(slim_state.get_shape(), state.get_shape()) @@ -377,7 +412,7 @@ def basic_rnn_cell(inputs, state, num_units, scope=None): init_state.set_shape([batch_size, num_units]) return init_output, init_state else: - with tf.variable_scope(scope, "BasicRNNCell", [inputs, state]): + with tf.variable_scope(scope, "basic_rnn_cell", [inputs, state]): output = tf.tanh(linear([inputs, state], num_units, True)) return output, output diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 5a74adad76..d3897afb92 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -260,8 +260,8 @@ class RNNTest(tf.test.TestCase): # check that all the variables names starts # with the proper scope. tf.global_variables_initializer() - all_vars = tf.all_variables() - prefix = prefix or "RNN" + all_vars = tf.global_variables() + prefix = prefix or "rnn" scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] tf.logging.info("RNN with scope: %s (%s)" % (prefix, "scope" if use_outer_scope else "str")) @@ -333,8 +333,8 @@ class GRUTest(tf.test.TestCase): # check that all the variables names starts # with the proper scope. - all_vars = tf.all_variables() - prefix = prefix or "RNN" + all_vars = tf.global_variables() + prefix = prefix or "rnn" scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] tf.logging.info("RNN with scope: %s (%s)" % (prefix, "scope" if use_outer_scope else "str")) @@ -567,13 +567,14 @@ class LSTMTest(tf.test.TestCase): cell_tuple = tf.nn.rnn_cell.LSTMCell( num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer, state_is_tuple=True) - outputs_notuple, state_notuple = tf.nn.rnn( - cell_notuple, inputs, dtype=tf.float32, - sequence_length=sequence_length) - tf.get_variable_scope().reuse_variables() - outputs_tuple, state_tuple = tf.nn.rnn( - cell_tuple, inputs, dtype=tf.float32, - sequence_length=sequence_length) + with tf.variable_scope("root") as scope: + outputs_notuple, state_notuple = tf.nn.rnn( + cell_notuple, inputs, dtype=tf.float32, + sequence_length=sequence_length, scope=scope) + scope.reuse_variables() + outputs_tuple, state_tuple = tf.nn.rnn( + cell_tuple, inputs, dtype=tf.float32, + sequence_length=sequence_length, scope=scope) self.assertEqual(len(outputs_notuple), len(inputs)) self.assertEqual(len(outputs_tuple), len(inputs)) self.assertTrue(isinstance(state_tuple, tuple)) @@ -624,31 +625,6 @@ class LSTMTest(tf.test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) - def _testTooManyShards(self, use_gpu): - num_units = 3 - input_size = 5 - num_proj = 4 - num_proj_shards = 4 - num_unit_shards = 2 - max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=tf.Graph()): - initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed) - - inputs = max_length * [ - tf.placeholder(tf.float32, shape=(None, input_size))] - - cell = tf.nn.rnn_cell.LSTMCell( - num_units, - use_peepholes=True, - num_proj=num_proj, - num_unit_shards=num_unit_shards, - num_proj_shards=num_proj_shards, - initializer=initializer, - state_is_tuple=False) - - with self.assertRaises(ValueError): - tf.nn.rnn(cell, inputs, dtype=tf.float32) - def _testDoubleInput(self, use_gpu): num_units = 3 input_size = 5 @@ -871,13 +847,14 @@ class LSTMTest(tf.test.TestCase): cell = tf.nn.rnn_cell.LSTMCell( num_units, use_peepholes=True, num_proj=num_proj, initializer=initializer, state_is_tuple=True) - outputs_static, state_static = tf.nn.rnn( - cell, inputs, dtype=tf.float32, - sequence_length=sequence_length) - tf.get_variable_scope().reuse_variables() - outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( - cell, inputs_c, dtype=tf.float32, time_major=True, - sequence_length=sequence_length) + with tf.variable_scope("root") as scope: + outputs_static, state_static = tf.nn.rnn( + cell, inputs, dtype=tf.float32, + sequence_length=sequence_length, scope=scope) + scope.reuse_variables() + outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( + cell, inputs_c, dtype=tf.float32, time_major=True, + sequence_length=sequence_length, scope=scope) self.assertTrue(isinstance(state_static, tf.nn.rnn_cell.LSTMStateTuple)) self.assertTrue(isinstance(state_dynamic, tf.nn.rnn_cell.LSTMStateTuple)) self.assertEqual(state_static[0], state_static.c) @@ -932,13 +909,14 @@ class LSTMTest(tf.test.TestCase): self.assertEqual(test_zero[i][0].get_shape()[1], cell.state_size[i][0]) self.assertEqual(test_zero[i][1].get_shape()[1], cell.state_size[i][1]) - outputs_static, state_static = tf.nn.rnn( - cell, inputs, dtype=tf.float32, - sequence_length=sequence_length) - tf.get_variable_scope().reuse_variables() - outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( - cell, inputs_c, dtype=tf.float32, time_major=True, - sequence_length=sequence_length) + with tf.variable_scope("root") as scope: + outputs_static, state_static = tf.nn.rnn( + cell, inputs, dtype=tf.float32, + sequence_length=sequence_length, scope=scope) + scope.reuse_variables() + outputs_dynamic, state_dynamic = tf.nn.dynamic_rnn( + cell, inputs_c, dtype=tf.float32, time_major=True, + sequence_length=sequence_length, scope=scope) tf.global_variables_initializer().run() @@ -1126,10 +1104,6 @@ class LSTMTest(tf.test.TestCase): self._testProjSharding(use_gpu=False) self._testProjSharding(use_gpu=True) - def testTooManyShards(self): - self._testTooManyShards(use_gpu=False) - self._testTooManyShards(use_gpu=True) - def testShardNoShardEquivalentOutput(self): self._testShardNoShardEquivalentOutput(use_gpu=False) self._testShardNoShardEquivalentOutput(use_gpu=True) @@ -1415,8 +1389,8 @@ class BidirectionalRNNTest(tf.test.TestCase): # check that all the variables names starts # with the proper scope. tf.global_variables_initializer() - all_vars = tf.all_variables() - prefix = prefix or "BiRNN" + all_vars = tf.global_variables() + prefix = prefix or "bidirectional_rnn" scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] tf.logging.info("BiRNN with scope: %s (%s)" % (prefix, "scope" if use_outer_scope else "str")) @@ -1667,13 +1641,16 @@ class RawRNNTest(tf.test.TestCase): lambda: inputs_ta.read(time_)) return (elements_finished, next_input, next_state, emit_output, None) - outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn) + reuse_scope = tf.get_variable_scope() + + outputs_ta, final_state, _ = tf.nn.raw_rnn( + cell, loop_fn, scope=reuse_scope) outputs = outputs_ta.pack() - tf.get_variable_scope().reuse_variables() + reuse_scope.reuse_variables() outputs_dynamic_rnn, final_state_dynamic_rnn = tf.nn.dynamic_rnn( cell, inputs, time_major=True, dtype=tf.float32, - sequence_length=sequence_length) + sequence_length=sequence_length, scope=reuse_scope) variables = tf.trainable_variables() gradients = tf.gradients([outputs, final_state], [inputs] + variables) @@ -1854,8 +1831,8 @@ class RawRNNTest(tf.test.TestCase): # check that all the variables names starts # with the proper scope. - all_vars = tf.all_variables() - prefix = prefix or "RNN" + all_vars = tf.global_variables() + prefix = prefix or "rnn" scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] tf.logging.info("RNN with scope: %s (%s)" % (prefix, "scope" if use_outer_scope else "str")) @@ -1917,8 +1894,8 @@ class StateSaverRNNTest(tf.test.TestCase): # check that all the variables names starts # with the proper scope. - all_vars = tf.all_variables() - prefix = prefix or "RNN" + all_vars = tf.global_variables() + prefix = prefix or "rnn" scope_vars = [v for v in all_vars if v.name.startswith(prefix + "/")] tf.logging.info("RNN with scope: %s (%s)" % (prefix, "scope" if use_outer_scope else "str")) diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index f67f4f35e8..b1270a1937 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -113,7 +113,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None, dtype. sequence_length: Specifies the length of each sequence in inputs. An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. - scope: VariableScope for the created subgraph; defaults to "RNN". + scope: VariableScope for the created subgraph; defaults to "rnn". Returns: A pair (outputs, state) where: @@ -139,7 +139,7 @@ def rnn(cell, inputs, initial_state=None, dtype=None, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - with vs.variable_scope(scope or "RNN") as varscope: + with vs.variable_scope(scope or "rnn") as varscope: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) @@ -246,7 +246,7 @@ def state_saving_rnn(cell, inputs, state_saver, state_name, be a single string. sequence_length: (optional) An int32/int64 vector size [batch_size]. See the documentation for rnn() for more details about sequence_length. - scope: VariableScope for the created subgraph; defaults to "RNN". + scope: VariableScope for the created subgraph; defaults to "rnn". Returns: A pair (outputs, state) where: @@ -508,7 +508,8 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, either of the initial states are not provided. sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, containing the actual lengths for each of the sequences. - scope: VariableScope for the created subgraph; defaults to "BiRNN" + scope: VariableScope for the created subgraph; defaults to + "bidirectional_rnn" Returns: A tuple (outputs, output_state_fw, output_state_bw) where: @@ -531,14 +532,14 @@ def bidirectional_rnn(cell_fw, cell_bw, inputs, if not inputs: raise ValueError("inputs must not be empty") - with vs.variable_scope(scope or "BiRNN"): + with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction - with vs.variable_scope("FW") as fw_scope: + with vs.variable_scope("fw") as fw_scope: output_fw, output_state_fw = rnn(cell_fw, inputs, initial_state_fw, dtype, sequence_length, scope=fw_scope) # Backward direction - with vs.variable_scope("BW") as bw_scope: + with vs.variable_scope("bw") as bw_scope: reversed_inputs = _reverse_seq(inputs, sequence_length) tmp, output_state_bw = rnn(cell_bw, reversed_inputs, initial_state_bw, dtype, sequence_length, scope=bw_scope) @@ -610,7 +611,8 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, accepts input and emits output in batch-major form. dtype: (optional) The data type for the initial state. Required if either of the initial states are not provided. - scope: VariableScope for the created subgraph; defaults to "BiRNN" + scope: VariableScope for the created subgraph; defaults to + "bidirectional_rnn" Returns: A tuple (outputs, output_states) where: @@ -642,9 +644,9 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, if not isinstance(cell_bw, rnn_cell.RNNCell): raise TypeError("cell_bw must be an instance of RNNCell") - with vs.variable_scope(scope or "BiRNN"): + with vs.variable_scope(scope or "bidirectional_rnn"): # Forward direction - with vs.variable_scope("FW") as fw_scope: + with vs.variable_scope("fw") as fw_scope: output_fw, output_state_fw = dynamic_rnn( cell=cell_fw, inputs=inputs, sequence_length=sequence_length, initial_state=initial_state_fw, dtype=dtype, @@ -659,7 +661,7 @@ def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, time_dim = 0 batch_dim = 1 - with vs.variable_scope("BW") as bw_scope: + with vs.variable_scope("bw") as bw_scope: inputs_reverse = array_ops.reverse_sequence( input=inputs, seq_lengths=sequence_length, seq_dim=time_dim, batch_dim=batch_dim) @@ -746,7 +748,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, transposes at the beginning and end of the RNN calculation. However, most TensorFlow data is batch-major, so by default this function accepts input and emits output in batch-major form. - scope: VariableScope for the created subgraph; defaults to "RNN". + scope: VariableScope for the created subgraph; defaults to "rnn". Returns: A pair (outputs, state) where: @@ -801,7 +803,7 @@ def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - with vs.variable_scope(scope or "RNN") as varscope: + with vs.variable_scope(scope or "rnn") as varscope: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) input_shape = tuple(array_ops.shape(input_) for input_ in flat_input) @@ -1161,7 +1163,7 @@ def raw_rnn(cell, loop_fn, but needed for back prop from GPU to CPU. This allows training RNNs which would typically not fit on a single GPU, with very minimal (or no) performance penalty. - scope: VariableScope for the created subgraph; defaults to "RNN". + scope: VariableScope for the created subgraph; defaults to "rnn". Returns: A tuple `(emit_ta, final_state, final_loop_state)` where: @@ -1201,7 +1203,7 @@ def raw_rnn(cell, loop_fn, # Create a new scope in which the caching device is either # determined by the parent scope, or is set to place the cached # Variable using the same placement as for the rest of the RNN. - with vs.variable_scope(scope or "RNN") as varscope: + with vs.variable_scope(scope or "rnn") as varscope: if varscope.caching_device is None: varscope.set_caching_device(lambda op: op.device) diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py index aa521eabb1..d620177e90 100644 --- a/tensorflow/python/ops/rnn_cell.py +++ b/tensorflow/python/ops/rnn_cell.py @@ -53,6 +53,7 @@ from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.math_ops import sigmoid @@ -195,9 +196,10 @@ class BasicRNNCell(RNNCell): return self._num_units def __call__(self, inputs, state, scope=None): - """Most basic RNN: output = new_state = activation(W * input + U * state + B).""" - with vs.variable_scope(scope or type(self).__name__): # "BasicRNNCell" - output = self._activation(_linear([inputs, state], self._num_units, True)) + """Most basic RNN: output = new_state = act(W * input + U * state + B).""" + with vs.variable_scope(scope or "basic_rnn_cell"): + output = self._activation( + _linear([inputs, state], self._num_units, True, scope=scope)) return output, output @@ -220,15 +222,17 @@ class GRUCell(RNNCell): def __call__(self, inputs, state, scope=None): """Gated recurrent unit (GRU) with nunits cells.""" - with vs.variable_scope(scope or type(self).__name__): # "GRUCell" - with vs.variable_scope("Gates"): # Reset gate and update gate. + with vs.variable_scope(scope or "gru_cell"): + with vs.variable_scope("gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. - r, u = array_ops.split(1, 2, _linear([inputs, state], - 2 * self._num_units, True, 1.0)) + r, u = array_ops.split( + 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0, + scope=scope)) r, u = sigmoid(r), sigmoid(u) - with vs.variable_scope("Candidate"): + with vs.variable_scope("candidate"): c = self._activation(_linear([inputs, r * state], - self._num_units, True)) + self._num_units, True, + scope=scope)) new_h = u * state + (1 - u) * c return new_h, new_h @@ -302,13 +306,13 @@ class BasicLSTMCell(RNNCell): def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" - with vs.variable_scope(scope or type(self).__name__): # "BasicLSTMCell" + with vs.variable_scope(scope or "basic_lstm_cell"): # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(1, 2, state) - concat = _linear([inputs, h], 4 * self._num_units, True) + concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(1, 4, concat) @@ -324,42 +328,6 @@ class BasicLSTMCell(RNNCell): return new_h, new_state -def _get_concat_variable(name, shape, dtype, num_shards): - """Get a sharded variable concatenated into one tensor.""" - sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) - if len(sharded_variable) == 1: - return sharded_variable[0] - - concat_name = name + "/concat" - concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" - for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): - if value.name == concat_full_name: - return value - - concat_variable = array_ops.concat(0, sharded_variable, name=concat_name) - ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, - concat_variable) - return concat_variable - - -def _get_sharded_variable(name, shape, dtype, num_shards): - """Get a list of sharded variables with the given dtype.""" - if num_shards > shape[0]: - raise ValueError("Too many shards: shape=%s, num_shards=%d" % - (shape, num_shards)) - unit_shard_size = int(math.floor(shape[0] / num_shards)) - remaining_rows = shape[0] - unit_shard_size * num_shards - - shards = [] - for i in range(num_shards): - current_size = unit_shard_size - if i < remaining_rows: - current_size += 1 - shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:], - dtype=dtype)) - return shards - - class LSTMCell(RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -385,7 +353,7 @@ class LSTMCell(RNNCell): def __init__(self, num_units, input_size=None, use_peepholes=False, cell_clip=None, initializer=None, num_proj=None, proj_clip=None, - num_unit_shards=1, num_proj_shards=1, + num_unit_shards=None, num_proj_shards=None, forget_bias=1.0, state_is_tuple=True, activation=tanh): """Initialize the parameters for an LSTM cell. @@ -401,12 +369,12 @@ class LSTMCell(RNNCell): num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is - provided, then the projected values are clipped elementwise to within - `[-proj_clip, proj_clip]`. - num_unit_shards: How to split the weight matrix. If >1, the weight - matrix is stored across num_unit_shards. - num_proj_shards: How to split the projection matrix. If >1, the - projection matrix is stored across num_proj_shards. + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. @@ -420,6 +388,12 @@ class LSTMCell(RNNCell): "deprecated. Use state_is_tuple=True.", self) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) + if num_unit_shards is not None or num_proj_shards is not None: + logging.warn( + "%s: The num_unit_shards and proj_unit_shards parameters are " + "deprecated and will be removed in Jan 2017. " + "Use a variable scope with a partitioner instead.", self) + self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip @@ -460,7 +434,7 @@ class LSTMCell(RNNCell): `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. - scope: VariableScope for the created subgraph; defaults to "LSTMCell". + scope: VariableScope for the created subgraph; defaults to "lstm_cell". Returns: A tuple containing: @@ -489,29 +463,28 @@ class LSTMCell(RNNCell): input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(scope or type(self).__name__, - initializer=self._initializer): # "LSTMCell" - concat_w = _get_concat_variable( - "W", [input_size.value + num_proj, 4 * self._num_units], - dtype, self._num_unit_shards) - - b = vs.get_variable( - "B", shape=[4 * self._num_units], - initializer=init_ops.zeros_initializer, dtype=dtype) - + with vs.variable_scope(scope or "lstm_cell", + initializer=self._initializer) as unit_scope: + if self._num_unit_shards is not None: + unit_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_unit_shards)) # i = input_gate, j = new_input, f = forget_gate, o = output_gate - cell_inputs = array_ops.concat(1, [inputs, m_prev]) - lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) + lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True, + scope=scope) i, j, f, o = array_ops.split(1, 4, lstm_matrix) # Diagonal connections if self._use_peepholes: - w_f_diag = vs.get_variable( - "W_F_diag", shape=[self._num_units], dtype=dtype) - w_i_diag = vs.get_variable( - "W_I_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = vs.get_variable( - "W_O_diag", shape=[self._num_units], dtype=dtype) + with vs.variable_scope(unit_scope) as projection_scope: + if self._num_unit_shards is not None: + projection_scope.set_partitioner(None) + w_f_diag = vs.get_variable( + "w_f_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "w_i_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "w_o_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + @@ -531,11 +504,13 @@ class LSTMCell(RNNCell): m = sigmoid(o) * self._activation(c) if self._num_proj is not None: - concat_w_proj = _get_concat_variable( - "W_P", [self._num_units, self._num_proj], - dtype, self._num_proj_shards) + with vs.variable_scope("projection") as proj_scope: + if self._num_proj_shards is not None: + proj_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_proj_shards)) + m = _linear(m, self._num_proj, bias=False, scope=scope) - m = math_ops.matmul(m, concat_w_proj) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) @@ -585,8 +560,8 @@ class OutputProjectionWrapper(RNNCell): """Run the cell and output projection on inputs, starting from state.""" output, res_state = self._cell(inputs, state) # Default scope: "OutputProjectionWrapper" - with vs.variable_scope(scope or type(self).__name__): - projected = _linear(output, self._output_size, True) + with vs.variable_scope(scope or "output_projection_wrapper"): + projected = _linear(output, self._output_size, True, scope=scope) return projected, res_state @@ -627,8 +602,8 @@ class InputProjectionWrapper(RNNCell): def __call__(self, inputs, state, scope=None): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" - with vs.variable_scope(scope or type(self).__name__): - projected = _linear(inputs, self._num_proj, True) + with vs.variable_scope(scope or "input_projection_wrapper"): + projected = _linear(inputs, self._num_proj, True, scope=scope) return self._cell(projected, state) @@ -731,7 +706,7 @@ class EmbeddingWrapper(RNNCell): def __call__(self, inputs, state, scope=None): """Run the cell on embedded inputs.""" - with vs.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper" + with vs.variable_scope(scope or "embedding_wrapper"): # "EmbeddingWrapper" with ops.device("/cpu:0"): if self._initializer: initializer = self._initializer @@ -796,12 +771,12 @@ class MultiRNNCell(RNNCell): def __call__(self, inputs, state, scope=None): """Run this multi-layer cell on inputs, starting from state.""" - with vs.variable_scope(scope or type(self).__name__): # "MultiRNNCell" + with vs.variable_scope(scope or "multi_rnn_cell"): cur_state_pos = 0 cur_inp = inputs new_states = [] for i, cell in enumerate(self._cells): - with vs.variable_scope("Cell%d" % i): + with vs.variable_scope("cell_%d" % i): if self._state_is_tuple: if not nest.is_sequence(state): raise ValueError( @@ -872,7 +847,7 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None): output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_start: starting value to initialize the bias; 0 by default. - scope: VariableScope for the created subgraph; defaults to "Linear". + scope: (optional) Variable scope to create parameters in. Returns: A 2D Tensor with shape [batch x output_size] equal to @@ -888,30 +863,33 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None): # Calculate the total size of arguments on dimension 1. total_arg_size = 0 - shapes = [a.get_shape().as_list() for a in args] + shapes = [a.get_shape() for a in args] for shape in shapes: - if len(shape) != 2: - raise ValueError("Linear is expecting 2D arguments: %s" % str(shapes)) - if not shape[1]: - raise ValueError("Linear expects shape[1] of arguments: %s" % str(shapes)) + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %d" % (shape, shape[1])) else: - total_arg_size += shape[1] + total_arg_size += shape[1].value dtype = [a.dtype for a in args][0] # Now the computation. - with vs.variable_scope(scope or "Linear"): - matrix = vs.get_variable( - "Matrix", [total_arg_size, output_size], dtype=dtype) + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + "weights", [total_arg_size, output_size], dtype=dtype) if len(args) == 1: - res = math_ops.matmul(args[0], matrix) + res = math_ops.matmul(args[0], weights) else: - res = math_ops.matmul(array_ops.concat(1, args), matrix) + res = math_ops.matmul(array_ops.concat(1, args), weights) if not bias: return res - bias_term = vs.get_variable( - "Bias", [output_size], - dtype=dtype, - initializer=init_ops.constant_initializer( - bias_start, dtype=dtype)) - return res + bias_term + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + biases = vs.get_variable( + "biases", [output_size], + dtype=dtype, + initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) + return nn_ops.bias_add(res, biases) |