diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-27 16:33:00 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-27 16:37:09 -0700 |
commit | 50b999a8336d19400ab75aea66fe46eca2f5fe0b (patch) | |
tree | 7cba4f4af6b131c253b65ff9f2923e851184668c /tensorflow/examples/tutorials | |
parent | d6d58a3a1785785679af56c0f8f131e7312b8226 (diff) |
Merge changes from github.
PiperOrigin-RevId: 160344052
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r-- | tensorflow/examples/tutorials/mnist/mnist.py | 2 | ||||
-rw-r--r-- | tensorflow/examples/tutorials/word2vec/word2vec_basic.py | 16 |
2 files changed, 11 insertions, 7 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist.py b/tensorflow/examples/tutorials/mnist/mnist.py index d533697976..3585043a2a 100644 --- a/tensorflow/examples/tutorials/mnist/mnist.py +++ b/tensorflow/examples/tutorials/mnist/mnist.py @@ -17,7 +17,7 @@ Implements the inference/loss/training pattern for model building. -1. inference() - Builds the model as far as is required for running the network +1. inference() - Builds the model as far as required for running the network forward to make predictions. 2. loss() - Adds to the inference model the layers required to generate loss. 3. training() - Adds to the loss model the Ops required to generate and diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py index 13e5717b0d..aee482fda5 100644 --- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py +++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py @@ -91,7 +91,6 @@ print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]]) data_index = 0 - # Step 3: Function to generate a training batch for the skip-gram model. def generate_batch(batch_size, num_skips, skip_window): global data_index @@ -101,9 +100,10 @@ def generate_batch(batch_size, num_skips, skip_window): labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) span = 2 * skip_window + 1 # [ skip_window target skip_window ] buffer = collections.deque(maxlen=span) - for _ in range(span): - buffer.append(data[data_index]) - data_index = (data_index + 1) % len(data) + if data_index + span > len(data): + data_index = 0 + buffer.extend(data[data_index:data_index + span]) + data_index += span for i in range(batch_size // num_skips): target = skip_window # target label at the center of the buffer targets_to_avoid = [skip_window] @@ -113,8 +113,12 @@ def generate_batch(batch_size, num_skips, skip_window): targets_to_avoid.append(target) batch[i * num_skips + j] = buffer[skip_window] labels[i * num_skips + j, 0] = buffer[target] - buffer.append(data[data_index]) - data_index = (data_index + 1) % len(data) + if data_index == len(data): + buffer[:] = data[:span] + data_index = span + else: + buffer.append(data[data_index]) + data_index += 1 # Backtrack a little bit to avoid skipping words in the end of a batch data_index = (data_index + len(data) - span) % len(data) return batch, labels |