diff options
Diffstat (limited to 'tensorflow/models/rnn/ptb/ptb_word_lm.py')
-rw-r--r-- | tensorflow/models/rnn/ptb/ptb_word_lm.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py index a9e8f8ddf3..3380a4fc92 100644 --- a/tensorflow/models/rnn/ptb/ptb_word_lm.py +++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py @@ -106,12 +106,10 @@ class PTBModel(object): with tf.device("/cpu:0"): embedding = tf.get_variable("embedding", [vocab_size, size]) - inputs = tf.split( - 1, num_steps, tf.nn.embedding_lookup(embedding, self._input_data)) - inputs = [tf.squeeze(input_, [1]) for input_ in inputs] + inputs = tf.nn.embedding_lookup(embedding, self._input_data) if is_training and config.keep_prob < 1: - inputs = [tf.nn.dropout(input_, config.keep_prob) for input_ in inputs] + inputs = tf.nn.dropout(inputs, config.keep_prob) # Simplified version of tensorflow.models.rnn.rnn.py's rnn(). # This builds an unrolled LSTM for tutorial purposes only. @@ -120,14 +118,16 @@ class PTBModel(object): # The alternative version of the code below is: # # from tensorflow.models.rnn import rnn + # inputs = [tf.squeeze(input_, [1]) + # for input_ in tf.split(1, num_steps, inputs)] # outputs, states = rnn.rnn(cell, inputs, initial_state=self._initial_state) outputs = [] states = [] state = self._initial_state with tf.variable_scope("RNN"): - for time_step, input_ in enumerate(inputs): + for time_step in range(num_steps): if time_step > 0: tf.get_variable_scope().reuse_variables() - (cell_output, state) = cell(input_, state) + (cell_output, state) = cell(inputs[:, time_step, :], state) outputs.append(cell_output) states.append(state) |