aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-09 10:11:07 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-09 10:11:07 -0800
commit61d3a958d6d83cb6037490d933b47621cc4009cc (patch)
tree20630337ec30cbc6d974730d3bfdd22508f6e257 /tensorflow
parent9f64983a8458700ba1aec613a755e8264b1608e0 (diff)
TensorFlow: Initial steps towards python3 support, some documentation
bug fixes -- reindents to 2 for some of the files to match our internal requirements. Thanks to Martin Andrews for the basic_usage.md suggested fix via Gerrit. Base CL: 107394029
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/g3doc/get_started/basic_usage.md20
-rw-r--r--tensorflow/g3doc/get_started/index.md2
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/fact_test.py3
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/convert_to_records.py95
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py189
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py213
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py251
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py321
-rw-r--r--tensorflow/g3doc/tutorials/mnist/input_data.py259
-rw-r--r--tensorflow/g3doc/tutorials/mnist/mnist_softmax.py3
-rw-r--r--tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py25
-rw-r--r--tensorflow/models/embedding/word2vec.py27
-rw-r--r--tensorflow/models/embedding/word2vec_optimized.py28
-rw-r--r--tensorflow/models/image/alexnet/alexnet_benchmark.py3
-rw-r--r--tensorflow/models/image/cifar10/cifar10.py5
-rw-r--r--tensorflow/models/image/cifar10/cifar10_eval.py7
-rw-r--r--tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py1
-rw-r--r--tensorflow/models/image/cifar10/cifar10_train.py1
-rw-r--r--tensorflow/models/image/mnist/convolutional.py416
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py1
-rw-r--r--tensorflow/models/rnn/seq2seq_test.py3
-rw-r--r--tensorflow/models/rnn/translate/data_utils.py23
-rw-r--r--tensorflow/models/rnn/translate/translate.py19
-rw-r--r--tensorflow/python/client/notebook.py3
-rw-r--r--tensorflow/python/framework/docs.py94
-rw-r--r--tensorflow/python/framework/gen_docs_combined.py3
-rw-r--r--tensorflow/python/framework/test_util.py17
-rw-r--r--tensorflow/python/framework/test_util_test.py5
-rw-r--r--tensorflow/python/framework/types_test.py3
-rw-r--r--tensorflow/python/kernel_tests/bias_op_test.py11
-rw-r--r--tensorflow/python/kernel_tests/cast_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py19
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/lrn_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/matmul_op_test.py39
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/random_ops_test.py21
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/reshape_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/reverse_sequence_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/sparse_matmul_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py3
-rw-r--r--tensorflow/python/ops/nn_test.py22
-rw-r--r--tensorflow/python/platform/__init__.py3
-rw-r--r--tensorflow/python/platform/app.py3
-rw-r--r--tensorflow/python/platform/default/_gfile.py18
-rw-r--r--tensorflow/python/platform/default/_googletest.py4
-rw-r--r--tensorflow/python/platform/default/_logging.py2
-rw-r--r--tensorflow/python/platform/flags.py3
-rw-r--r--tensorflow/python/platform/gfile.py3
-rw-r--r--tensorflow/python/platform/googletest.py3
-rw-r--r--tensorflow/python/platform/logging.py3
-rw-r--r--tensorflow/python/platform/parameterized.py3
-rw-r--r--tensorflow/python/platform/resource_loader.py3
-rw-r--r--tensorflow/python/platform/status_bar.py3
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py5
-rw-r--r--tensorflow/python/training/coordinator_test.py2
-rw-r--r--tensorflow/python/training/optimizer.py3
-rw-r--r--tensorflow/python/training/queue_runner.py6
-rw-r--r--tensorflow/python/training/saver.py4
-rw-r--r--tensorflow/python/training/saver_test.py6
-rw-r--r--tensorflow/python/util/protobuf/compare_test.py12
-rw-r--r--tensorflow/tensorboard/tensorboard.py9
66 files changed, 1175 insertions, 1127 deletions
diff --git a/tensorflow/g3doc/get_started/basic_usage.md b/tensorflow/g3doc/get_started/basic_usage.md
index c29f6a4179..df4f769983 100644
--- a/tensorflow/g3doc/get_started/basic_usage.md
+++ b/tensorflow/g3doc/get_started/basic_usage.md
@@ -194,9 +194,9 @@ shows a variable serving as a simple counter. See
```python
# Create a Variable, that will be initialized to the scalar value 0.
-var = tf.Variable(0, name="counter")
+state = tf.Variable(0, name="counter")
-# Create an Op to add one to `var`.
+# Create an Op to add one to `state`.
one = tf.constant(1)
new_value = tf.add(state, one)
@@ -209,14 +209,14 @@ init_op = tf.initialize_all_variables()
# Launch the graph and run the ops.
with tf.Session() as sess:
- # Run the 'init' op
- sess.run(init_op)
- # Print the initial value of 'var'
- print sess.run(var)
- # Run the op that updates 'var' and print 'var'.
- for _ in range(3):
- sess.run(update)
- print sess.run(var)
+ # Run the 'init' op
+ sess.run(init_op)
+ # Print the initial value of 'state'
+ print sess.run(state)
+ # Run the op that updates 'state' and print 'state'.
+ for _ in range(3):
+ sess.run(update)
+ print sess.run(state)
# output:
diff --git a/tensorflow/g3doc/get_started/index.md b/tensorflow/g3doc/get_started/index.md
index 6e6b41a5e9..0c58b969c7 100644
--- a/tensorflow/g3doc/get_started/index.md
+++ b/tensorflow/g3doc/get_started/index.md
@@ -2,7 +2,7 @@
Let's get you up and running with TensorFlow!
-But before we even get started, let's give you a sneak peak at what TensorFlow
+But before we even get started, let's give you a sneak peek at what TensorFlow
code looks like in the Python API, just so you have a sense of where we're
headed.
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py b/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py
index 17a7028d98..e2f44db60b 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/fact_test.py
@@ -1,4 +1,5 @@
"""Test that user ops can be used as expected."""
+from __future__ import print_function
import tensorflow.python.platform
@@ -9,7 +10,7 @@ class FactTest(tf.test.TestCase):
def test(self):
with self.test_session():
- print tf.user_ops.my_fact().eval()
+ print(tf.user_ops.my_fact().eval())
if __name__ == '__main__':
diff --git a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
index 1d510cdfa9..6d77ed50e2 100644
--- a/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
+++ b/tensorflow/g3doc/how_tos/reading_data/convert_to_records.py
@@ -1,4 +1,5 @@
"""Converts MNIST data to TFRecords file format with Example protos."""
+from __future__ import print_function
import os
import tensorflow.python.platform
@@ -32,56 +33,56 @@ def _bytes_feature(value):
def convert_to(images, labels, name):
- num_examples = labels.shape[0]
- if images.shape[0] != num_examples:
- raise ValueError("Images size %d does not match label size %d." %
- (dat.shape[0], num_examples))
- rows = images.shape[1]
- cols = images.shape[2]
- depth = images.shape[3]
-
- filename = os.path.join(FLAGS.directory, name + '.tfrecords')
- print 'Writing', filename
- writer = tf.python_io.TFRecordWriter(filename)
- for index in range(num_examples):
- image_raw = images[index].tostring()
- example = tf.train.Example(features=tf.train.Features(feature={
- 'height':_int64_feature(rows),
- 'width':_int64_feature(cols),
- 'depth':_int64_feature(depth),
- 'label':_int64_feature(int(labels[index])),
- 'image_raw':_bytes_feature(image_raw)}))
- writer.write(example.SerializeToString())
+ num_examples = labels.shape[0]
+ if images.shape[0] != num_examples:
+ raise ValueError("Images size %d does not match label size %d." %
+ (dat.shape[0], num_examples))
+ rows = images.shape[1]
+ cols = images.shape[2]
+ depth = images.shape[3]
+
+ filename = os.path.join(FLAGS.directory, name + '.tfrecords')
+ print('Writing', filename)
+ writer = tf.python_io.TFRecordWriter(filename)
+ for index in range(num_examples):
+ image_raw = images[index].tostring()
+ example = tf.train.Example(features=tf.train.Features(feature={
+ 'height': _int64_feature(rows),
+ 'width': _int64_feature(cols),
+ 'depth': _int64_feature(depth),
+ 'label': _int64_feature(int(labels[index])),
+ 'image_raw': _bytes_feature(image_raw)}))
+ writer.write(example.SerializeToString())
def main(argv):
- # Get the data.
- train_images_filename = input_data.maybe_download(
- TRAIN_IMAGES, FLAGS.directory)
- train_labels_filename = input_data.maybe_download(
- TRAIN_LABELS, FLAGS.directory)
- test_images_filename = input_data.maybe_download(
- TEST_IMAGES, FLAGS.directory)
- test_labels_filename = input_data.maybe_download(
- TEST_LABELS, FLAGS.directory)
-
- # Extract it into numpy arrays.
- train_images = input_data.extract_images(train_images_filename)
- train_labels = input_data.extract_labels(train_labels_filename)
- test_images = input_data.extract_images(test_images_filename)
- test_labels = input_data.extract_labels(test_labels_filename)
-
- # Generate a validation set.
- validation_images = train_images[:FLAGS.validation_size, :, :, :]
- validation_labels = train_labels[:FLAGS.validation_size]
- train_images = train_images[FLAGS.validation_size:, :, :, :]
- train_labels = train_labels[FLAGS.validation_size:]
-
- # Convert to Examples and write the result to TFRecords.
- convert_to(train_images, train_labels, 'train')
- convert_to(validation_images, validation_labels, 'validation')
- convert_to(test_images, test_labels, 'test')
+ # Get the data.
+ train_images_filename = input_data.maybe_download(
+ TRAIN_IMAGES, FLAGS.directory)
+ train_labels_filename = input_data.maybe_download(
+ TRAIN_LABELS, FLAGS.directory)
+ test_images_filename = input_data.maybe_download(
+ TEST_IMAGES, FLAGS.directory)
+ test_labels_filename = input_data.maybe_download(
+ TEST_LABELS, FLAGS.directory)
+
+ # Extract it into numpy arrays.
+ train_images = input_data.extract_images(train_images_filename)
+ train_labels = input_data.extract_labels(train_labels_filename)
+ test_images = input_data.extract_images(test_images_filename)
+ test_labels = input_data.extract_labels(test_labels_filename)
+
+ # Generate a validation set.
+ validation_images = train_images[:FLAGS.validation_size, :, :, :]
+ validation_labels = train_labels[:FLAGS.validation_size]
+ train_images = train_images[FLAGS.validation_size:, :, :, :]
+ train_labels = train_labels[FLAGS.validation_size:]
+
+ # Convert to Examples and write the result to TFRecords.
+ convert_to(train_images, train_labels, 'train')
+ convert_to(validation_images, validation_labels, 'validation')
+ convert_to(test_images, test_labels, 'test')
if __name__ == '__main__':
- tf.app.run()
+ tf.app.run()
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
index b2436cd2ab..7e9d8355a9 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded.py
@@ -5,6 +5,7 @@ Command to run this py_binary target:
bazel run -c opt \
<...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded
"""
+from __future__ import print_function
import os.path
import time
@@ -31,104 +32,102 @@ flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
def run_training():
- """Train MNIST for a number of epochs."""
- # Get the sets of images and labels for training, validation, and
- # test on MNIST.
- data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
-
- # Tell TensorFlow that the model will be built into the default Graph.
- with tf.Graph().as_default():
- with tf.name_scope('input'):
- # Input data
- input_images = tf.constant(data_sets.train.images)
- input_labels = tf.constant(data_sets.train.labels)
-
- image, label = tf.train.slice_input_producer(
- [input_images, input_labels], num_epochs=FLAGS.num_epochs)
- label = tf.cast(label, tf.int32)
- images, labels = tf.train.batch(
- [image, label], batch_size=FLAGS.batch_size)
-
- # Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
-
- # Add to the Graph the Ops for loss calculation.
- loss = mnist.loss(logits, labels)
-
- # Add to the Graph the Ops that calculate and apply gradients.
- train_op = mnist.training(loss, FLAGS.learning_rate)
-
- # Add the Op to compare the logits to the labels during evaluation.
- eval_correct = mnist.evaluation(logits, labels)
-
- # Build the summary operation based on the TF collection of Summaries.
- summary_op = tf.merge_all_summaries()
-
- # Create a saver for writing training checkpoints.
- saver = tf.train.Saver()
-
- # Create the op for initializing variables.
- init_op = tf.initialize_all_variables()
-
- # Create a session for running Ops on the Graph.
- sess = tf.Session()
-
- # Run the Op to initialize the variables.
- sess.run(init_op)
-
- # Instantiate a SummaryWriter to output summaries and the Graph.
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
- graph_def=sess.graph_def)
-
- # Start input enqueue threads.
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
-
- # And then after everything is built, start the training loop.
- try:
- step = 0
- while not coord.should_stop():
- start_time = time.time()
-
- # Run one step of the model.
- _, loss_value = sess.run([train_op, loss])
-
- duration = time.time() - start_time
-
- # Write the summaries and print an overview fairly often.
- if step % 100 == 0:
- # Print status to stdout.
- print 'Step %d: loss = %.2f (%.3f sec)' % (step,
- loss_value,
- duration)
- # Update the events file.
- summary_str = sess.run(summary_op)
- summary_writer.add_summary(summary_str, step)
- step += 1
-
- # Save a checkpoint periodically.
- if (step + 1) % 1000 == 0:
- print 'Saving'
- saver.save(sess, FLAGS.train_dir, global_step=step)
-
- step += 1
- except tf.errors.OutOfRangeError:
- print 'Saving'
- saver.save(sess, FLAGS.train_dir, global_step=step)
- print 'Done training for %d epochs, %d steps.' % (
- FLAGS.num_epochs, step)
- finally:
- # When done, ask the threads to stop.
- coord.request_stop()
-
- # Wait for threads to finish.
- coord.join(threads)
- sess.close()
+ """Train MNIST for a number of epochs."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ with tf.name_scope('input'):
+ # Input data
+ input_images = tf.constant(data_sets.train.images)
+ input_labels = tf.constant(data_sets.train.labels)
+
+ image, label = tf.train.slice_input_producer(
+ [input_images, input_labels], num_epochs=FLAGS.num_epochs)
+ label = tf.cast(label, tf.int32)
+ images, labels = tf.train.batch(
+ [image, label], batch_size=FLAGS.batch_size)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Create a saver for writing training checkpoints.
+ saver = tf.train.Saver()
+
+ # Create the op for initializing variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ # Run the Op to initialize the variables.
+ sess.run(init_op)
+
+ # Instantiate a SummaryWriter to output summaries and the Graph.
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # And then after everything is built, start the training loop.
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Write the summaries and print an overview fairly often.
+ if step % 100 == 0:
+ # Print status to stdout.
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ duration))
+ # Update the events file.
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+ step += 1
+
+ # Save a checkpoint periodically.
+ if (step + 1) % 1000 == 0:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
def main(_):
- run_training()
+ run_training()
if __name__ == '__main__':
- tf.app.run()
+ tf.app.run()
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
index 89abd60d0e..ef62242d1e 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -5,6 +5,7 @@ Command to run this py_binary target:
bazel run -c opt \
<...>/tensorflow/g3doc/how_tos/reading_data:fully_connected_preloaded_var
"""
+from __future__ import print_function
import os.path
import time
@@ -31,116 +32,114 @@ flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
def run_training():
- """Train MNIST for a number of epochs."""
- # Get the sets of images and labels for training, validation, and
- # test on MNIST.
- data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
-
- # Tell TensorFlow that the model will be built into the default Graph.
- with tf.Graph().as_default():
- with tf.name_scope('input'):
- # Input data
- images_initializer = tf.placeholder(
- dtype=data_sets.train.images.dtype,
- shape=data_sets.train.images.shape)
- labels_initializer = tf.placeholder(
- dtype=data_sets.train.labels.dtype,
- shape=data_sets.train.labels.shape)
- input_images = tf.Variable(
- images_initializer, trainable=False, collections=[])
- input_labels = tf.Variable(
- labels_initializer, trainable=False, collections=[])
-
- image, label = tf.train.slice_input_producer(
- [input_images, input_labels], num_epochs=FLAGS.num_epochs)
- label = tf.cast(label, tf.int32)
- images, labels = tf.train.batch(
- [image, label], batch_size=FLAGS.batch_size)
-
- # Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
-
- # Add to the Graph the Ops for loss calculation.
- loss = mnist.loss(logits, labels)
-
- # Add to the Graph the Ops that calculate and apply gradients.
- train_op = mnist.training(loss, FLAGS.learning_rate)
-
- # Add the Op to compare the logits to the labels during evaluation.
- eval_correct = mnist.evaluation(logits, labels)
-
- # Build the summary operation based on the TF collection of Summaries.
- summary_op = tf.merge_all_summaries()
-
- # Create a saver for writing training checkpoints.
- saver = tf.train.Saver()
-
- # Create the op for initializing variables.
- init_op = tf.initialize_all_variables()
-
- # Create a session for running Ops on the Graph.
- sess = tf.Session()
-
- # Run the Op to initialize the variables.
- sess.run(init_op)
- sess.run(input_images.initializer,
- feed_dict={images_initializer: data_sets.train.images})
- sess.run(input_labels.initializer,
- feed_dict={labels_initializer: data_sets.train.labels})
-
- # Instantiate a SummaryWriter to output summaries and the Graph.
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
- graph_def=sess.graph_def)
-
- # Start input enqueue threads.
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
-
- # And then after everything is built, start the training loop.
- try:
- step = 0
- while not coord.should_stop():
- start_time = time.time()
-
- # Run one step of the model.
- _, loss_value = sess.run([train_op, loss])
-
- duration = time.time() - start_time
-
- # Write the summaries and print an overview fairly often.
- if step % 100 == 0:
- # Print status to stdout.
- print 'Step %d: loss = %.2f (%.3f sec)' % (step,
- loss_value,
- duration)
- # Update the events file.
- summary_str = sess.run(summary_op)
- summary_writer.add_summary(summary_str, step)
- step += 1
-
- # Save a checkpoint periodically.
- if (step + 1) % 1000 == 0:
- print 'Saving'
- saver.save(sess, FLAGS.train_dir, global_step=step)
-
- step += 1
- except tf.errors.OutOfRangeError:
- print 'Saving'
- saver.save(sess, FLAGS.train_dir, global_step=step)
- print 'Done training for %d epochs, %d steps.' % (
- FLAGS.num_epochs, step)
- finally:
- # When done, ask the threads to stop.
- coord.request_stop()
-
- # Wait for threads to finish.
- coord.join(threads)
- sess.close()
+ """Train MNIST for a number of epochs."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ with tf.name_scope('input'):
+ # Input data
+ images_initializer = tf.placeholder(
+ dtype=data_sets.train.images.dtype,
+ shape=data_sets.train.images.shape)
+ labels_initializer = tf.placeholder(
+ dtype=data_sets.train.labels.dtype,
+ shape=data_sets.train.labels.shape)
+ input_images = tf.Variable(
+ images_initializer, trainable=False, collections=[])
+ input_labels = tf.Variable(
+ labels_initializer, trainable=False, collections=[])
+
+ image, label = tf.train.slice_input_producer(
+ [input_images, input_labels], num_epochs=FLAGS.num_epochs)
+ label = tf.cast(label, tf.int32)
+ images, labels = tf.train.batch(
+ [image, label], batch_size=FLAGS.batch_size)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Create a saver for writing training checkpoints.
+ saver = tf.train.Saver()
+
+ # Create the op for initializing variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ # Run the Op to initialize the variables.
+ sess.run(init_op)
+ sess.run(input_images.initializer,
+ feed_dict={images_initializer: data_sets.train.images})
+ sess.run(input_labels.initializer,
+ feed_dict={labels_initializer: data_sets.train.labels})
+
+ # Instantiate a SummaryWriter to output summaries and the Graph.
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # And then after everything is built, start the training loop.
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Write the summaries and print an overview fairly often.
+ if step % 100 == 0:
+ # Print status to stdout.
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ duration))
+ # Update the events file.
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+ step += 1
+
+ # Save a checkpoint periodically.
+ if (step + 1) % 1000 == 0:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
def main(_):
- run_training()
+ run_training()
if __name__ == '__main__':
- tf.app.run()
+ tf.app.run()
diff --git a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py b/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
index f1e10ca34e..e467535ffe 100644
--- a/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/g3doc/how_tos/reading_data/fully_connected_reader.py
@@ -8,6 +8,7 @@ for context.
YOU MUST run convert_to_records before running this (but you only need to
run it once).
"""
+from __future__ import print_function
import os.path
import time
@@ -35,146 +36,144 @@ VALIDATION_FILE = 'validation.tfrecords'
def read_and_decode(filename_queue):
- reader = tf.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.parse_single_example(
- serialized_example,
- dense_keys=['image_raw', 'label'],
- # Defaults are not specified since both keys are required.
- dense_types=[tf.string, tf.int64])
+ reader = tf.TFRecordReader()
+ _, serialized_example = reader.read(filename_queue)
+ features = tf.parse_single_example(
+ serialized_example,
+ dense_keys=['image_raw', 'label'],
+ # Defaults are not specified since both keys are required.
+ dense_types=[tf.string, tf.int64])
- # Convert from a scalar string tensor (whose single string has
- # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
- # [mnist.IMAGE_PIXELS].
- image = tf.decode_raw(features['image_raw'], tf.uint8)
- image.set_shape([mnist.IMAGE_PIXELS])
+ # Convert from a scalar string tensor (whose single string has
+ # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
+ # [mnist.IMAGE_PIXELS].
+ image = tf.decode_raw(features['image_raw'], tf.uint8)
+ image.set_shape([mnist.IMAGE_PIXELS])
- # OPTIONAL: Could reshape into a 28x28 image and apply distortions
- # here. Since we are not applying any distortions in this
- # example, and the next step expects the image to be flattened
- # into a vector, we don't bother.
+ # OPTIONAL: Could reshape into a 28x28 image and apply distortions
+ # here. Since we are not applying any distortions in this
+ # example, and the next step expects the image to be flattened
+ # into a vector, we don't bother.
- # Convert from [0, 255] -> [-0.5, 0.5] floats.
- image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
+ # Convert from [0, 255] -> [-0.5, 0.5] floats.
+ image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
- # Convert label from a scalar uint8 tensor to an int32 scalar.
- label = tf.cast(features['label'], tf.int32)
+ # Convert label from a scalar uint8 tensor to an int32 scalar.
+ label = tf.cast(features['label'], tf.int32)
- return image, label
+ return image, label
def inputs(train, batch_size, num_epochs):
- """Reads input data num_epochs times.
-
- Args:
- train: Selects between the training (True) and validation (False) data.
- batch_size: Number of examples per returned batch.
- num_epochs: Number of times to read the input data, or 0/None to
- train forever.
-
- Returns:
- A tuple (images, labels), where:
- * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
- in the range [-0.5, 0.5].
- * labels is an int32 tensor with shape [batch_size] with the true label,
- a number in the range [0, mnist.NUM_CLASSES).
- Note that an tf.train.QueueRunner is added to the graph, which
- must be run using e.g. tf.train.start_queue_runners().
- """
- if not num_epochs: num_epochs = None
- filename = os.path.join(FLAGS.train_dir,
- TRAIN_FILE if train else VALIDATION_FILE)
-
- with tf.name_scope('input'):
- filename_queue = tf.train.string_input_producer(
- [filename], num_epochs=num_epochs)
-
- # Even when reading in multiple threads, share the filename
- # queue.
- image, label = read_and_decode(filename_queue)
-
- # Shuffle the examples and collect them into batch_size batches.
- # (Internally uses a RandomShuffleQueue.)
- # We run this in two threads to avoid being a bottleneck.
- images, sparse_labels = tf.train.shuffle_batch(
- [image, label], batch_size=batch_size, num_threads=2,
- capacity=1000 + 3 * batch_size,
- # Ensures a minimum amount of shuffling of examples.
- min_after_dequeue=1000)
-
- return images, sparse_labels
+ """Reads input data num_epochs times.
+
+ Args:
+ train: Selects between the training (True) and validation (False) data.
+ batch_size: Number of examples per returned batch.
+ num_epochs: Number of times to read the input data, or 0/None to
+ train forever.
+
+ Returns:
+ A tuple (images, labels), where:
+ * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
+ in the range [-0.5, 0.5].
+ * labels is an int32 tensor with shape [batch_size] with the true label,
+ a number in the range [0, mnist.NUM_CLASSES).
+ Note that an tf.train.QueueRunner is added to the graph, which
+ must be run using e.g. tf.train.start_queue_runners().
+ """
+ if not num_epochs: num_epochs = None
+ filename = os.path.join(FLAGS.train_dir,
+ TRAIN_FILE if train else VALIDATION_FILE)
+
+ with tf.name_scope('input'):
+ filename_queue = tf.train.string_input_producer(
+ [filename], num_epochs=num_epochs)
+
+ # Even when reading in multiple threads, share the filename
+ # queue.
+ image, label = read_and_decode(filename_queue)
+
+ # Shuffle the examples and collect them into batch_size batches.
+ # (Internally uses a RandomShuffleQueue.)
+ # We run this in two threads to avoid being a bottleneck.
+ images, sparse_labels = tf.train.shuffle_batch(
+ [image, label], batch_size=batch_size, num_threads=2,
+ capacity=1000 + 3 * batch_size,
+ # Ensures a minimum amount of shuffling of examples.
+ min_after_dequeue=1000)
+
+ return images, sparse_labels
def run_training():
- """Train MNIST for a number of steps."""
-
- # Tell TensorFlow that the model will be built into the default Graph.
- with tf.Graph().as_default():
- # Input images and labels.
- images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
- num_epochs=FLAGS.num_epochs)
-
- # Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(images,
- FLAGS.hidden1,
- FLAGS.hidden2)
-
- # Add to the Graph the loss calculation.
- loss = mnist.loss(logits, labels)
-
- # Add to the Graph operations that train the model.
- train_op = mnist.training(loss, FLAGS.learning_rate)
-
- # The op for initializing the variables.
- init_op = tf.initialize_all_variables();
-
- # Create a session for running operations in the Graph.
- sess = tf.Session()
-
- # Initialize the variables (the trained variables and the
- # epoch counter).
- sess.run(init_op)
-
- # Start input enqueue threads.
- coord = tf.train.Coordinator()
- threads = tf.train.start_queue_runners(sess=sess, coord=coord)
-
- try:
- step = 0
- while not coord.should_stop():
- start_time = time.time()
-
- # Run one step of the model. The return values are
- # the activations from the `train_op` (which is
- # discarded) and the `loss` op. To inspect the values
- # of your ops or variables, you may include them in
- # the list passed to sess.run() and the value tensors
- # will be returned in the tuple from the call.
- _, loss_value = sess.run([train_op, loss])
-
- duration = time.time() - start_time
-
- # Print an overview fairly often.
- if step % 100 == 0:
- print 'Step %d: loss = %.2f (%.3f sec)' % (step,
- loss_value,
- duration)
- step += 1
- except tf.errors.OutOfRangeError:
- print 'Done training for %d epochs, %d steps.' % (
- FLAGS.num_epochs, step)
- finally:
- # When done, ask the threads to stop.
- coord.request_stop()
-
- # Wait for threads to finish.
- coord.join(threads)
- sess.close()
+ """Train MNIST for a number of steps."""
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ # Input images and labels.
+ images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
+ num_epochs=FLAGS.num_epochs)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images,
+ FLAGS.hidden1,
+ FLAGS.hidden2)
+
+ # Add to the Graph the loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph operations that train the model.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # The op for initializing the variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running operations in the Graph.
+ sess = tf.Session()
+
+ # Initialize the variables (the trained variables and the
+ # epoch counter).
+ sess.run(init_op)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model. The return values are
+ # the activations from the `train_op` (which is
+ # discarded) and the `loss` op. To inspect the values
+ # of your ops or variables, you may include them in
+ # the list passed to sess.run() and the value tensors
+ # will be returned in the tuple from the call.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Print an overview fairly often.
+ if step % 100 == 0:
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ duration))
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
def main(_):
- run_training()
+ run_training()
if __name__ == '__main__':
- tf.app.run()
+ tf.app.run()
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index f53531537b..f48bc3d18a 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -87,7 +87,7 @@ pixel intensity between 0 and 1, for a particular pixel in a particular image.
The corresponding labels in MNIST are numbers between 0 and 9, describing
which digit a given image is of.
-For the purposes of this tutorial, we're going to want our labels as
+For the purposes of this tutorial, we're going to want our labels
as "one-hot vectors". A one-hot vector is a vector which is 0 in most
dimensions, and 1 in a single dimension. In this case, the \\(n\\)th digit will be
represented as a vector which is 1 in the \\(n\\)th dimensions. For example, 0
@@ -319,7 +319,7 @@ data point.)
Now that we know what we want our model to do, it's very easy to have TensorFlow
train it to do so.
-Because TensorFlow know the entire graph of your computations, it
+Because TensorFlow knows the entire graph of your computations, it
can automatically use the [backpropagation
algorithm](http://colah.github.io/posts/2015-08-Backprop/)
to efficiently determine how your variables affect the cost you ask it minimize.
diff --git a/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py b/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
index 618c8f47cb..df974ce715 100644
--- a/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
+++ b/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py
@@ -7,6 +7,7 @@ MNIST tutorial:
https://tensorflow.org/tutorials/mnist/tf/index.html
"""
+from __future__ import print_function
# pylint: disable=missing-docstring
import os.path
import time
@@ -34,53 +35,53 @@ flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
def placeholder_inputs(batch_size):
- """Generate placeholder variables to represent the the input tensors.
+ """Generate placeholder variables to represent the the input tensors.
- These placeholders are used as inputs by the rest of the model building
- code and will be fed from the downloaded data in the .run() loop, below.
+ These placeholders are used as inputs by the rest of the model building
+ code and will be fed from the downloaded data in the .run() loop, below.
- Args:
- batch_size: The batch size will be baked into both placeholders.
+ Args:
+ batch_size: The batch size will be baked into both placeholders.
- Returns:
- images_placeholder: Images placeholder.
- labels_placeholder: Labels placeholder.
- """
- # Note that the shapes of the placeholders match the shapes of the full
- # image and label tensors, except the first dimension is now batch_size
- # rather than the full size of the train or test data sets.
- images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
- mnist.IMAGE_PIXELS))
- labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
- return images_placeholder, labels_placeholder
+ Returns:
+ images_placeholder: Images placeholder.
+ labels_placeholder: Labels placeholder.
+ """
+ # Note that the shapes of the placeholders match the shapes of the full
+ # image and label tensors, except the first dimension is now batch_size
+ # rather than the full size of the train or test data sets.
+ images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
+ mnist.IMAGE_PIXELS))
+ labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
+ return images_placeholder, labels_placeholder
def fill_feed_dict(data_set, images_pl, labels_pl):
- """Fills the feed_dict for training the given step.
-
- A feed_dict takes the form of:
- feed_dict = {
- <placeholder>: <tensor of values to be passed for placeholder>,
- ....
- }
-
- Args:
- data_set: The set of images and labels, from input_data.read_data_sets()
- images_pl: The images placeholder, from placeholder_inputs().
- labels_pl: The labels placeholder, from placeholder_inputs().
-
- Returns:
- feed_dict: The feed dictionary mapping from placeholders to values.
- """
- # Create the feed_dict for the placeholders filled with the next
- # `batch size ` examples.
- images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
- FLAGS.fake_data)
- feed_dict = {
- images_pl: images_feed,
- labels_pl: labels_feed,
- }
- return feed_dict
+ """Fills the feed_dict for training the given step.
+
+ A feed_dict takes the form of:
+ feed_dict = {
+ <placeholder>: <tensor of values to be passed for placeholder>,
+ ....
+ }
+
+ Args:
+ data_set: The set of images and labels, from input_data.read_data_sets()
+ images_pl: The images placeholder, from placeholder_inputs().
+ labels_pl: The labels placeholder, from placeholder_inputs().
+
+ Returns:
+ feed_dict: The feed dictionary mapping from placeholders to values.
+ """
+ # Create the feed_dict for the placeholders filled with the next
+ # `batch size ` examples.
+ images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
+ FLAGS.fake_data)
+ feed_dict = {
+ images_pl: images_feed,
+ labels_pl: labels_feed,
+ }
+ return feed_dict
def do_eval(sess,
@@ -88,132 +89,130 @@ def do_eval(sess,
images_placeholder,
labels_placeholder,
data_set):
- """Runs one evaluation against the full epoch of data.
-
- Args:
- sess: The session in which the model has been trained.
- eval_correct: The Tensor that returns the number of correct predictions.
- images_placeholder: The images placeholder.
- labels_placeholder: The labels placeholder.
- data_set: The set of images and labels to evaluate, from
- input_data.read_data_sets().
- """
- # And run one epoch of eval.
- true_count = 0 # Counts the number of correct predictions.
- steps_per_epoch = int(data_set.num_examples / FLAGS.batch_size)
- num_examples = steps_per_epoch * FLAGS.batch_size
- for step in xrange(steps_per_epoch):
- feed_dict = fill_feed_dict(data_set,
- images_placeholder,
- labels_placeholder)
- true_count += sess.run(eval_correct, feed_dict=feed_dict)
- precision = float(true_count) / float(num_examples)
- print ' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % (
- num_examples, true_count, precision)
+ """Runs one evaluation against the full epoch of data.
+
+ Args:
+ sess: The session in which the model has been trained.
+ eval_correct: The Tensor that returns the number of correct predictions.
+ images_placeholder: The images placeholder.
+ labels_placeholder: The labels placeholder.
+ data_set: The set of images and labels to evaluate, from
+ input_data.read_data_sets().
+ """
+ # And run one epoch of eval.
+ true_count = 0 # Counts the number of correct predictions.
+ steps_per_epoch = int(data_set.num_examples / FLAGS.batch_size)
+ num_examples = steps_per_epoch * FLAGS.batch_size
+ for step in xrange(steps_per_epoch):
+ feed_dict = fill_feed_dict(data_set,
+ images_placeholder,
+ labels_placeholder)
+ true_count += sess.run(eval_correct, feed_dict=feed_dict)
+ precision = float(true_count) / float(num_examples)
+ print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
+ (num_examples, true_count, precision))
def run_training():
- """Train MNIST for a number of steps."""
- # Get the sets of images and labels for training, validation, and
- # test on MNIST.
- data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
-
- # Tell TensorFlow that the model will be built into the default Graph.
- with tf.Graph().as_default():
- # Generate placeholders for the images and labels.
- images_placeholder, labels_placeholder = placeholder_inputs(
- FLAGS.batch_size)
-
- # Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(images_placeholder,
- FLAGS.hidden1,
- FLAGS.hidden2)
-
- # Add to the Graph the Ops for loss calculation.
- loss = mnist.loss(logits, labels_placeholder)
-
- # Add to the Graph the Ops that calculate and apply gradients.
- train_op = mnist.training(loss, FLAGS.learning_rate)
-
- # Add the Op to compare the logits to the labels during evaluation.
- eval_correct = mnist.evaluation(logits, labels_placeholder)
-
- # Build the summary operation based on the TF collection of Summaries.
- summary_op = tf.merge_all_summaries()
-
- # Create a saver for writing training checkpoints.
- saver = tf.train.Saver()
-
- # Create a session for running Ops on the Graph.
- sess = tf.Session()
-
- # Run the Op to initialize the variables.
- init = tf.initialize_all_variables()
- sess.run(init)
-
- # Instantiate a SummaryWriter to output summaries and the Graph.
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
- graph_def=sess.graph_def)
-
- # And then after everything is built, start the training loop.
- for step in xrange(FLAGS.max_steps):
- start_time = time.time()
-
- # Fill a feed dictionary with the actual set of images and labels
- # for this particular training step.
- feed_dict = fill_feed_dict(data_sets.train,
- images_placeholder,
- labels_placeholder)
-
- # Run one step of the model. The return values are the activations
- # from the `train_op` (which is discarded) and the `loss` Op. To
- # inspect the values of your Ops or variables, you may include them
- # in the list passed to sess.run() and the value tensors will be
- # returned in the tuple from the call.
- _, loss_value = sess.run([train_op, loss],
- feed_dict=feed_dict)
-
- duration = time.time() - start_time
-
- # Write the summaries and print an overview fairly often.
- if step % 100 == 0:
- # Print status to stdout.
- print 'Step %d: loss = %.2f (%.3f sec)' % (step,
- loss_value,
- duration)
- # Update the events file.
- summary_str = sess.run(summary_op, feed_dict=feed_dict)
- summary_writer.add_summary(summary_str, step)
-
- # Save a checkpoint and evaluate the model periodically.
- if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
- saver.save(sess, FLAGS.train_dir, global_step=step)
- # Evaluate against the training set.
- print 'Training Data Eval:'
- do_eval(sess,
- eval_correct,
- images_placeholder,
- labels_placeholder,
- data_sets.train)
- # Evaluate against the validation set.
- print 'Validation Data Eval:'
- do_eval(sess,
- eval_correct,
- images_placeholder,
- labels_placeholder,
- data_sets.validation)
- # Evaluate against the test set.
- print 'Test Data Eval:'
- do_eval(sess,
- eval_correct,
- images_placeholder,
- labels_placeholder,
- data_sets.test)
+ """Train MNIST for a number of steps."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ # Generate placeholders for the images and labels.
+ images_placeholder, labels_placeholder = placeholder_inputs(
+ FLAGS.batch_size)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images_placeholder,
+ FLAGS.hidden1,
+ FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels_placeholder)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels_placeholder)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Create a saver for writing training checkpoints.
+ saver = tf.train.Saver()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ # Run the Op to initialize the variables.
+ init = tf.initialize_all_variables()
+ sess.run(init)
+
+ # Instantiate a SummaryWriter to output summaries and the Graph.
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ # And then after everything is built, start the training loop.
+ for step in xrange(FLAGS.max_steps):
+ start_time = time.time()
+
+ # Fill a feed dictionary with the actual set of images and labels
+ # for this particular training step.
+ feed_dict = fill_feed_dict(data_sets.train,
+ images_placeholder,
+ labels_placeholder)
+
+ # Run one step of the model. The return values are the activations
+ # from the `train_op` (which is discarded) and the `loss` Op. To
+ # inspect the values of your Ops or variables, you may include them
+ # in the list passed to sess.run() and the value tensors will be
+ # returned in the tuple from the call.
+ _, loss_value = sess.run([train_op, loss],
+ feed_dict=feed_dict)
+
+ duration = time.time() - start_time
+
+ # Write the summaries and print an overview fairly often.
+ if step % 100 == 0:
+ # Print status to stdout.
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
+ # Update the events file.
+ summary_str = sess.run(summary_op, feed_dict=feed_dict)
+ summary_writer.add_summary(summary_str, step)
+
+ # Save a checkpoint and evaluate the model periodically.
+ if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+ # Evaluate against the training set.
+ print('Training Data Eval:')
+ do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_sets.train)
+ # Evaluate against the validation set.
+ print('Validation Data Eval:')
+ do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_sets.validation)
+ # Evaluate against the test set.
+ print('Test Data Eval:')
+ do_eval(sess,
+ eval_correct,
+ images_placeholder,
+ labels_placeholder,
+ data_sets.test)
def main(_):
- run_training()
+ run_training()
if __name__ == '__main__':
- tf.app.run()
+ tf.app.run()
diff --git a/tensorflow/g3doc/tutorials/mnist/input_data.py b/tensorflow/g3doc/tutorials/mnist/input_data.py
index 88892027ff..e700680aa4 100644
--- a/tensorflow/g3doc/tutorials/mnist/input_data.py
+++ b/tensorflow/g3doc/tutorials/mnist/input_data.py
@@ -1,4 +1,5 @@
"""Functions for downloading and reading MNIST data."""
+from __future__ import print_function
import gzip
import os
import urllib
@@ -9,15 +10,15 @@ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
- """Download the data from Yann's website, unless it's already here."""
- if not os.path.exists(work_directory):
- os.mkdir(work_directory)
- filepath = os.path.join(work_directory, filename)
- if not os.path.exists(filepath):
- filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
- statinfo = os.stat(filepath)
- print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
- return filepath
+ """Download the data from Yann's website, unless it's already here."""
+ if not os.path.exists(work_directory):
+ os.mkdir(work_directory)
+ filepath = os.path.join(work_directory, filename)
+ if not os.path.exists(filepath):
+ filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
+ statinfo = os.stat(filepath)
+ print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
+ return filepath
def _read32(bytestream):
@@ -26,21 +27,21 @@ def _read32(bytestream):
def extract_images(filename):
- """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
- print 'Extracting', filename
- with gzip.open(filename) as bytestream:
- magic = _read32(bytestream)
- if magic != 2051:
- raise ValueError(
- 'Invalid magic number %d in MNIST image file: %s' %
- (magic, filename))
- num_images = _read32(bytestream)
- rows = _read32(bytestream)
- cols = _read32(bytestream)
- buf = bytestream.read(rows * cols * num_images)
- data = numpy.frombuffer(buf, dtype=numpy.uint8)
- data = data.reshape(num_images, rows, cols, 1)
- return data
+ """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
+ print('Extracting', filename)
+ with gzip.open(filename) as bytestream:
+ magic = _read32(bytestream)
+ if magic != 2051:
+ raise ValueError(
+ 'Invalid magic number %d in MNIST image file: %s' %
+ (magic, filename))
+ num_images = _read32(bytestream)
+ rows = _read32(bytestream)
+ cols = _read32(bytestream)
+ buf = bytestream.read(rows * cols * num_images)
+ data = numpy.frombuffer(buf, dtype=numpy.uint8)
+ data = data.reshape(num_images, rows, cols, 1)
+ return data
def dense_to_one_hot(labels_dense, num_classes=10):
@@ -53,123 +54,123 @@ def dense_to_one_hot(labels_dense, num_classes=10):
def extract_labels(filename, one_hot=False):
- """Extract the labels into a 1D uint8 numpy array [index]."""
- print 'Extracting', filename
- with gzip.open(filename) as bytestream:
- magic = _read32(bytestream)
- if magic != 2049:
- raise ValueError(
- 'Invalid magic number %d in MNIST label file: %s' %
- (magic, filename))
- num_items = _read32(bytestream)
- buf = bytestream.read(num_items)
- labels = numpy.frombuffer(buf, dtype=numpy.uint8)
- if one_hot:
- return dense_to_one_hot(labels)
- return labels
+ """Extract the labels into a 1D uint8 numpy array [index]."""
+ print('Extracting', filename)
+ with gzip.open(filename) as bytestream:
+ magic = _read32(bytestream)
+ if magic != 2049:
+ raise ValueError(
+ 'Invalid magic number %d in MNIST label file: %s' %
+ (magic, filename))
+ num_items = _read32(bytestream)
+ buf = bytestream.read(num_items)
+ labels = numpy.frombuffer(buf, dtype=numpy.uint8)
+ if one_hot:
+ return dense_to_one_hot(labels)
+ return labels
class DataSet(object):
- def __init__(self, images, labels, fake_data=False):
- if fake_data:
- self._num_examples = 10000
- else:
- assert images.shape[0] == labels.shape[0], (
- "images.shape: %s labels.shape: %s" % (images.shape,
- labels.shape))
- self._num_examples = images.shape[0]
-
- # Convert shape from [num examples, rows, columns, depth]
- # to [num examples, rows*columns] (assuming depth == 1)
- assert images.shape[3] == 1
- images = images.reshape(images.shape[0],
- images.shape[1] * images.shape[2])
- # Convert from [0, 255] -> [0.0, 1.0].
- images = images.astype(numpy.float32)
- images = numpy.multiply(images, 1.0 / 255.0)
- self._images = images
- self._labels = labels
- self._epochs_completed = 0
- self._index_in_epoch = 0
-
- @property
- def images(self):
- return self._images
-
- @property
- def labels(self):
- return self._labels
-
- @property
- def num_examples(self):
- return self._num_examples
-
- @property
- def epochs_completed(self):
- return self._epochs_completed
-
- def next_batch(self, batch_size, fake_data=False):
- """Return the next `batch_size` examples from this data set."""
- if fake_data:
- fake_image = [1.0 for _ in xrange(784)]
- fake_label = 0
- return [fake_image for _ in xrange(batch_size)], [
- fake_label for _ in xrange(batch_size)]
- start = self._index_in_epoch
- self._index_in_epoch += batch_size
- if self._index_in_epoch > self._num_examples:
- # Finished epoch
- self._epochs_completed += 1
- # Shuffle the data
- perm = numpy.arange(self._num_examples)
- numpy.random.shuffle(perm)
- self._images = self._images[perm]
- self._labels = self._labels[perm]
- # Start next epoch
- start = 0
- self._index_in_epoch = batch_size
- assert batch_size <= self._num_examples
- end = self._index_in_epoch
- return self._images[start:end], self._labels[start:end]
+ def __init__(self, images, labels, fake_data=False):
+ if fake_data:
+ self._num_examples = 10000
+ else:
+ assert images.shape[0] == labels.shape[0], (
+ "images.shape: %s labels.shape: %s" % (images.shape,
+ labels.shape))
+ self._num_examples = images.shape[0]
+
+ # Convert shape from [num examples, rows, columns, depth]
+ # to [num examples, rows*columns] (assuming depth == 1)
+ assert images.shape[3] == 1
+ images = images.reshape(images.shape[0],
+ images.shape[1] * images.shape[2])
+ # Convert from [0, 255] -> [0.0, 1.0].
+ images = images.astype(numpy.float32)
+ images = numpy.multiply(images, 1.0 / 255.0)
+ self._images = images
+ self._labels = labels
+ self._epochs_completed = 0
+ self._index_in_epoch = 0
+
+ @property
+ def images(self):
+ return self._images
+
+ @property
+ def labels(self):
+ return self._labels
+
+ @property
+ def num_examples(self):
+ return self._num_examples
+
+ @property
+ def epochs_completed(self):
+ return self._epochs_completed
+
+ def next_batch(self, batch_size, fake_data=False):
+ """Return the next `batch_size` examples from this data set."""
+ if fake_data:
+ fake_image = [1.0 for _ in xrange(784)]
+ fake_label = 0
+ return [fake_image for _ in xrange(batch_size)], [
+ fake_label for _ in xrange(batch_size)]
+ start = self._index_in_epoch
+ self._index_in_epoch += batch_size
+ if self._index_in_epoch > self._num_examples:
+ # Finished epoch
+ self._epochs_completed += 1
+ # Shuffle the data
+ perm = numpy.arange(self._num_examples)
+ numpy.random.shuffle(perm)
+ self._images = self._images[perm]
+ self._labels = self._labels[perm]
+ # Start next epoch
+ start = 0
+ self._index_in_epoch = batch_size
+ assert batch_size <= self._num_examples
+ end = self._index_in_epoch
+ return self._images[start:end], self._labels[start:end]
def read_data_sets(train_dir, fake_data=False, one_hot=False):
- class DataSets(object):
- pass
- data_sets = DataSets()
-
- if fake_data:
- data_sets.train = DataSet([], [], fake_data=True)
- data_sets.validation = DataSet([], [], fake_data=True)
- data_sets.test = DataSet([], [], fake_data=True)
- return data_sets
+ class DataSets(object):
+ pass
+ data_sets = DataSets()
+
+ if fake_data:
+ data_sets.train = DataSet([], [], fake_data=True)
+ data_sets.validation = DataSet([], [], fake_data=True)
+ data_sets.test = DataSet([], [], fake_data=True)
+ return data_sets
- TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
- TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
- TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
- TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
- VALIDATION_SIZE = 5000
+ TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
+ TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
+ TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
+ TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
+ VALIDATION_SIZE = 5000
- local_file = maybe_download(TRAIN_IMAGES, train_dir)
- train_images = extract_images(local_file)
+ local_file = maybe_download(TRAIN_IMAGES, train_dir)
+ train_images = extract_images(local_file)
- local_file = maybe_download(TRAIN_LABELS, train_dir)
- train_labels = extract_labels(local_file, one_hot=one_hot)
+ local_file = maybe_download(TRAIN_LABELS, train_dir)
+ train_labels = extract_labels(local_file, one_hot=one_hot)
- local_file = maybe_download(TEST_IMAGES, train_dir)
- test_images = extract_images(local_file)
+ local_file = maybe_download(TEST_IMAGES, train_dir)
+ test_images = extract_images(local_file)
- local_file = maybe_download(TEST_LABELS, train_dir)
- test_labels = extract_labels(local_file, one_hot=one_hot)
+ local_file = maybe_download(TEST_LABELS, train_dir)
+ test_labels = extract_labels(local_file, one_hot=one_hot)
- validation_images = train_images[:VALIDATION_SIZE]
- validation_labels = train_labels[:VALIDATION_SIZE]
- train_images = train_images[VALIDATION_SIZE:]
- train_labels = train_labels[VALIDATION_SIZE:]
+ validation_images = train_images[:VALIDATION_SIZE]
+ validation_labels = train_labels[:VALIDATION_SIZE]
+ train_images = train_images[VALIDATION_SIZE:]
+ train_labels = train_labels[VALIDATION_SIZE:]
- data_sets.train = DataSet(train_images, train_labels)
- data_sets.validation = DataSet(validation_images, validation_labels)
- data_sets.test = DataSet(test_images, test_labels)
+ data_sets.train = DataSet(train_images, train_labels)
+ data_sets.validation = DataSet(validation_images, validation_labels)
+ data_sets.test = DataSet(test_images, test_labels)
- return data_sets
+ return data_sets
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py b/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
index 640ea29dac..a90f1ea685 100644
--- a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
+++ b/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py
@@ -2,6 +2,7 @@
See extensive documentation at ??????? (insert public URL)
"""
+from __future__ import print_function
# Import data
import input_data
@@ -30,4 +31,4 @@ for i in range(1000):
# Test trained model
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
-print accuracy.eval({x: mnist.test.images, y_: mnist.test.labels})
+print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
diff --git a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
index 632ce1f3a9..173944950c 100644
--- a/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
+++ b/tensorflow/g3doc/tutorials/word2vec/word2vec_basic.py
@@ -1,3 +1,4 @@
+from __future__ import print_function
import tensorflow.python.platform
import collections
@@ -18,15 +19,16 @@ def maybe_download(filename, expected_bytes):
filename, _ = urllib.urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
- print 'Found and verified', filename
+ print('Found and verified', filename)
else:
- print statinfo.st_size
+ print(statinfo.st_size)
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename
filename = maybe_download('text8.zip', 31344016)
+
# Read the data into a string.
def read_data(filename):
f = zipfile.ZipFile(filename)
@@ -35,7 +37,7 @@ def read_data(filename):
f.close()
words = read_data(filename)
-print 'Data size', len(words)
+print('Data size', len(words))
# Step 2: Build the dictionary and replace rare words with UNK token.
vocabulary_size = 50000
@@ -61,11 +63,12 @@ def build_dataset(words):
data, count, dictionary, reverse_dictionary = build_dataset(words)
del words # Hint to reduce memory.
-print 'Most common words (+UNK)', count[:5]
-print 'Sample data', data[:10]
+print('Most common words (+UNK)', count[:5])
+print('Sample data', data[:10])
data_index = 0
+
# Step 4: Function to generate a training batch for the skip-gram model.
def generate_batch(batch_size, num_skips, skip_window):
global data_index
@@ -93,8 +96,8 @@ def generate_batch(batch_size, num_skips, skip_window):
batch, labels = generate_batch(batch_size=8, num_skips=2, skip_window=1)
for i in range(8):
- print batch[i], '->', labels[i, 0]
- print reverse_dictionary[batch[i]], '->', reverse_dictionary[labels[i, 0]]
+ print(batch[i], '->', labels[i, 0])
+ print(reverse_dictionary[batch[i]], '->', reverse_dictionary[labels[i, 0]])
# Step 5: Build and train a skip-gram model.
@@ -155,7 +158,7 @@ num_steps = 100001
with tf.Session(graph=graph) as session:
# We must initialize all variables before we use them.
tf.initialize_all_variables().run()
- print "Initialized"
+ print("Initialized")
average_loss = 0
for step in xrange(num_steps):
@@ -172,7 +175,7 @@ with tf.Session(graph=graph) as session:
if step > 0:
average_loss = average_loss / 2000
# The average loss is an estimate of the loss over the last 2000 batches.
- print "Average loss at step ", step, ": ", average_loss
+ print("Average loss at step ", step, ": ", average_loss)
average_loss = 0
# note that this is expensive (~20% slowdown if computed every 500 steps)
@@ -186,7 +189,7 @@ with tf.Session(graph=graph) as session:
for k in xrange(top_k):
close_word = reverse_dictionary[nearest[k]]
log_str = "%s %s," % (log_str, close_word)
- print log_str
+ print(log_str)
final_embeddings = normalized_embeddings.eval()
# Step 7: Visualize the embeddings.
@@ -217,4 +220,4 @@ try:
plot_with_labels(low_dim_embs, labels)
except ImportError:
- print "Please install sklearn and matplotlib to visualize embeddings."
+ print("Please install sklearn and matplotlib to visualize embeddings.")
diff --git a/tensorflow/models/embedding/word2vec.py b/tensorflow/models/embedding/word2vec.py
index 12fb994d98..c417b086d6 100644
--- a/tensorflow/models/embedding/word2vec.py
+++ b/tensorflow/models/embedding/word2vec.py
@@ -13,6 +13,7 @@ The key ops used are:
* GradientDescentOptimizer for optimizing the loss.
* skipgram custom op that does input processing.
"""
+from __future__ import print_function
import os
import sys
@@ -168,9 +169,9 @@ class Word2Vec(object):
questions_skipped += 1
else:
questions.append(np.array(ids))
- print "Eval analogy file: ", self._options.eval_data
- print "Questions: ", len(questions)
- print "Skipped: ", questions_skipped
+ print("Eval analogy file: ", self._options.eval_data)
+ print("Questions: ", len(questions))
+ print("Skipped: ", questions_skipped)
self._analogy_questions = np.array(questions, dtype=np.int32)
def forward(self, examples, labels):
@@ -336,9 +337,9 @@ class Word2Vec(object):
(opts.vocab_words, opts.vocab_counts,
opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
opts.vocab_size = len(opts.vocab_words)
- print "Data file: ", opts.train_data
- print "Vocab size: ", opts.vocab_size - 1, " + UNK"
- print "Words per epoch: ", opts.words_per_epoch
+ print("Data file: ", opts.train_data)
+ print("Vocab size: ", opts.vocab_size - 1, " + UNK")
+ print("Words per epoch: ", opts.words_per_epoch)
self._examples = examples
self._labels = labels
self._id2word = opts.vocab_words
@@ -394,7 +395,7 @@ class Word2Vec(object):
last_words, last_time, rate = words, now, (words - last_words) / (
now - last_time)
print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" %
- (epoch, step, lr, loss, rate)),
+ (epoch, step, lr, loss, rate), end="")
sys.stdout.flush()
if now - last_summary_time > opts.summary_interval:
summary_str = self._session.run(summary_op)
@@ -447,9 +448,9 @@ class Word2Vec(object):
else:
# The correct label is not the precision@1
break
- print
- print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
- correct * 100.0 / total)
+ print()
+ print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
+ correct * 100.0 / total))
def analogy(self, w0, w1, w2):
"""Predict word w3 as in w0:w1 vs w2:w3."""
@@ -466,9 +467,9 @@ class Word2Vec(object):
vals, idx = self._session.run(
[self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
for i in xrange(len(words)):
- print "\n%s\n=====================================" % (words[i])
+ print("\n%s\n=====================================" % (words[i]))
for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
- print "%-20s %6.4f" % (self._id2word[neighbor], distance)
+ print("%-20s %6.4f" % (self._id2word[neighbor], distance))
def _start_shell(local_ns=None):
@@ -484,7 +485,7 @@ def _start_shell(local_ns=None):
def main(_):
"""Train a word2vec model."""
if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
- print "--train_data --eval_data and --save_path must be specified."
+ print("--train_data --eval_data and --save_path must be specified.")
sys.exit(1)
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
diff --git a/tensorflow/models/embedding/word2vec_optimized.py b/tensorflow/models/embedding/word2vec_optimized.py
index 38fac1651d..4d5b3fe58d 100644
--- a/tensorflow/models/embedding/word2vec_optimized.py
+++ b/tensorflow/models/embedding/word2vec_optimized.py
@@ -12,6 +12,7 @@ The key ops used are:
* neg_train custom op that efficiently calculates and applies the gradient using
true SGD.
"""
+from __future__ import print_function
import os
import sys
@@ -148,9 +149,9 @@ class Word2Vec(object):
questions_skipped += 1
else:
questions.append(np.array(ids))
- print "Eval analogy file: ", self._options.eval_data
- print "Questions: ", len(questions)
- print "Skipped: ", questions_skipped
+ print("Eval analogy file: ", self._options.eval_data)
+ print("Questions: ", len(questions))
+ print("Skipped: ", questions_skipped)
self._analogy_questions = np.array(questions, dtype=np.int32)
def build_graph(self):
@@ -167,9 +168,9 @@ class Word2Vec(object):
(opts.vocab_words, opts.vocab_counts,
opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
opts.vocab_size = len(opts.vocab_words)
- print "Data file: ", opts.train_data
- print "Vocab size: ", opts.vocab_size - 1, " + UNK"
- print "Words per epoch: ", opts.words_per_epoch
+ print("Data file: ", opts.train_data)
+ print("Vocab size: ", opts.vocab_size - 1, " + UNK")
+ print("Words per epoch: ", opts.words_per_epoch)
self._id2word = opts.vocab_words
for i, w in enumerate(self._id2word):
@@ -308,8 +309,9 @@ class Word2Vec(object):
now = time.time()
last_words, last_time, rate = words, now, (words - last_words) / (
now - last_time)
- print "Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step,
+ print("Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step,
lr, rate),
+ end="")
sys.stdout.flush()
if epoch != initial_epoch:
break
@@ -351,9 +353,9 @@ class Word2Vec(object):
else:
# The correct label is not the precision@1
break
- print
- print "Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
- correct * 100.0 / total)
+ print()
+ print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
+ correct * 100.0 / total))
def analogy(self, w0, w1, w2):
"""Predict word w3 as in w0:w1 vs w2:w3."""
@@ -370,9 +372,9 @@ class Word2Vec(object):
vals, idx = self._session.run(
[self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
for i in xrange(len(words)):
- print "\n%s\n=====================================" % (words[i])
+ print("\n%s\n=====================================" % (words[i]))
for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
- print "%-20s %6.4f" % (self._id2word[neighbor], distance)
+ print("%-20s %6.4f" % (self._id2word[neighbor], distance))
def _start_shell(local_ns=None):
@@ -388,7 +390,7 @@ def _start_shell(local_ns=None):
def main(_):
"""Train a word2vec model."""
if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
- print "--train_data --eval_data and --save_path must be specified."
+ print("--train_data --eval_data and --save_path must be specified.")
sys.exit(1)
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
diff --git a/tensorflow/models/image/alexnet/alexnet_benchmark.py b/tensorflow/models/image/alexnet/alexnet_benchmark.py
index 130948c4bf..e4be47ff38 100644
--- a/tensorflow/models/image/alexnet/alexnet_benchmark.py
+++ b/tensorflow/models/image/alexnet/alexnet_benchmark.py
@@ -14,6 +14,7 @@ Forward-backward pass:
Run on Tesla K40c: 480 +/- 48 ms / batch
Run on Titan X: 244 +/- 30 ms / batch
"""
+from __future__ import print_function
from datetime import datetime
import math
import time
@@ -31,7 +32,7 @@ tf.app.flags.DEFINE_integer('num_batches', 100,
def print_activations(t):
- print t.op.name, ' ', t.get_shape().as_list()
+ print(t.op.name, ' ', t.get_shape().as_list())
def inference(images):
diff --git a/tensorflow/models/image/cifar10/cifar10.py b/tensorflow/models/image/cifar10/cifar10.py
index 7870080820..6d79029dc8 100644
--- a/tensorflow/models/image/cifar10/cifar10.py
+++ b/tensorflow/models/image/cifar10/cifar10.py
@@ -15,6 +15,7 @@ Summary of available functions:
# Create a graph to run one step of training with respect to the loss.
train_op = train(loss, global_step)
"""
+from __future__ import print_function
# pylint: disable=missing-docstring
import gzip
import os
@@ -474,7 +475,7 @@ def maybe_download_and_extract():
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.urlretrieve(DATA_URL, filepath, reporthook=_progress)
- print
+ print()
statinfo = os.stat(filepath)
- print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
+ print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
diff --git a/tensorflow/models/image/cifar10/cifar10_eval.py b/tensorflow/models/image/cifar10/cifar10_eval.py
index 73c224191d..c8e6ec067f 100644
--- a/tensorflow/models/image/cifar10/cifar10_eval.py
+++ b/tensorflow/models/image/cifar10/cifar10_eval.py
@@ -15,6 +15,7 @@ data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
+from __future__ import print_function
from datetime import datetime
import math
import time
@@ -61,7 +62,7 @@ def eval_once(saver, summary_writer, top_k_op, summary_op):
# extract global_step from it.
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
else:
- print 'No checkpoint file found'
+ print('No checkpoint file found')
return
# Start the queue runners.
@@ -83,13 +84,13 @@ def eval_once(saver, summary_writer, top_k_op, summary_op):
# Compute precision @ 1.
precision = float(true_count) / float(total_sample_count)
- print '%s: precision @ 1 = %.3f' % (datetime.now(), precision)
+ print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
summary = tf.Summary()
summary.ParseFromString(sess.run(summary_op))
summary.value.add(tag='Precision @ 1', simple_value=precision)
summary_writer.add_summary(summary, global_step)
- except Exception, e: # pylint: disable=broad-except
+ except Exception as e: # pylint: disable=broad-except
coord.request_stop(e)
coord.request_stop()
diff --git a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
index 54bc41f444..e0396ea782 100644
--- a/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
@@ -20,6 +20,7 @@ data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
+from __future__ import print_function
from datetime import datetime
import os.path
import re
diff --git a/tensorflow/models/image/cifar10/cifar10_train.py b/tensorflow/models/image/cifar10/cifar10_train.py
index bcb6eeae58..1a70b39d57 100644
--- a/tensorflow/models/image/cifar10/cifar10_train.py
+++ b/tensorflow/models/image/cifar10/cifar10_train.py
@@ -17,6 +17,7 @@ data set, compile the program and train the model.
http://tensorflow.org/tutorials/deep_cnn/
"""
+from __future__ import print_function
from datetime import datetime
import os.path
import time
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index 8fb0e4dfb4..f9453654ef 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -4,6 +4,7 @@ This should achieve a test error of 0.8%. Please keep this model as simple and
linear as possible, it is meant as a tutorial for simple convolutional models.
Run with --self_test on the command line to exectute a short self-test.
"""
+from __future__ import print_function
import gzip
import os
import sys
@@ -31,240 +32,239 @@ FLAGS = tf.app.flags.FLAGS
def maybe_download(filename):
- """Download the data from Yann's website, unless it's already here."""
- if not os.path.exists(WORK_DIRECTORY):
- os.mkdir(WORK_DIRECTORY)
- filepath = os.path.join(WORK_DIRECTORY, filename)
- if not os.path.exists(filepath):
- filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
- statinfo = os.stat(filepath)
- print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.'
- return filepath
+ """Download the data from Yann's website, unless it's already here."""
+ if not os.path.exists(WORK_DIRECTORY):
+ os.mkdir(WORK_DIRECTORY)
+ filepath = os.path.join(WORK_DIRECTORY, filename)
+ if not os.path.exists(filepath):
+ filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath)
+ statinfo = os.stat(filepath)
+ print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
+ return filepath
def extract_data(filename, num_images):
- """Extract the images into a 4D tensor [image index, y, x, channels].
+ """Extract the images into a 4D tensor [image index, y, x, channels].
- Values are rescaled from [0, 255] down to [-0.5, 0.5].
- """
- print 'Extracting', filename
- with gzip.open(filename) as bytestream:
- bytestream.read(16)
- buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images)
- data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
- data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
- data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1)
- return data
+ Values are rescaled from [0, 255] down to [-0.5, 0.5].
+ """
+ print('Extracting', filename)
+ with gzip.open(filename) as bytestream:
+ bytestream.read(16)
+ buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images)
+ data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
+ data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
+ data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1)
+ return data
def extract_labels(filename, num_images):
- """Extract the labels into a 1-hot matrix [image index, label index]."""
- print 'Extracting', filename
- with gzip.open(filename) as bytestream:
- bytestream.read(8)
- buf = bytestream.read(1 * num_images)
- labels = numpy.frombuffer(buf, dtype=numpy.uint8)
- # Convert to dense 1-hot representation.
- return (numpy.arange(NUM_LABELS) == labels[:, None]).astype(numpy.float32)
+ """Extract the labels into a 1-hot matrix [image index, label index]."""
+ print('Extracting', filename)
+ with gzip.open(filename) as bytestream:
+ bytestream.read(8)
+ buf = bytestream.read(1 * num_images)
+ labels = numpy.frombuffer(buf, dtype=numpy.uint8)
+ # Convert to dense 1-hot representation.
+ return (numpy.arange(NUM_LABELS) == labels[:, None]).astype(numpy.float32)
def fake_data(num_images):
- """Generate a fake dataset that matches the dimensions of MNIST."""
- data = numpy.ndarray(
- shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
- dtype=numpy.float32)
- labels = numpy.zeros(shape=(num_images, NUM_LABELS), dtype=numpy.float32)
- for image in xrange(num_images):
- label = image % 2
- data[image, :, :, 0] = label - 0.5
- labels[image, label] = 1.0
- return data, labels
+ """Generate a fake dataset that matches the dimensions of MNIST."""
+ data = numpy.ndarray(
+ shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
+ dtype=numpy.float32)
+ labels = numpy.zeros(shape=(num_images, NUM_LABELS), dtype=numpy.float32)
+ for image in xrange(num_images):
+ label = image % 2
+ data[image, :, :, 0] = label - 0.5
+ labels[image, label] = 1.0
+ return data, labels
def error_rate(predictions, labels):
- """Return the error rate based on dense predictions and 1-hot labels."""
- return 100.0 - (
- 100.0 *
- numpy.sum(numpy.argmax(predictions, 1) == numpy.argmax(labels, 1)) /
- predictions.shape[0])
+ """Return the error rate based on dense predictions and 1-hot labels."""
+ return 100.0 - (
+ 100.0 *
+ numpy.sum(numpy.argmax(predictions, 1) == numpy.argmax(labels, 1)) /
+ predictions.shape[0])
def main(argv=None): # pylint: disable=unused-argument
- if FLAGS.self_test:
- print 'Running self-test.'
- train_data, train_labels = fake_data(256)
- validation_data, validation_labels = fake_data(16)
- test_data, test_labels = fake_data(256)
- num_epochs = 1
- else:
- # Get the data.
- train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
- train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
- test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
- test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
+ if FLAGS.self_test:
+ print('Running self-test.')
+ train_data, train_labels = fake_data(256)
+ validation_data, validation_labels = fake_data(16)
+ test_data, test_labels = fake_data(256)
+ num_epochs = 1
+ else:
+ # Get the data.
+ train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
+ train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
+ test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
+ test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
- # Extract it into numpy arrays.
- train_data = extract_data(train_data_filename, 60000)
- train_labels = extract_labels(train_labels_filename, 60000)
- test_data = extract_data(test_data_filename, 10000)
- test_labels = extract_labels(test_labels_filename, 10000)
+ # Extract it into numpy arrays.
+ train_data = extract_data(train_data_filename, 60000)
+ train_labels = extract_labels(train_labels_filename, 60000)
+ test_data = extract_data(test_data_filename, 10000)
+ test_labels = extract_labels(test_labels_filename, 10000)
- # Generate a validation set.
- validation_data = train_data[:VALIDATION_SIZE, :, :, :]
- validation_labels = train_labels[:VALIDATION_SIZE]
- train_data = train_data[VALIDATION_SIZE:, :, :, :]
- train_labels = train_labels[VALIDATION_SIZE:]
- num_epochs = NUM_EPOCHS
- train_size = train_labels.shape[0]
+ # Generate a validation set.
+ validation_data = train_data[:VALIDATION_SIZE, :, :, :]
+ validation_labels = train_labels[:VALIDATION_SIZE]
+ train_data = train_data[VALIDATION_SIZE:, :, :, :]
+ train_labels = train_labels[VALIDATION_SIZE:]
+ num_epochs = NUM_EPOCHS
+ train_size = train_labels.shape[0]
- # This is where training samples and labels are fed to the graph.
- # These placeholder nodes will be fed a batch of training data at each
- # training step using the {feed_dict} argument to the Run() call below.
- train_data_node = tf.placeholder(
- tf.float32,
- shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
- train_labels_node = tf.placeholder(tf.float32,
- shape=(BATCH_SIZE, NUM_LABELS))
- # For the validation and test data, we'll just hold the entire dataset in
- # one constant node.
- validation_data_node = tf.constant(validation_data)
- test_data_node = tf.constant(test_data)
+ # This is where training samples and labels are fed to the graph.
+ # These placeholder nodes will be fed a batch of training data at each
+ # training step using the {feed_dict} argument to the Run() call below.
+ train_data_node = tf.placeholder(
+ tf.float32,
+ shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
+ train_labels_node = tf.placeholder(tf.float32,
+ shape=(BATCH_SIZE, NUM_LABELS))
+ # For the validation and test data, we'll just hold the entire dataset in
+ # one constant node.
+ validation_data_node = tf.constant(validation_data)
+ test_data_node = tf.constant(test_data)
- # The variables below hold all the trainable weights. They are passed an
- # initial value which will be assigned when when we call:
- # {tf.initialize_all_variables().run()}
- conv1_weights = tf.Variable(
- tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32.
- stddev=0.1,
- seed=SEED))
- conv1_biases = tf.Variable(tf.zeros([32]))
- conv2_weights = tf.Variable(
- tf.truncated_normal([5, 5, 32, 64],
- stddev=0.1,
- seed=SEED))
- conv2_biases = tf.Variable(tf.constant(0.1, shape=[64]))
- fc1_weights = tf.Variable( # fully connected, depth 512.
- tf.truncated_normal([IMAGE_SIZE / 4 * IMAGE_SIZE / 4 * 64, 512],
- stddev=0.1,
- seed=SEED))
- fc1_biases = tf.Variable(tf.constant(0.1, shape=[512]))
- fc2_weights = tf.Variable(
- tf.truncated_normal([512, NUM_LABELS],
- stddev=0.1,
- seed=SEED))
- fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS]))
+ # The variables below hold all the trainable weights. They are passed an
+ # initial value which will be assigned when when we call:
+ # {tf.initialize_all_variables().run()}
+ conv1_weights = tf.Variable(
+ tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32.
+ stddev=0.1,
+ seed=SEED))
+ conv1_biases = tf.Variable(tf.zeros([32]))
+ conv2_weights = tf.Variable(
+ tf.truncated_normal([5, 5, 32, 64],
+ stddev=0.1,
+ seed=SEED))
+ conv2_biases = tf.Variable(tf.constant(0.1, shape=[64]))
+ fc1_weights = tf.Variable( # fully connected, depth 512.
+ tf.truncated_normal([IMAGE_SIZE / 4 * IMAGE_SIZE / 4 * 64, 512],
+ stddev=0.1,
+ seed=SEED))
+ fc1_biases = tf.Variable(tf.constant(0.1, shape=[512]))
+ fc2_weights = tf.Variable(
+ tf.truncated_normal([512, NUM_LABELS],
+ stddev=0.1,
+ seed=SEED))
+ fc2_biases = tf.Variable(tf.constant(0.1, shape=[NUM_LABELS]))
- # We will replicate the model structure for the training subgraph, as well
- # as the evaluation subgraphs, while sharing the trainable parameters.
- def model(data, train=False):
- """The Model definition."""
- # 2D convolution, with 'SAME' padding (i.e. the output feature map has
- # the same size as the input). Note that {strides} is a 4D array whose
- # shape matches the data layout: [image index, y, x, depth].
- conv = tf.nn.conv2d(data,
- conv1_weights,
- strides=[1, 1, 1, 1],
- padding='SAME')
- # Bias and rectified linear non-linearity.
- relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases))
- # Max pooling. The kernel size spec {ksize} also follows the layout of
- # the data. Here we have a pooling window of 2, and a stride of 2.
- pool = tf.nn.max_pool(relu,
- ksize=[1, 2, 2, 1],
- strides=[1, 2, 2, 1],
- padding='SAME')
- conv = tf.nn.conv2d(pool,
- conv2_weights,
- strides=[1, 1, 1, 1],
- padding='SAME')
- relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases))
- pool = tf.nn.max_pool(relu,
- ksize=[1, 2, 2, 1],
- strides=[1, 2, 2, 1],
- padding='SAME')
- # Reshape the feature map cuboid into a 2D matrix to feed it to the
- # fully connected layers.
- pool_shape = pool.get_shape().as_list()
- reshape = tf.reshape(
- pool,
- [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]])
- # Fully connected layer. Note that the '+' operation automatically
- # broadcasts the biases.
- hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases)
- # Add a 50% dropout during training only. Dropout also scales
- # activations such that no rescaling is needed at evaluation time.
- if train:
- hidden = tf.nn.dropout(hidden, 0.5, seed=SEED)
- return tf.matmul(hidden, fc2_weights) + fc2_biases
+ # We will replicate the model structure for the training subgraph, as well
+ # as the evaluation subgraphs, while sharing the trainable parameters.
+ def model(data, train=False):
+ """The Model definition."""
+ # 2D convolution, with 'SAME' padding (i.e. the output feature map has
+ # the same size as the input). Note that {strides} is a 4D array whose
+ # shape matches the data layout: [image index, y, x, depth].
+ conv = tf.nn.conv2d(data,
+ conv1_weights,
+ strides=[1, 1, 1, 1],
+ padding='SAME')
+ # Bias and rectified linear non-linearity.
+ relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases))
+ # Max pooling. The kernel size spec {ksize} also follows the layout of
+ # the data. Here we have a pooling window of 2, and a stride of 2.
+ pool = tf.nn.max_pool(relu,
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding='SAME')
+ conv = tf.nn.conv2d(pool,
+ conv2_weights,
+ strides=[1, 1, 1, 1],
+ padding='SAME')
+ relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases))
+ pool = tf.nn.max_pool(relu,
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding='SAME')
+ # Reshape the feature map cuboid into a 2D matrix to feed it to the
+ # fully connected layers.
+ pool_shape = pool.get_shape().as_list()
+ reshape = tf.reshape(
+ pool,
+ [pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]])
+ # Fully connected layer. Note that the '+' operation automatically
+ # broadcasts the biases.
+ hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases)
+ # Add a 50% dropout during training only. Dropout also scales
+ # activations such that no rescaling is needed at evaluation time.
+ if train:
+ hidden = tf.nn.dropout(hidden, 0.5, seed=SEED)
+ return tf.matmul(hidden, fc2_weights) + fc2_biases
- # Training computation: logits + cross-entropy loss.
- logits = model(train_data_node, True)
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
- logits, train_labels_node))
+ # Training computation: logits + cross-entropy loss.
+ logits = model(train_data_node, True)
+ loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
+ logits, train_labels_node))
- # L2 regularization for the fully connected parameters.
- regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
- tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases))
- # Add the regularization term to the loss.
- loss += 5e-4 * regularizers
+ # L2 regularization for the fully connected parameters.
+ regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
+ tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases))
+ # Add the regularization term to the loss.
+ loss += 5e-4 * regularizers
- # Optimizer: set up a variable that's incremented once per batch and
- # controls the learning rate decay.
- batch = tf.Variable(0)
- # Decay once per epoch, using an exponential schedule starting at 0.01.
- learning_rate = tf.train.exponential_decay(
- 0.01, # Base learning rate.
- batch * BATCH_SIZE, # Current index into the dataset.
- train_size, # Decay step.
- 0.95, # Decay rate.
- staircase=True)
- # Use simple momentum for the optimization.
- optimizer = tf.train.MomentumOptimizer(learning_rate,
- 0.9).minimize(loss,
- global_step=batch)
+ # Optimizer: set up a variable that's incremented once per batch and
+ # controls the learning rate decay.
+ batch = tf.Variable(0)
+ # Decay once per epoch, using an exponential schedule starting at 0.01.
+ learning_rate = tf.train.exponential_decay(
+ 0.01, # Base learning rate.
+ batch * BATCH_SIZE, # Current index into the dataset.
+ train_size, # Decay step.
+ 0.95, # Decay rate.
+ staircase=True)
+ # Use simple momentum for the optimization.
+ optimizer = tf.train.MomentumOptimizer(learning_rate,
+ 0.9).minimize(loss,
+ global_step=batch)
- # Predictions for the minibatch, validation set and test set.
- train_prediction = tf.nn.softmax(logits)
- # We'll compute them only once in a while by calling their {eval()} method.
- validation_prediction = tf.nn.softmax(model(validation_data_node))
- test_prediction = tf.nn.softmax(model(test_data_node))
+ # Predictions for the minibatch, validation set and test set.
+ train_prediction = tf.nn.softmax(logits)
+ # We'll compute them only once in a while by calling their {eval()} method.
+ validation_prediction = tf.nn.softmax(model(validation_data_node))
+ test_prediction = tf.nn.softmax(model(test_data_node))
- # Create a local session to run this computation.
- with tf.Session() as s:
- # Run all the initializers to prepare the trainable parameters.
- tf.initialize_all_variables().run()
- print 'Initialized!'
- # Loop through training steps.
- for step in xrange(int(num_epochs * train_size / BATCH_SIZE)):
- # Compute the offset of the current minibatch in the data.
- # Note that we could use better randomization across epochs.
- offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
- batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]
- batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
- # This dictionary maps the batch data (as a numpy array) to the
- # node in the graph is should be fed to.
- feed_dict = {train_data_node: batch_data,
- train_labels_node: batch_labels}
- # Run the graph and fetch some of the nodes.
- _, l, lr, predictions = s.run(
- [optimizer, loss, learning_rate, train_prediction],
- feed_dict=feed_dict)
- if step % 100 == 0:
- print 'Epoch %.2f' % (float(step) * BATCH_SIZE / train_size)
- print 'Minibatch loss: %.3f, learning rate: %.6f' % (l, lr)
- print 'Minibatch error: %.1f%%' % error_rate(predictions,
- batch_labels)
- print 'Validation error: %.1f%%' % error_rate(
- validation_prediction.eval(), validation_labels)
- sys.stdout.flush()
- # Finally print the result!
- test_error = error_rate(test_prediction.eval(), test_labels)
- print 'Test error: %.1f%%' % test_error
- if FLAGS.self_test:
- print 'test_error', test_error
- assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % (
- test_error,)
+ # Create a local session to run this computation.
+ with tf.Session() as s:
+ # Run all the initializers to prepare the trainable parameters.
+ tf.initialize_all_variables().run()
+ print('Initialized!')
+ # Loop through training steps.
+ for step in xrange(int(num_epochs * train_size / BATCH_SIZE)):
+ # Compute the offset of the current minibatch in the data.
+ # Note that we could use better randomization across epochs.
+ offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
+ batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]
+ batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
+ # This dictionary maps the batch data (as a numpy array) to the
+ # node in the graph is should be fed to.
+ feed_dict = {train_data_node: batch_data,
+ train_labels_node: batch_labels}
+ # Run the graph and fetch some of the nodes.
+ _, l, lr, predictions = s.run(
+ [optimizer, loss, learning_rate, train_prediction],
+ feed_dict=feed_dict)
+ if step % 100 == 0:
+ print('Epoch %.2f' % (float(step) * BATCH_SIZE / train_size))
+ print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
+ print('Minibatch error: %.1f%%' % error_rate(predictions, batch_labels))
+ print('Validation error: %.1f%%' %
+ error_rate(validation_prediction.eval(), validation_labels))
+ sys.stdout.flush()
+ # Finally print the result!
+ test_error = error_rate(test_prediction.eval(), test_labels)
+ print('Test error: %.1f%%' % test_error)
+ if FLAGS.self_test:
+ print('test_error', test_error)
+ assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % (
+ test_error,)
if __name__ == '__main__':
- tf.app.run()
+ tf.app.run()
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
index e28d3bf78c..146d05a9ee 100644
--- a/tensorflow/models/rnn/ptb/ptb_word_lm.py
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -41,6 +41,7 @@ To run:
--data_path=/tmp/simple-examples/data/ --alsologtostderr
"""
+from __future__ import print_function
import time
diff --git a/tensorflow/models/rnn/seq2seq_test.py b/tensorflow/models/rnn/seq2seq_test.py
index c5125acc21..d2949ecae2 100644
--- a/tensorflow/models/rnn/seq2seq_test.py
+++ b/tensorflow/models/rnn/seq2seq_test.py
@@ -1,4 +1,5 @@
"""Tests for functional style sequence-to-sequence models."""
+from __future__ import print_function
import math
import random
@@ -377,7 +378,7 @@ class Seq2SeqTest(tf.test.TestCase):
res = sess.run([updates[bucket], losses[bucket]], feed)
log_perp += float(res[1])
perp = math.exp(log_perp / 100)
- print "step %d avg. perp %f" % ((ep + 1)*50, perp)
+ print("step %d avg. perp %f" % ((ep + 1) * 50, perp))
self.assertLess(perp, 2.5)
if __name__ == "__main__":
diff --git a/tensorflow/models/rnn/translate/data_utils.py b/tensorflow/models/rnn/translate/data_utils.py
index 28bc54354c..628227ba8f 100644
--- a/tensorflow/models/rnn/translate/data_utils.py
+++ b/tensorflow/models/rnn/translate/data_utils.py
@@ -1,4 +1,5 @@
"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
+from __future__ import print_function
import gzip
import os
@@ -32,20 +33,20 @@ _WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
def maybe_download(directory, filename, url):
"""Download filename from url unless it's already in directory."""
if not os.path.exists(directory):
- print "Creating directory %s" % directory
+ print("Creating directory %s" % directory)
os.mkdir(directory)
filepath = os.path.join(directory, filename)
if not os.path.exists(filepath):
- print "Downloading %s to %s" % (url, filepath)
+ print("Downloading %s to %s" % (url, filepath))
filepath, _ = urllib.urlretrieve(url, filepath)
statinfo = os.stat(filepath)
- print "Succesfully downloaded", filename, statinfo.st_size, "bytes"
+ print("Succesfully downloaded", filename, statinfo.st_size, "bytes")
return filepath
def gunzip_file(gz_path, new_path):
"""Unzips from gz_path into new_path."""
- print "Unpacking %s to %s" % (gz_path, new_path)
+ print("Unpacking %s to %s" % (gz_path, new_path))
with gzip.open(gz_path, "rb") as gz_file:
with open(new_path, "w") as new_file:
for line in gz_file:
@@ -58,7 +59,7 @@ def get_wmt_enfr_train_set(directory):
if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
corpus_file = maybe_download(directory, "training-giga-fren.tar",
_WMT_ENFR_TRAIN_URL)
- print "Extracting tar file %s" % corpus_file
+ print("Extracting tar file %s" % corpus_file)
with tarfile.open(corpus_file, "r") as corpus_tar:
corpus_tar.extractall(directory)
gunzip_file(train_path + ".fr.gz", train_path + ".fr")
@@ -72,7 +73,7 @@ def get_wmt_enfr_dev_set(directory):
dev_path = os.path.join(directory, dev_name)
if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
- print "Extracting tgz file %s" % dev_file
+ print("Extracting tgz file %s" % dev_file)
with tarfile.open(dev_file, "r:gz") as dev_tar:
fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
@@ -110,13 +111,14 @@ def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
normalize_digits: Boolean; if true, all digits are replaced by 0s.
"""
if not gfile.Exists(vocabulary_path):
- print "Creating vocabulary %s from data %s" % (vocabulary_path, data_path)
+ print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
vocab = {}
with gfile.GFile(data_path, mode="r") as f:
counter = 0
for line in f:
counter += 1
- if counter % 100000 == 0: print " processing line %d" % counter
+ if counter % 100000 == 0:
+ print(" processing line %d" % counter)
tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
for w in tokens:
word = re.sub(_DIGIT_RE, "0", w) if normalize_digits else w
@@ -207,14 +209,15 @@ def data_to_token_ids(data_path, target_path, vocabulary_path,
normalize_digits: Boolean; if true, all digits are replaced by 0s.
"""
if not gfile.Exists(target_path):
- print "Tokenizing data in %s" % data_path
+ print("Tokenizing data in %s" % data_path)
vocab, _ = initialize_vocabulary(vocabulary_path)
with gfile.GFile(data_path, mode="r") as data_file:
with gfile.GFile(target_path, mode="w") as tokens_file:
counter = 0
for line in data_file:
counter += 1
- if counter % 100000 == 0: print " tokenizing line %d" % counter
+ if counter % 100000 == 0:
+ print(" tokenizing line %d" % counter)
token_ids = sentence_to_token_ids(line, vocab, tokenizer,
normalize_digits)
tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
diff --git a/tensorflow/models/rnn/translate/translate.py b/tensorflow/models/rnn/translate/translate.py
index abf4c7c57b..ec408eb127 100644
--- a/tensorflow/models/rnn/translate/translate.py
+++ b/tensorflow/models/rnn/translate/translate.py
@@ -12,6 +12,7 @@ See the following papers for more information on neural translation models.
* http://arxiv.org/abs/1409.0473
* http://arxiv.org/pdf/1412.2007v2.pdf
"""
+from __future__ import print_function
import math
import os
@@ -83,7 +84,7 @@ def read_data(source_path, target_path, max_size=None):
while source and target and (not max_size or counter < max_size):
counter += 1
if counter % 100000 == 0:
- print " reading data line %d" % counter
+ print(" reading data line %d" % counter)
sys.stdout.flush()
source_ids = [int(x) for x in source.split()]
target_ids = [int(x) for x in target.split()]
@@ -105,10 +106,10 @@ def create_model(session, forward_only):
forward_only=forward_only)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
- print "Reading model parameters from %s" % ckpt.model_checkpoint_path
+ print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
model.saver.restore(session, ckpt.model_checkpoint_path)
else:
- print "Created model with fresh parameters."
+ print("Created model with fresh parameters.")
session.run(tf.variables.initialize_all_variables())
return model
@@ -116,13 +117,13 @@ def create_model(session, forward_only):
def train():
"""Train a en->fr translation model using WMT data."""
# Prepare WMT data.
- print "Preparing WMT data in %s" % FLAGS.data_dir
+ print("Preparing WMT data in %s" % FLAGS.data_dir)
en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
with tf.Session() as sess:
# Create model.
- print "Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size)
+ print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
model = create_model(sess, False)
# Read data into buckets and compute their sizes.
@@ -182,7 +183,7 @@ def train():
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, True)
eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')
- print " eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)
+ print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
sys.stdout.flush()
@@ -222,8 +223,8 @@ def decode():
if data_utils.EOS_ID in outputs:
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
# Print out French sentence corresponding to outputs.
- print " ".join([rev_fr_vocab[output] for output in outputs])
- print "> ",
+ print(" ".join([rev_fr_vocab[output] for output in outputs]))
+ print("> ", end="")
sys.stdout.flush()
sentence = sys.stdin.readline()
@@ -231,7 +232,7 @@ def decode():
def self_test():
"""Test the translation model."""
with tf.Session() as sess:
- print "Self-test for neural translation model."
+ print("Self-test for neural translation model.")
# Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
5.0, 32, 0.3, 0.99, num_samples=8)
diff --git a/tensorflow/python/client/notebook.py b/tensorflow/python/client/notebook.py
index 1871fbc632..585c4e0f8f 100644
--- a/tensorflow/python/client/notebook.py
+++ b/tensorflow/python/client/notebook.py
@@ -12,6 +12,7 @@ Press "a" in command mode to insert cell above or "b" to insert cell below.
Your root notebooks directory is FLAGS.notebook_dir
"""
+from __future__ import print_function
import os
@@ -70,7 +71,7 @@ def main(unused_argv):
proto = "https" if notebookapp.certfile else "http"
url = "%s://%s:%d%s" % (proto, socket.gethostname(), notebookapp.port,
notebookapp.base_project_url)
- print "\nNotebook server will be publicly available at: %s\n" % url
+ print("\nNotebook server will be publicly available at: %s\n" % url)
notebookapp.start()
return
diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py
index 7e9770683c..d468d0fb68 100644
--- a/tensorflow/python/framework/docs.py
+++ b/tensorflow/python/framework/docs.py
@@ -3,6 +3,7 @@
Both updates the files in the file-system and executes g4 commands to
make sure any changes are ready to be submitted.
"""
+from __future__ import print_function
import inspect
import os
@@ -56,10 +57,10 @@ class Index(Document):
Args:
f: The output file.
"""
- print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->"
- print >>f, ""
- print >>f, "# TensorFlow Python reference documentation"
- print >>f, ""
+ print("<!-- This file is machine generated: DO NOT EDIT! -->", file=f)
+ print("", file=f)
+ print("# TensorFlow Python reference documentation", file=f)
+ print("", file=f)
fullname_f = lambda name: self._members[name][0]
anchor_f = lambda name: _get_anchor(self._module_to_name, fullname_f(name))
@@ -72,16 +73,16 @@ class Index(Document):
links = ["[`%s`](%s#%s)" % (name, full_filename, anchor_f(name))
for name in member_names]
if links:
- print >>f, "* **[%s](%s)**:" % (library.title, full_filename)
+ print("* **[%s](%s)**:" % (library.title, full_filename), file=f)
for link in links:
- print >>f, " * %s" % link
- print >>f, ""
+ print(" * %s" % link, file=f)
+ print("", file=f)
# actually include the files right here
- print >>f, '<div class="sections-order" style="display: none;">\n<!--'
+ print('<div class="sections-order" style="display: none;">\n<!--', file=f)
for filename, _ in self._filename_to_library_map:
- print >>f, "<!-- %s -->" % filename
- print >>f, "-->\n</div>"
+ print("<!-- %s -->" % filename, file=f)
+ print("-->\n</div>", file=f)
def collect_members(module_to_name):
"""Collect all symbols from a list of modules.
@@ -241,7 +242,7 @@ class Library(Document):
# functions that have the @contextlib.contextmanager decorator.
# We should do something better.
if argspec.varargs == "args" and argspec.keywords == "kwds":
- original_func = func.func_closure[0].cell_contents
+ original_func = func.__closure__[0].cell_contents
return self._generate_signature_for_function(original_func)
if argspec.defaults:
@@ -309,10 +310,10 @@ class Library(Document):
section_header = _at_start_of_section()
if section_header:
if i == 0 or lines[i-1]:
- print >>f, ""
+ print("", file=f)
# Use at least H4 to keep these out of the TOC.
- print >>f, "##### " + section_header + ":"
- print >>f, ""
+ print("##### " + section_header + ":", file=f)
+ print("", file=f)
i += 1
outputting_list = False
while i < len(lines):
@@ -325,18 +326,18 @@ class Library(Document):
if not outputting_list:
# We need to start a list. In Markdown, a blank line needs to
# precede a list.
- print >>f, ""
+ print("", file=f)
outputting_list = True
suffix = l[len(match.group()):].lstrip()
- print >>f, "* <b>`" + match.group(1) + "`</b>: " + suffix
+ print("* <b>`" + match.group(1) + "`</b>: " + suffix, file=f)
else:
# For lines that don't start with _arg_re, continue the list if it
# has enough indentation.
outputting_list &= l.startswith(" ")
- print >>f, l
+ print(l, file=f)
i += 1
else:
- print >>f, l
+ print(l, file=f)
i += 1
def _print_function(self, f, prefix, fullname, func):
@@ -345,35 +346,36 @@ class Library(Document):
if not isinstance(func, property):
heading += self._generate_signature_for_function(func)
heading += "` {#%s}" % _get_anchor(self._module_to_name, fullname)
- print >>f, heading
- print >>f, ""
+ print(heading, file=f)
+ print("", file=f)
self._print_formatted_docstring(inspect.getdoc(func), f)
- print >>f, ""
+ print("", file=f)
def _write_member_markdown_to_file(self, f, name, member):
"""Print `member` to `f`."""
if inspect.isfunction(member):
- print >>f, "- - -"
- print >>f, ""
+ print("- - -", file=f)
+ print("", file=f)
self._print_function(f, "###", name, member)
- print >>f, ""
+ print("", file=f)
elif inspect.ismethod(member):
- print >>f, "- - -"
- print >>f, ""
+ print("- - -", file=f)
+ print("", file=f)
self._print_function(f, "####", name, member)
- print >>f, ""
+ print("", file=f)
elif isinstance(member, property):
- print >>f, "- - -"
- print >>f, ""
+ print("- - -", file=f)
+ print("", file=f)
self._print_function(f, "####", name, member)
elif inspect.isclass(member):
- print >>f, "- - -"
- print >>f, ""
- print >>f, "### `class %s` {#%s}" % (
- name, _get_anchor(self._module_to_name, name))
- print >>f, ""
+ print("- - -", file=f)
+ print("", file=f)
+ print("### `class %s` {#%s}" % (name,
+ _get_anchor(self._module_to_name, name)),
+ file=f)
+ print("", file=f)
self._write_class_markdown_to_file(f, name, member)
- print >>f, ""
+ print("", file=f)
else:
raise RuntimeError("Member %s has unknown type %s" % (name, type(member)))
@@ -391,7 +393,7 @@ class Library(Document):
else:
raise ValueError("%s: unknown member `%s`" % (self._title, name))
else:
- print >>f, l
+ print(l, file=f)
def _write_class_markdown_to_file(self, f, name, cls):
"""Write the class doc to 'f'.
@@ -420,7 +422,7 @@ class Library(Document):
other_methods = {n: m for n, m in methods.iteritems()
if n in cls.__dict__}
if other_methods:
- print >>f, "\n#### Other Methods"
+ print("\n#### Other Methods", file=f)
else:
other_methods = methods
for name in sorted(other_methods):
@@ -440,15 +442,15 @@ class Library(Document):
Returns:
Dictionary of documented members.
"""
- print >>f, "<!-- This file is machine generated: DO NOT EDIT! -->"
- print >>f, ""
+ print("<!-- This file is machine generated: DO NOT EDIT! -->", file=f)
+ print("", file=f)
# TODO(touts): Do not insert these. Let the doc writer put them in
# the module docstring explicitly.
- print >>f, "#", self._title
+ print("#", self._title, file=f)
if self._prefix:
- print >>f, self._prefix
- print >>f, "[TOC]"
- print >>f, ""
+ print(self._prefix, file=f)
+ print("[TOC]", file=f)
+ print("", file=f)
if self._module is not None:
self._write_module_markdown_to_file(f, self._module)
@@ -469,10 +471,10 @@ class Library(Document):
if name in self._members and name not in self._documented:
leftovers.append(name)
if leftovers:
- print "%s: undocumented members: %d" % (self._title, len(leftovers))
- print >>f, "\n## Other Functions and Classes"
+ print("%s: undocumented members: %d" % (self._title, len(leftovers)))
+ print("\n## Other Functions and Classes", file=f)
for name in sorted(leftovers):
- print " %s" % name
+ print(" %s" % name)
self._documented.add(name)
self._mentioned.add(name)
self._write_member_markdown_to_file(f, *self._members[name])
diff --git a/tensorflow/python/framework/gen_docs_combined.py b/tensorflow/python/framework/gen_docs_combined.py
index 873e48d85d..5f62311f12 100644
--- a/tensorflow/python/framework/gen_docs_combined.py
+++ b/tensorflow/python/framework/gen_docs_combined.py
@@ -1,4 +1,5 @@
"""Updates generated docs from Python doc comments."""
+from __future__ import print_function
import os.path
@@ -109,7 +110,7 @@ def main(unused_argv):
hidden = set(_hidden_symbols)
for _, lib in libraries:
hidden.update(lib.exclude_symbols)
- print r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden))
+ print(r"hidden symbols regex = r'\b(%s)\b'" % "|".join(sorted(hidden)))
# Verify that all symbols are mentioned in some library doc.
catch_all = docs.Library(title="Catch All", module=None,
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 597a5ad829..2b683c5123 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1,5 +1,6 @@
# pylint: disable=invalid-name
"""Test utils for tensorflow."""
+from __future__ import print_function
import contextlib
import math
import re
@@ -337,14 +338,14 @@ class TensorFlowTestCase(googletest.TestCase):
if a.ndim:
x = a[np.where(cond)]
y = b[np.where(cond)]
- print "not close where = ", np.where(cond)
+ print("not close where = ", np.where(cond))
else:
# np.where is broken for scalars
x, y = a, b
- print "not close lhs = ", x
- print "not close rhs = ", y
- print "not close dif = ", np.abs(x - y)
- print "not close tol = ", atol + rtol * np.abs(y)
+ print("not close lhs = ", x)
+ print("not close rhs = ", y)
+ print("not close dif = ", np.abs(x - y))
+ print("not close tol = ", atol + rtol * np.abs(y))
np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
def assertAllEqual(self, a, b):
@@ -369,12 +370,12 @@ class TensorFlowTestCase(googletest.TestCase):
if a.ndim:
x = a[np.where(diff)]
y = b[np.where(diff)]
- print "not equal where = ", np.where(diff)
+ print("not equal where = ", np.where(diff))
else:
# np.where is broken for scalars
x, y = a, b
- print "not equal lhs = ", x
- print "not equal rhs = ", y
+ print("not equal lhs = ", x)
+ print("not equal rhs = ", y)
np.testing.assert_array_equal(a, b)
# pylint: disable=g-doc-return-or-yield
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index e0618cfea4..38b8be3c54 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.test_util."""
+from __future__ import print_function
import threading
import tensorflow.python.platform
@@ -20,9 +21,9 @@ class TestUtilTest(test_util.TensorFlowTestCase):
# The test doesn't assert anything. It ensures the py wrapper
# function is generated correctly.
if test_util.IsGoogleCudaEnabled():
- print "GoogleCuda is enabled"
+ print("GoogleCuda is enabled")
else:
- print "GoogleCuda is disabled"
+ print("GoogleCuda is disabled")
def testAssertProtoEqualsStr(self):
diff --git a/tensorflow/python/framework/types_test.py b/tensorflow/python/framework/types_test.py
index d933a28598..5c50080db3 100644
--- a/tensorflow/python/framework/types_test.py
+++ b/tensorflow/python/framework/types_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.python.framework.importer."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -130,7 +131,7 @@ class TypesTest(test_util.TensorFlowTestCase):
dtype.base_dtype == types.complex64):
continue
- print "%s: %s - %s" % (dtype, dtype.min, dtype.max)
+ print("%s: %s - %s" % (dtype, dtype.min, dtype.max))
# check some values that are known
if numpy_dtype == np.bool_:
diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py
index f3a26e2490..da706e8259 100644
--- a/tensorflow/python/kernel_tests/bias_op_test.py
+++ b/tensorflow/python/kernel_tests/bias_op_test.py
@@ -1,4 +1,5 @@
"""Functional tests for BiasAdd."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -11,8 +12,8 @@ class BiasAddTest(tf.test.TestCase):
def _npBias(self, inputs, bias):
assert len(bias.shape) == 1
- print inputs.shape
- print bias.shape
+ print(inputs.shape)
+ print(bias.shape)
assert inputs.shape[-1] == bias.shape[0]
return inputs + bias.reshape(([1] * (len(inputs.shape) - 1))
+ [bias.shape[0]])
@@ -64,7 +65,7 @@ class BiasAddTest(tf.test.TestCase):
b = tf.constant([1.3, 2.4], dtype=tf.float64)
bo = tf.nn.bias_add(t, b)
err = gradient_checker.ComputeGradientError(t, [3, 2], bo, [3, 2])
- print "bias add tensor gradient err = ", err
+ print("bias add tensor gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradientBias(self):
@@ -74,7 +75,7 @@ class BiasAddTest(tf.test.TestCase):
b = tf.constant([1.3, 2.4], dtype=tf.float64)
bo = tf.nn.bias_add(t, b)
err = gradient_checker.ComputeGradientError(b, [2], bo, [3, 2])
- print "bias add bias gradient err = ", err
+ print("bias add bias gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradientTensor4D(self):
@@ -85,7 +86,7 @@ class BiasAddTest(tf.test.TestCase):
b = tf.constant([1.3, 2.4], dtype=tf.float32)
bo = tf.nn.bias_add(t, b)
err = gradient_checker.ComputeGradientError(t, s, bo, s, x_init_value=x)
- print "bias add tensor gradient err = ", err
+ print("bias add tensor gradient err = ", err)
self.assertLess(err, 1e-3)
diff --git a/tensorflow/python/kernel_tests/cast_op_test.py b/tensorflow/python/kernel_tests/cast_op_test.py
index 21e8f71198..e67b0694c4 100644
--- a/tensorflow/python/kernel_tests/cast_op_test.py
+++ b/tensorflow/python/kernel_tests/cast_op_test.py
@@ -149,16 +149,16 @@ class CastOpTest(tf.test.TestCase):
class SparseTensorCastTest(tf.test.TestCase):
def testCast(self):
- indices = tf.constant([[0L], [1L], [2L]])
+ indices = tf.constant([[0], [1], [2]])
values = tf.constant(np.array([1, 2, 3], np.int64))
- shape = tf.constant([3L])
+ shape = tf.constant([3])
st = tf.SparseTensor(indices, values, shape)
st_cast = tf.cast(st, tf.float32)
with self.test_session():
- self.assertAllEqual(st_cast.indices.eval(), [[0L], [1L], [2L]])
+ self.assertAllEqual(st_cast.indices.eval(), [[0], [1], [2]])
self.assertAllEqual(st_cast.values.eval(),
np.array([1, 2, 3], np.float32))
- self.assertAllEqual(st_cast.shape.eval(), [3L])
+ self.assertAllEqual(st_cast.shape.eval(), [3])
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 7f5d419c98..88f37ef952 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -1,4 +1,5 @@
"""Functional tests for convolutional operations."""
+from __future__ import print_function
import math
import tensorflow.python.platform
@@ -167,8 +168,8 @@ class Conv2DTest(tf.test.TestCase):
for i in range(len(tensors)):
conv = tensors[i]
value = values[i]
- print "expected = ", expected
- print "actual = ", value
+ print("expected = ", expected)
+ print("actual = ", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
@@ -235,8 +236,8 @@ class Conv2DTest(tf.test.TestCase):
# "values" consists of two tensors for two backprops
value = sess.run(conv)
self.assertShapeEqual(value, conv)
- print "expected = ", expected
- print "actual = ", value
+ print("expected = ", expected)
+ print("actual = ", value)
self.assertArrayNear(expected, value.flatten(), 1e-5)
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
@@ -311,8 +312,8 @@ class Conv2DTest(tf.test.TestCase):
padding=padding)
value = sess.run(conv)
self.assertShapeEqual(value, conv)
- print "expected = ", expected
- print "actual = ", value
+ print("expected = ", expected)
+ print("actual = ", value)
self.assertArrayNear(expected, value.flatten(), 1e-5)
def _CompareBackFilter(self, input_sizes, filter_sizes, output_sizes,
@@ -410,7 +411,7 @@ class Conv2DTest(tf.test.TestCase):
else:
err = gc.ComputeGradientError(filter_tensor, filter_shape,
conv, output_shape)
- print "conv_2d gradient error = ", err
+ print("conv_2d gradient error = ", err)
self.assertLess(err, tolerance)
def testInputGradientValidPaddingStrideOne(self):
@@ -838,7 +839,7 @@ class DepthwiseConv2DTest(tf.test.TestCase):
conv = tf.nn.depthwise_conv2d(t1, t2, strides=[1, stride, stride, 1],
padding=padding)
value = sess.run(conv)
- print "value = ", value
+ print("value = ", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
@@ -935,7 +936,7 @@ class SeparableConv2DTest(tf.test.TestCase):
conv = tf.nn.separable_conv2d(t1, f1, f2, strides=[1, stride, stride, 1],
padding=padding)
value = sess.run(conv)
- print "value = ", value
+ print("value = ", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 99aa2453dc..50755b6c46 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -1,4 +1,5 @@
"""Functional tests for ops used with embeddings."""
+from __future__ import print_function
import itertools
import tensorflow.python.platform
@@ -160,7 +161,7 @@ class EmbeddingLookupTest(tf.test.TestCase):
id_vals = np.array([0, 0])
ids = tf.constant(list(id_vals), dtype=tf.int32)
- print "Construct ids", ids.get_shape()
+ print("Construct ids", ids.get_shape())
embedding = tf.nn.embedding_lookup(p, ids)
tf_result = embedding.eval(feed_dict=feed_dict)
diff --git a/tensorflow/python/kernel_tests/lrn_op_test.py b/tensorflow/python/kernel_tests/lrn_op_test.py
index 7a3bb67938..85ef65b653 100644
--- a/tensorflow/python/kernel_tests/lrn_op_test.py
+++ b/tensorflow/python/kernel_tests/lrn_op_test.py
@@ -1,4 +1,5 @@
"""Tests for local response normalization."""
+from __future__ import print_function
import copy
import tensorflow.python.platform
@@ -89,7 +90,7 @@ class LRNOpTest(tf.test.TestCase):
inp, name="lrn", depth_radius=lrn_depth_radius, bias=bias,
alpha=alpha, beta=beta)
err = ComputeGradientError(inp, shape, lrn_op, shape)
- print "LRN Gradient error ", err
+ print("LRN Gradient error ", err)
self.assertLess(err, 1e-4)
def testGradients(self):
diff --git a/tensorflow/python/kernel_tests/matmul_op_test.py b/tensorflow/python/kernel_tests/matmul_op_test.py
index 5aeb736b9b..c38a2a91f1 100644
--- a/tensorflow/python/kernel_tests/matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/matmul_op_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.math_ops.matmul."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -111,25 +112,25 @@ class MatMulTest(tf.test.TestCase):
self._testCpuMatmul(x, y, True, True)
def testMatMul_OutEmpty_A(self):
- n, k, m = 0, 8, 3
- x = self._randMatrix(n, k, np.float32)
- y = self._randMatrix(k, m, np.float32)
- self._testCpuMatmul(x, y)
- self._testGpuMatmul(x, y)
+ n, k, m = 0, 8, 3
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
def testMatMul_OutEmpty_B(self):
- n, k, m = 3, 8, 0
- x = self._randMatrix(n, k, np.float32)
- y = self._randMatrix(k, m, np.float32)
- self._testCpuMatmul(x, y)
- self._testGpuMatmul(x, y)
+ n, k, m = 3, 8, 0
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
def testMatMul_Inputs_Empty(self):
- n, k, m = 3, 0, 4
- x = self._randMatrix(n, k, np.float32)
- y = self._randMatrix(k, m, np.float32)
- self._testCpuMatmul(x, y)
- self._testGpuMatmul(x, y)
+ n, k, m = 3, 0, 4
+ x = self._randMatrix(n, k, np.float32)
+ y = self._randMatrix(k, m, np.float32)
+ self._testCpuMatmul(x, y)
+ self._testGpuMatmul(x, y)
# TODO(zhifengc): Figures out how to test matmul gradients on GPU.
@@ -143,7 +144,7 @@ class MatMulGradientTest(tf.test.TestCase):
shape=[2, 4], dtype=tf.float64, name="y")
m = tf.matmul(x, y, name="matmul")
err = gc.ComputeGradientError(x, [3, 2], m, [3, 4])
- print "matmul input0 gradient err = ", err
+ print("matmul input0 gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradientInput1(self):
@@ -154,7 +155,7 @@ class MatMulGradientTest(tf.test.TestCase):
shape=[2, 4], dtype=tf.float64, name="y")
m = tf.matmul(x, y, name="matmul")
err = gc.ComputeGradientError(y, [2, 4], m, [3, 4])
- print "matmul input1 gradient err = ", err
+ print("matmul input1 gradient err = ", err)
self.assertLess(err, 1e-10)
def _VerifyInput0(self, transpose_a, transpose_b):
@@ -171,7 +172,7 @@ class MatMulGradientTest(tf.test.TestCase):
shape=shape_y, dtype=tf.float64, name="y")
m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul")
err = gc.ComputeGradientError(x, shape_x, m, [3, 4])
- print "matmul input0 gradient err = ", err
+ print("matmul input0 gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradientInput0WithTranspose(self):
@@ -193,7 +194,7 @@ class MatMulGradientTest(tf.test.TestCase):
shape=shape_y, dtype=tf.float64, name="y")
m = tf.matmul(x, y, transpose_a, transpose_b, name="matmul")
err = gc.ComputeGradientError(y, shape_y, m, [3, 4])
- print "matmul input1 gradient err = ", err
+ print("matmul input1 gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradientInput1WithTranspose(self):
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index b9a65726ee..5a35fd17fc 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -1,4 +1,5 @@
"""Functional tests for pooling operations."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -420,7 +421,7 @@ class PoolingTest(tf.test.TestCase):
err = gc.ComputeGradientError(
input_tensor, input_sizes, t, output_sizes,
x_init_value=x_init_value, delta=1e-2)
- print "%s gradient error = " % func_name, err
+ print("%s gradient error = " % func_name, err)
self.assertLess(err, err_margin)
def _testMaxPoolGradValidPadding1_1(self, use_gpu):
diff --git a/tensorflow/python/kernel_tests/random_ops_test.py b/tensorflow/python/kernel_tests/random_ops_test.py
index 311f0e3e5e..aa107a22de 100644
--- a/tensorflow/python/kernel_tests/random_ops_test.py
+++ b/tensorflow/python/kernel_tests/random_ops_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.random_ops."""
+from __future__ import print_function
import tensorflow.python.platform
@@ -31,9 +32,9 @@ class RandomNormalTest(tf.test.TestCase):
# Number of different samples.
count = (x == y).sum()
if count >= 10:
- print "x = ", x
- print "y = ", y
- print "count = ", count
+ print("x = ", x)
+ print("y = ", y)
+ print("count = ", count)
self.assertTrue(count < 10)
# Checks that the CPU and GPU implementation returns the same results,
@@ -89,9 +90,9 @@ class TruncatedNormalTest(tf.test.TestCase):
# Number of different samples.
count = (x == y).sum()
if count >= 10:
- print "x = ", x
- print "y = ", y
- print "count = ", count
+ print("x = ", x)
+ print("y = ", y)
+ print("count = ", count)
self.assertTrue(count < 10)
# Checks that the CPU and GPU implementation returns the same results,
@@ -122,7 +123,7 @@ class TruncatedNormalTest(tf.test.TestCase):
stddev = 3.0
sampler = self._Sampler(100000, 0.0, stddev, dt, use_gpu=use_gpu)
x = sampler()
- print "std(x)", np.std(x), abs(np.std(x) / stddev - 0.85)
+ print("std(x)", np.std(x), abs(np.std(x) / stddev - 0.85))
self.assertTrue(abs(np.std(x) / stddev - 0.85) < 0.04)
def testNoCSE(self):
@@ -167,9 +168,9 @@ class RandomUniformTest(tf.test.TestCase):
y = sampler()
count = (x == y).sum()
if count >= 10:
- print "x = ", x
- print "y = ", y
- print "count = ", count
+ print("x = ", x)
+ print("y = ", y)
+ print("count = ", count)
self.assertTrue(count < 10)
# Checks that the CPU and GPU implementation returns the same results,
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index a4b353f253..be79dd40ac 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -1,4 +1,5 @@
"""Tests for Relu and ReluGrad."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -49,7 +50,7 @@ class ReluTest(tf.test.TestCase):
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
- print "relu (float) gradient err = ", err
+ print("relu (float) gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradientNaN(self):
@@ -80,7 +81,7 @@ class ReluTest(tf.test.TestCase):
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float64, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
- print "relu (double) gradient err = ", err
+ print("relu (double) gradient err = ", err)
self.assertLess(err, 1e-10)
def testGradGradFloat(self):
@@ -95,7 +96,7 @@ class ReluTest(tf.test.TestCase):
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5],
x_init_value=x_init)
- print "relu (float) gradient of gradient err = ", err
+ print("relu (float) gradient of gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradGradDouble(self):
@@ -110,7 +111,7 @@ class ReluTest(tf.test.TestCase):
dtype=np.float64, order="F")
err = gc.ComputeGradientError(x, [2, 5], z[0], [2, 5],
x_init_value=x_init)
- print "relu (double) gradient of gradient err = ", err
+ print("relu (double) gradient of gradient err = ", err)
self.assertLess(err, 1e-10)
@@ -160,7 +161,7 @@ class Relu6Test(tf.test.TestCase):
[[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
- print "relu6 (float) gradient err = ", err
+ print("relu6 (float) gradient err = ", err)
self.assertLess(err, 1e-4)
def testGradientDouble(self):
@@ -173,7 +174,7 @@ class Relu6Test(tf.test.TestCase):
[[-0.9, -0.7, -0.5, -0.3, -0.1], [6.1, 6.3, 6.5, 6.7, 6.9]],
dtype=np.float64, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
- print "relu6 (double) gradient err = ", err
+ print("relu6 (double) gradient err = ", err)
self.assertLess(err, 1e-10)
diff --git a/tensorflow/python/kernel_tests/reshape_op_test.py b/tensorflow/python/kernel_tests/reshape_op_test.py
index 65b0e6d4bf..3c91db1221 100644
--- a/tensorflow/python/kernel_tests/reshape_op_test.py
+++ b/tensorflow/python/kernel_tests/reshape_op_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.reshape_op."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -64,7 +65,7 @@ class ReshapeTest(tf.test.TestCase):
reshape_out = tf.reshape(input_tensor, [1, 8, 3])
err = gc.ComputeGradientError(input_tensor, s,
reshape_out, s, x_init_value=x)
- print "Reshape gradient error = " % err
+ print("Reshape gradient error = " % err)
self.assertLess(err, 1e-3)
def testFloatEmpty(self):
diff --git a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
index 7cfbcd7946..e7d8e70ae8 100644
--- a/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
+++ b/tensorflow/python/kernel_tests/reverse_sequence_op_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.reverse_sequence_op."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -86,7 +87,7 @@ class ReverseSequenceTest(tf.test.TestCase):
reverse_sequence_out,
x.shape,
x_init_value=x)
- print "ReverseSequence gradient error = %g" % err
+ print("ReverseSequence gradient error = %g" % err)
self.assertLess(err, 1e-8)
def testShapeFunctionEdgeCases(self):
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index ac97180dbe..ad5425e6b5 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -1,4 +1,5 @@
"""Tests for various tensorflow.ops.tf."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -351,7 +352,7 @@ class TileTest(tf.test.TestCase):
grad_shape = list(np.array(multiples) * np.array(inp.shape))
err = gc.ComputeGradientError(a, list(input_shape), tiled, grad_shape,
x_init_value=inp)
- print "tile(float) error = ", err
+ print("tile(float) error = ", err)
self.assertLess(err, 1e-3)
def testGradientRandom(self):
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index 25b68aa659..216362340c 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -1,4 +1,5 @@
"""Tests for Softplus and SoftplusGrad."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -39,7 +40,7 @@ class SoftplusTest(tf.test.TestCase):
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float32, order="F")
err = gc.ComputeGradientError(x, [2, 5], y, [2, 5], x_init_value=x_init)
- print "softplus (float) gradient err = ", err
+ print("softplus (float) gradient err = ", err)
self.assertLess(err, 1e-4)
diff --git a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
index d87d15cae9..4529be21fc 100644
--- a/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
+++ b/tensorflow/python/kernel_tests/sparse_matmul_op_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.tf.matmul."""
+from __future__ import print_function
import tensorflow.python.platform
@@ -67,7 +68,7 @@ class MatMulGradientTest(tf.test.TestCase):
b_is_sparse=sp_b)
err = (gc.ComputeGradientError(a, [2, 3] if tr_a else [3, 2], m, [3, 4]) +
gc.ComputeGradientError(b, [4, 2] if tr_b else [2, 4], m, [3, 4]))
- print "sparse_matmul gradient err = ", err
+ print("sparse_matmul gradient err = ", err)
self.assertLess(err, 1e-3)
def testGradientInput(self):
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 4e44472c0d..c6ecaff799 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -1,4 +1,5 @@
"""Tests for SoftmaxCrossEntropyWithLogits op."""
+from __future__ import print_function
import tensorflow.python.platform
import numpy as np
@@ -102,7 +103,7 @@ class XentTest(tf.test.TestCase):
dtype=tf.float64, name="f")
x = tf.nn.softmax_cross_entropy_with_logits(f, l, name="xent")
err = gc.ComputeGradientError(f, [3, 4], x, [3])
- print "cross entropy gradient err = ", err
+ print("cross entropy gradient err = ", err)
self.assertLess(err, 5e-8)
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 11ce56e359..0bb9e787a5 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1,4 +1,5 @@
"""Tests for tensorflow.ops.nn."""
+from __future__ import print_function
import math
import tensorflow.python.platform
@@ -71,7 +72,7 @@ class SigmoidCrossEntropyWithLogitsTest(test_util.TensorFlowTestCase):
logits, targets, _ = self._Inputs(sizes=sizes)
loss = nn.sigmoid_cross_entropy_with_logits(logits, targets)
err = gc.ComputeGradientError(logits, sizes, loss, sizes)
- print "logistic loss gradient err = ", err
+ print("logistic loss gradient err = ", err)
self.assertLess(err, 1e-7)
@@ -264,7 +265,7 @@ class DeConv2DTest(test_util.TensorFlowTestCase):
f = constant_op.constant(f_val, name="f", dtype=types.float32)
output = nn.deconv2d(x, f, y_shape, strides=strides, padding="SAME")
err = gc.ComputeGradientError([x, f], [x_shape, f_shape], output, y_shape)
- print "DeConv gradient err = %g " % err
+ print("DeConv gradient err = %g " % err)
err_tolerance = 0.0005
self.assertLess(err, err_tolerance)
@@ -286,7 +287,7 @@ class L2LossTest(test_util.TensorFlowTestCase):
x = constant_op.constant(x_val, name="x")
output = nn.l2_loss(x)
err = gc.ComputeGradientError(x, x_shape, output, [1])
- print "L2Loss gradient err = %g " % err
+ print("L2Loss gradient err = %g " % err)
err_tolerance = 1e-11
self.assertLess(err, err_tolerance)
@@ -317,7 +318,7 @@ class L2NormalizeTest(test_util.TensorFlowTestCase):
x_tf = constant_op.constant(x_np, name="x")
y_tf = nn.l2_normalize(x_tf, dim)
err = gc.ComputeGradientError(x_tf, x_shape, y_tf, x_shape)
- print "L2Normalize gradient err = %g " % err
+ print("L2Normalize gradient err = %g " % err)
self.assertLess(err, 1e-4)
@@ -348,7 +349,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
# Check that we are in the 15% error range
expected_count = x_dim * y_dim * keep_prob * num_iter
rel_error = math.fabs(final_count - expected_count) / expected_count
- print rel_error
+ print(rel_error)
self.assertTrue(rel_error < 0.15)
def testShapedDropout(self):
@@ -377,7 +378,7 @@ class DropoutTest(test_util.TensorFlowTestCase):
# Check that we are in the 15% error range
expected_count = x_dim * y_dim * keep_prob * num_iter
rel_error = math.fabs(final_count - expected_count) / expected_count
- print rel_error
+ print(rel_error)
self.assertTrue(rel_error < 0.15)
def testShapedDropoutCorrelation(self):
@@ -494,9 +495,8 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
err = gc.ComputeGradientError(all_params[param_index],
all_shapes[param_index], output, x_shape)
- print "Batch normalization %s gradient %s scale err = " % (
- tag, "with" if scale_after_normalization else "without"
- ), err
+ print("Batch normalization %s gradient %s scale err = " %
+ (tag, "with" if scale_after_normalization else "without"), err)
self.assertLess(err, err_tolerance)
def testBatchNormInputGradient(self):
@@ -554,7 +554,7 @@ class BatchNormWithGlobalNormalizationTest(test_util.TensorFlowTestCase):
all_grads = sess.run([dx, dm, dv, db, odx, odm, odv, odb])
to_check = ["dx", "dm", "dv", "db"]
for i, n in enumerate(to_check):
- print n
+ print(n)
self.assertAllClose(
all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
@@ -606,7 +606,7 @@ class MomentsTest(test_util.TensorFlowTestCase):
elif from_y == "var":
y = out_var
err = gc.ComputeGradientError(x, x_shape, y, y_shape)
- print "Moments %s gradient err = %g" % (from_y, err)
+ print("Moments %s gradient err = %g" % (from_y, err))
self.assertLess(err, 1e-11)
def testMeanGlobalGradient(self):
diff --git a/tensorflow/python/platform/__init__.py b/tensorflow/python/platform/__init__.py
index b545bac907..10b12f4abc 100644
--- a/tensorflow/python/platform/__init__.py
+++ b/tensorflow/python/platform/__init__.py
@@ -1,5 +1,6 @@
"""Setup system-specific platform environment for TensorFlow."""
-import control_imports
+from __future__ import absolute_import
+from . import control_imports
if control_imports.USE_OSS:
from tensorflow.python.platform.default._init import *
else:
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
index 3d51bc74b2..7186d6e0b5 100644
--- a/tensorflow/python/platform/app.py
+++ b/tensorflow/python/platform/app.py
@@ -1,9 +1,10 @@
"""Switch between depending on pyglib.app or an OSS replacement."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS and control_imports.OSS_APP:
from tensorflow.python.platform.default._app import *
else:
diff --git a/tensorflow/python/platform/default/_gfile.py b/tensorflow/python/platform/default/_gfile.py
index cfd25bdf90..b3c4b8f9b9 100644
--- a/tensorflow/python/platform/default/_gfile.py
+++ b/tensorflow/python/platform/default/_gfile.py
@@ -26,13 +26,13 @@ class _GFileBase(object):
def wrap(self, *args, **kwargs):
try:
return fn(self, *args, **kwargs)
- except ValueError, e:
+ except ValueError as e:
# Sometimes a ValueError is raised, e.g., a read() on a closed file.
raise FileError(errno.EIO, e.message, self._name)
- except IOError, e:
+ except IOError as e:
e.filename = self._name
raise FileError(e)
- except OSError, e:
+ except OSError as e:
raise GOSError(e)
return wrap
@@ -187,7 +187,7 @@ class _GFileBase(object):
# read a file's lines by consuming the iterator with a list
with open("filename", "r") as fp: lines = list(fp)
"""
- return self._fp.next()
+ return next(self._fp)
@_error_wrapper
@_synchronized
@@ -271,11 +271,11 @@ def _func_error_wrapper(fn):
def wrap(*args, **kwargs):
try:
return fn(*args, **kwargs)
- except ValueError, e:
+ except ValueError as e:
raise FileError(errno.EIO, e.message)
- except IOError, e:
+ except IOError as e:
raise FileError(e)
- except OSError, e:
+ except OSError as e:
raise GOSError(e)
return wrap
@@ -299,7 +299,7 @@ def Glob(glob): # pylint: disable=invalid-name
@_func_error_wrapper
-def MkDir(path, mode=0755): # pylint: disable=invalid-name
+def MkDir(path, mode=0o755): # pylint: disable=invalid-name
"""Create the directory "path" with the given mode.
Args:
@@ -316,7 +316,7 @@ def MkDir(path, mode=0755): # pylint: disable=invalid-name
@_func_error_wrapper
-def MakeDirs(path, mode=0755): # pylint: disable=invalid-name
+def MakeDirs(path, mode=0o755): # pylint: disable=invalid-name
"""Recursively create the directory "path" with the given mode.
Args:
diff --git a/tensorflow/python/platform/default/_googletest.py b/tensorflow/python/platform/default/_googletest.py
index d2686565a0..42e0eac18a 100644
--- a/tensorflow/python/platform/default/_googletest.py
+++ b/tensorflow/python/platform/default/_googletest.py
@@ -42,7 +42,7 @@ def main(*args, **kwargs):
def getShardedTestCaseNames(testCaseClass):
filtered_names = []
for testcase in sorted(delegate_get_names(testCaseClass)):
- bucket = bucket_iterator.next()
+ bucket = next(bucket_iterator)
if bucket == shard_index:
filtered_names.append(testcase)
return filtered_names
@@ -60,7 +60,7 @@ def GetTempDir():
tempfile.gettempdir(), os.path.basename(inspect.getfile(first_frame)))
temp_dir = temp_dir.rstrip('.py')
if not os.path.isdir(temp_dir):
- os.mkdir(temp_dir, 0755)
+ os.mkdir(temp_dir, 0o755)
return temp_dir
diff --git a/tensorflow/python/platform/default/_logging.py b/tensorflow/python/platform/default/_logging.py
index 2e289b1abe..5f0ace51fb 100644
--- a/tensorflow/python/platform/default/_logging.py
+++ b/tensorflow/python/platform/default/_logging.py
@@ -39,7 +39,7 @@ _level_names = {
# Mask to convert integer thread ids to unsigned quantities for logging
# purposes
-_THREAD_ID_MASK = 2 * sys.maxint + 1
+_THREAD_ID_MASK = 2 * sys.maxsize + 1
_log_prefix = None # later set to google2_log_prefix
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
index d5b12d26df..85bb200e18 100644
--- a/tensorflow/python/platform/flags.py
+++ b/tensorflow/python/platform/flags.py
@@ -1,9 +1,10 @@
"""Switch between depending on pyglib.flags or open-source gflags."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS and control_imports.OSS_FLAGS:
from tensorflow.python.platform.default._flags import *
else:
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index fc28811821..a0737cd59b 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -1,9 +1,10 @@
"""Switch between depending on pyglib.gfile or an OSS replacement."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS and control_imports.OSS_GFILE:
from tensorflow.python.platform.default._gfile import *
else:
diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py
index ca22ec6e6b..2b4808552a 100644
--- a/tensorflow/python/platform/googletest.py
+++ b/tensorflow/python/platform/googletest.py
@@ -1,9 +1,10 @@
"""Switch between depending on googletest or unittest."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS and control_imports.OSS_GOOGLETEST:
from tensorflow.python.platform.default._googletest import *
else:
diff --git a/tensorflow/python/platform/logging.py b/tensorflow/python/platform/logging.py
index b6d2e53dd4..6a064398d5 100644
--- a/tensorflow/python/platform/logging.py
+++ b/tensorflow/python/platform/logging.py
@@ -1,9 +1,10 @@
"""Switch between depending on pyglib.logging or regular logging."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS and control_imports.OSS_LOGGING:
from tensorflow.python.platform.default._logging import *
else:
diff --git a/tensorflow/python/platform/parameterized.py b/tensorflow/python/platform/parameterized.py
index cf01512bc1..62b615474f 100644
--- a/tensorflow/python/platform/parameterized.py
+++ b/tensorflow/python/platform/parameterized.py
@@ -1,9 +1,10 @@
"""Switch between depending on pyglib.gfile or an OSS replacement."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS and control_imports.OSS_PARAMETERIZED:
from tensorflow.python.platform.default._parameterized import *
else:
diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py
index a0e6546c28..44ae05caf7 100644
--- a/tensorflow/python/platform/resource_loader.py
+++ b/tensorflow/python/platform/resource_loader.py
@@ -1,8 +1,9 @@
"""Load a file resource and return the contents."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
-import control_imports
+from . import control_imports
import tensorflow.python.platform
if control_imports.USE_OSS:
from tensorflow.python.platform.default._resource_loader import *
diff --git a/tensorflow/python/platform/status_bar.py b/tensorflow/python/platform/status_bar.py
index 720b9d82c0..87a84d9898 100644
--- a/tensorflow/python/platform/status_bar.py
+++ b/tensorflow/python/platform/status_bar.py
@@ -1,9 +1,10 @@
"""Switch between an internal status bar and a no-op version."""
+from __future__ import absolute_import
# pylint: disable=unused-import
# pylint: disable=g-import-not-at-top
# pylint: disable=wildcard-import
import tensorflow.python.platform
-import control_imports
+from . import control_imports
if control_imports.USE_OSS:
from tensorflow.python.platform.default._status_bar import *
else:
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
index 0571bc84cb..ac7c4be2b1 100644
--- a/tensorflow/python/summary/impl/event_file_loader.py
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -1,4 +1,5 @@
"""Functionality for loading events from a record file."""
+from __future__ import print_function
from tensorflow.core.util import event_pb2
from tensorflow.python import pywrap_tensorflow
@@ -38,11 +39,11 @@ class EventFileLoader(object):
def main(argv):
if len(argv) != 2:
- print 'Usage: event_file_loader <path-to-the-recordio-file>'
+ print('Usage: event_file_loader <path-to-the-recordio-file>')
return 1
loader = EventFileLoader(argv[1])
for event in loader.Load():
- print event
+ print(event)
if __name__ == '__main__':
diff --git a/tensorflow/python/training/coordinator_test.py b/tensorflow/python/training/coordinator_test.py
index ce9126caf4..bcd1234f3d 100644
--- a/tensorflow/python/training/coordinator_test.py
+++ b/tensorflow/python/training/coordinator_test.py
@@ -17,7 +17,7 @@ def RaiseInN(coord, n_secs, ex, report_exception):
try:
time.sleep(n_secs)
raise ex
- except RuntimeError, e:
+ except RuntimeError as e:
if report_exception:
coord.request_stop(e)
else:
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 1186117169..c0480f6c5c 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -1,6 +1,5 @@
"""Base class for optimizers."""
# pylint: disable=g-bad-name
-import types
from tensorflow.python.framework import ops
from tensorflow.python.framework import types as tf_types
@@ -234,7 +233,7 @@ class Optimizer(object):
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
for g, v in grads_and_vars:
- if not isinstance(g, (ops.Tensor, ops.IndexedSlices, types.NoneType)):
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices, type(None))):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
if not isinstance(v, variables.Variable):
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
index fcf9927c79..af9048c114 100644
--- a/tensorflow/python/training/queue_runner.py
+++ b/tensorflow/python/training/queue_runner.py
@@ -99,11 +99,11 @@ class QueueRunner(object):
if self._runs == 0:
try:
sess.run(self._close_op)
- except Exception, e:
+ except Exception as e:
# Intentionally ignore errors from close_op.
logging.vlog(1, "Ignored exception: %s", str(e))
return
- except Exception, e:
+ except Exception as e:
# This catches all other exceptions.
if coord:
coord.request_stop(e)
@@ -129,7 +129,7 @@ class QueueRunner(object):
coord.wait_for_stop()
try:
sess.run(cancel_op)
- except Exception, e:
+ except Exception as e:
# Intentionally ignore errors from cancel_op.
logging.vlog(1, "Ignored exception: %s", str(e))
# pylint: enable=broad-except
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 1ef1313eea..321e1cdd34 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -505,7 +505,7 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
except gfile.FileError:
# It's ok if the file cannot be read
return None
- except text_format.ParseError, e:
+ except text_format.ParseError as e:
logging.warning(str(e))
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
@@ -754,7 +754,7 @@ class Saver(object):
for f in gfile.Glob(self._CheckpointFilename(p)):
try:
gfile.Remove(f)
- except gfile.GOSError, e:
+ except gfile.GOSError as e:
logging.warning("Ignoring: %s", str(e))
def as_saver_def(self):
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index db378e9637..bfc856cbdb 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -319,7 +319,7 @@ class MaxToKeepTest(tf.test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_non_sharded")
try:
gfile.DeleteRecursively(save_dir)
- except gfile.GOSError, _:
+ except gfile.GOSError as _:
pass # Ignore
gfile.MakeDirs(save_dir)
@@ -408,7 +408,7 @@ class MaxToKeepTest(tf.test.TestCase):
save_dir = os.path.join(self.get_temp_dir(), "max_to_keep_sharded")
try:
gfile.DeleteRecursively(save_dir)
- except gfile.GOSError, _:
+ except gfile.GOSError as _:
pass # Ignore
gfile.MakeDirs(save_dir)
@@ -446,7 +446,7 @@ class KeepCheckpointEveryNHoursTest(tf.test.TestCase):
"keep_checkpoint_every_n_hours")
try:
gfile.DeleteRecursively(save_dir)
- except gfile.GOSError, _:
+ except gfile.GOSError as _:
pass # Ignore
gfile.MakeDirs(save_dir)
diff --git a/tensorflow/python/util/protobuf/compare_test.py b/tensorflow/python/util/protobuf/compare_test.py
index 9a03d123ae..25d1fb2914 100644
--- a/tensorflow/python/util/protobuf/compare_test.py
+++ b/tensorflow/python/util/protobuf/compare_test.py
@@ -284,17 +284,17 @@ class NormalizeNumbersTest(googletest.TestCase):
compare.NormalizeNumberFields(pb)
self.assertTrue(isinstance(pb.int64_, long))
- pb.int64_ = 4L
+ pb.int64_ = 4
compare.NormalizeNumberFields(pb)
self.assertTrue(isinstance(pb.int64_, long))
- pb.int64_ = 9999999999999999L
+ pb.int64_ = 9999999999999999
compare.NormalizeNumberFields(pb)
self.assertTrue(isinstance(pb.int64_, long))
def testNormalizesRepeatedInts(self):
pb = compare_test_pb2.Large()
- pb.int64s.extend([1L, 400, 999999999999999L])
+ pb.int64s.extend([1, 400, 999999999999999])
compare.NormalizeNumberFields(pb)
self.assertTrue(isinstance(pb.int64s[0], long))
self.assertTrue(isinstance(pb.int64s[1], long))
@@ -472,20 +472,20 @@ class AssertTest(googletest.TestCase):
pb1 = compare_test_pb2.Large()
pb1.int64_ = 4
pb2 = compare_test_pb2.Large()
- pb2.int64_ = 4L
+ pb2.int64_ = 4
compare.assertProto2Equal(self, pb1, pb2)
def testNormalizesFloat(self):
pb1 = compare_test_pb2.Large()
pb1.double_ = 4.0
pb2 = compare_test_pb2.Large()
- pb2.double_ = 4L
+ pb2.double_ = 4
compare.assertProto2Equal(self, pb1, pb2, normalize_numbers=True)
pb1 = compare_test_pb2.Medium()
pb1.floats.extend([4.0, 6.0])
pb2 = compare_test_pb2.Medium()
- pb2.floats.extend([6L, 4L])
+ pb2.floats.extend([6, 4])
compare.assertProto2SameElements(self, pb1, pb2, normalize_numbers=True)
def testPrimitives(self):
diff --git a/tensorflow/tensorboard/tensorboard.py b/tensorflow/tensorboard/tensorboard.py
index dcbc50401c..c75db0d1f7 100644
--- a/tensorflow/tensorboard/tensorboard.py
+++ b/tensorflow/tensorboard/tensorboard.py
@@ -3,6 +3,7 @@
This is a simple web server to proxy data from the event_loader to the web, and
serve static web files.
"""
+from __future__ import print_function
import BaseHTTPServer
import functools
@@ -111,7 +112,7 @@ def main(unused_argv=None):
return -1
if FLAGS.debug:
- logging.info('Starting TensorBoard in directory %s' % os.getcwd())
+ logging.info('Starting TensorBoard in directory %s', os.getcwd())
path_to_run = ParseEventFilesFlag(FLAGS.logdir)
multiplexer = event_multiplexer.AutoloadingMultiplexer(
@@ -125,13 +126,13 @@ def main(unused_argv=None):
try:
server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
except socket.error:
- logging.error('Tried to connect to port %d, but that address is in use.' %
+ logging.error('Tried to connect to port %d, but that address is in use.',
FLAGS.port)
return -2
status_bar.SetupStatusBarInsideGoogle('TensorBoard', FLAGS.port)
- print 'Starting TensorBoard on port %d' % FLAGS.port
- print '(You can navigate to http://localhost:%d)' % FLAGS.port
+ print('Starting TensorBoard on port %d' % FLAGS.port)
+ print('(You can navigate to http://localhost:%d)' % FLAGS.port)
server.serve_forever()