aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/tutorials
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-26 14:00:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-26 14:04:35 -0700
commit1fa73c53ab95693f070ce70e6be0c644d83c163a (patch)
treeffbedf825daf1f3453c695a433c8a9cdf93f6019 /tensorflow/examples/tutorials
parentb13e96e21c1229a905a623111dd89d2bd0cba53b (diff)
Automated g4 rollback of changelist 160182040
PiperOrigin-RevId: 160190881
Diffstat (limited to 'tensorflow/examples/tutorials')
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist.py2
-rw-r--r--tensorflow/examples/tutorials/word2vec/word2vec_basic.py16
2 files changed, 7 insertions, 11 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist.py b/tensorflow/examples/tutorials/mnist/mnist.py
index 3585043a2a..d533697976 100644
--- a/tensorflow/examples/tutorials/mnist/mnist.py
+++ b/tensorflow/examples/tutorials/mnist/mnist.py
@@ -17,7 +17,7 @@
Implements the inference/loss/training pattern for model building.
-1. inference() - Builds the model as far as required for running the network
+1. inference() - Builds the model as far as is required for running the network
forward to make predictions.
2. loss() - Adds to the inference model the layers required to generate loss.
3. training() - Adds to the loss model the Ops required to generate and
diff --git a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
index aee482fda5..13e5717b0d 100644
--- a/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/examples/tutorials/word2vec/word2vec_basic.py
@@ -91,6 +91,7 @@ print('Sample data', data[:10], [reverse_dictionary[i] for i in data[:10]])
data_index = 0
+
# Step 3: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):
global data_index
@@ -100,10 +101,9 @@ def generate_batch(batch_size, num_skips, skip_window):
labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
span = 2 * skip_window + 1 # [ skip_window target skip_window ]
buffer = collections.deque(maxlen=span)
- if data_index + span > len(data):
- data_index = 0
- buffer.extend(data[data_index:data_index + span])
- data_index += span
+ for _ in range(span):
+ buffer.append(data[data_index])
+ data_index = (data_index + 1) % len(data)
for i in range(batch_size // num_skips):
target = skip_window # target label at the center of the buffer
targets_to_avoid = [skip_window]
@@ -113,12 +113,8 @@ def generate_batch(batch_size, num_skips, skip_window):
targets_to_avoid.append(target)
batch[i * num_skips + j] = buffer[skip_window]
labels[i * num_skips + j, 0] = buffer[target]
- if data_index == len(data):
- buffer[:] = data[:span]
- data_index = span
- else:
- buffer.append(data[data_index])
- data_index += 1
+ buffer.append(data[data_index])
+ data_index = (data_index + 1) % len(data)
# Backtrack a little bit to avoid skipping words in the end of a batch
data_index = (data_index + len(data) - span) % len(data)
return batch, labels