aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-01-04 21:00:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-04 21:04:20 -0800
commite8127588c4cdb8e8e46983c69ec70258ea329108 (patch)
treeec4adf867f5cf90afeb7c44f465cc9b5e7fc21f9
parent969f5a06271f506ce53c0078d2cf706393a7ee56 (diff)
Add MNIST GAN example and benchmarks
Training benchmark results on GPU: batch_size 64 128 256 training eager rate 2640 5330 7269 training graph rate 5340 7080 8380 traning eager/graph 0.5 0.75 0.86 generating eager rate 45872 86192 142009 generating graph rate 63558 85610 104939 generating eager/graph 0.72 1.0 1.35 rate is processed/generated mnist images per second. Eager is faster when generating because we don't need to copy "feeds" from CPU to GPU memory. PiperOrigin-RevId: 180885299
-rw-r--r--tensorflow/contrib/eager/python/examples/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/BUILD36
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/README.md38
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist.py368
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py151
-rw-r--r--tensorflow/contrib/eager/python/examples/gan/mnist_test.py113
6 files changed, 707 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD
index 6aef010a21..15a21885f6 100644
--- a/tensorflow/contrib/eager/python/examples/BUILD
+++ b/tensorflow/contrib/eager/python/examples/BUILD
@@ -6,6 +6,7 @@ package(default_visibility = ["//tensorflow:internal"])
py_library(
name = "examples_pip",
deps = [
+ "//tensorflow/contrib/eager/python/examples/gan:mnist",
"//tensorflow/contrib/eager/python/examples/linear_regression",
"//tensorflow/contrib/eager/python/examples/mnist",
"//tensorflow/contrib/eager/python/examples/resnet50",
diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD
new file mode 100644
index 0000000000..c61ec2dbae
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/BUILD
@@ -0,0 +1,36 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+
+py_binary(
+ name = "mnist",
+ srcs = ["mnist.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+cuda_py_test(
+ name = "mnist_test",
+ srcs = ["mnist_test.py"],
+ additional_deps = [
+ ":mnist",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cuda_py_test(
+ name = "mnist_graph_test",
+ srcs = ["mnist_graph_test.py"],
+ additional_deps = [
+ ":mnist",
+ "//third_party/py/numpy",
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md
new file mode 100644
index 0000000000..e8c9db1a1e
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/README.md
@@ -0,0 +1,38 @@
+# GAN with TensorFlow eager execution
+
+A simple Generative Adversarial Network (GAN) example using eager execution.
+The discriminator and generator networks each contain a few convolution and
+fully connected layers.
+
+Other eager execution examples can be found under the parent directory.
+
+## Content
+
+- `mnist.py`: Model definitions and training routines.
+- `mnist_test.py`: Benchmarks for training and using the models using eager
+execution.
+- `mnist_graph_test.py`: Benchmarks for trainig and using the models using
+graph execution. The same model definitions and loss functions are used in
+all benchmarks.
+
+
+## To run
+
+- Make sure you have installed TensorFlow 1.5+ or the latest `tf-nightly`
+or `tf-nightly-gpu` pip package in order to access the eager execution feature.
+
+- Train model. E.g.,
+
+ ```bash
+ python mnist.py
+ ```
+
+ Use `--output_dir=<DIR>` to direct the script to save TensorBoard summaries
+ during training. Disabled by default.
+
+ Use `--checkpoint_dir=<DIR>` to direct the script to save checkpoints to
+ `<DIR>` during training. DIR defaults to /tmp/tensorflow/mnist/checkpoints/.
+ The script will load the latest saved checkpoint from this directory if
+ one exists.
+
+ Use `-h` for other options.
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py
new file mode 100644
index 0000000000..b9ac79f46c
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py
@@ -0,0 +1,368 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A deep MNIST classifier using convolutional layers.
+
+Sample usage:
+ python mnist.py --help
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import os
+import sys
+import time
+
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.examples.tutorials.mnist import input_data
+
+FLAGS = None
+
+
+class Discriminator(tfe.Network):
+ """GAN Discriminator.
+
+ A network to differentiate between generated and real handwritten digits.
+ """
+
+ def __init__(self, data_format):
+ """Creates a model for discriminating between real and generated digits.
+
+ Args:
+ data_format: Either 'channels_first' or 'channels_last'.
+ 'channels_first' is typically faster on GPUs while 'channels_last' is
+ typically faster on CPUs. See
+ https://www.tensorflow.org/performance/performance_guide#data_formats
+ """
+ super(Discriminator, self).__init__(name='')
+ if data_format == 'channels_first':
+ self._input_shape = [-1, 1, 28, 28]
+ else:
+ assert data_format == 'channels_last'
+ self._input_shape = [-1, 28, 28, 1]
+ self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME',
+ data_format=data_format,
+ activation=tf.tanh))
+ self.pool1 = self.track_layer(
+ tf.layers.AveragePooling2D(2, 2, data_format=data_format))
+ self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5,
+ data_format=data_format,
+ activation=tf.tanh))
+ self.pool2 = self.track_layer(
+ tf.layers.AveragePooling2D(2, 2, data_format=data_format))
+ self.flatten = self.track_layer(tf.layers.Flatten())
+ self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh))
+ self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None))
+
+ def call(self, inputs):
+ """Return two logits per image estimating input authenticity.
+
+ Users should invoke __call__ to run the network, which delegates to this
+ method (and not call this method directly).
+
+ Args:
+ inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
+ or [batch_size, 1, 28, 28]
+
+ Returns:
+ A Tensor with shape [batch_size] containing logits estimating
+ the probability that corresponding digit is real.
+ """
+ x = tf.reshape(inputs, self._input_shape)
+ x = self.conv1(x)
+ x = self.pool1(x)
+ x = self.conv2(x)
+ x = self.pool2(x)
+ x = self.flatten(x)
+ x = self.fc1(x)
+ x = self.fc2(x)
+ return x
+
+
+class Generator(tfe.Network):
+ """Generator of handwritten digits similar to the ones in the MNIST dataset.
+ """
+
+ def __init__(self, data_format):
+ """Creates a model for discriminating between real and generated digits.
+
+ Args:
+ data_format: Either 'channels_first' or 'channels_last'.
+ 'channels_first' is typically faster on GPUs while 'channels_last' is
+ typically faster on CPUs. See
+ https://www.tensorflow.org/performance/performance_guide#data_formats
+ """
+ super(Generator, self).__init__(name='')
+ self.data_format = data_format
+ # We are using 128 6x6 channels as input to the first deconvolution layer
+ if data_format == 'channels_first':
+ self._pre_conv_shape = [-1, 128, 6, 6]
+ else:
+ assert data_format == 'channels_last'
+ self._pre_conv_shape = [-1, 6, 6, 128]
+ self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128,
+ activation=tf.tanh))
+
+ # In call(), we reshape the output of fc1 to _pre_conv_shape
+
+ # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
+ self.conv1 = self.track_layer(tf.layers.Conv2DTranspose(
+ 64, 4, strides=2, activation=None, data_format=data_format))
+
+ # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
+ self.conv2 = self.track_layer(tf.layers.Conv2DTranspose(
+ 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format))
+
+ def call(self, inputs):
+ """Return a batch of generated images.
+
+ Users should invoke __call__ to run the network, which delegates to this
+ method (and not call this method directly).
+
+ Args:
+ inputs: A batch of noise vectors as a Tensor with shape
+ [batch_size, length of noise vectors].
+
+ Returns:
+ A Tensor containing generated images. If data_format is 'channels_last',
+ the shape of returned images is [batch_size, 28, 28, 1], else
+ [batch_size, 1, 28, 28]
+ """
+
+ x = self.fc1(inputs)
+ x = tf.reshape(x, shape=self._pre_conv_shape)
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x
+
+
+def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
+ """Original discriminator loss for GANs, with label smoothing.
+
+ See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
+ details.
+
+ Args:
+ discriminator_real_outputs: Discriminator output on real data.
+ discriminator_gen_outputs: Discriminator output on generated data. Expected
+ to be in the range of (-inf, inf).
+
+ Returns:
+ A scalar loss Tensor.
+ """
+
+ loss_on_real = tf.losses.sigmoid_cross_entropy(
+ tf.ones_like(discriminator_real_outputs), discriminator_real_outputs,
+ label_smoothing=0.25)
+ loss_on_generated = tf.losses.sigmoid_cross_entropy(
+ tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
+ loss = loss_on_real + loss_on_generated
+ tf.contrib.summary.scalar('discriminator_loss', loss)
+ return loss
+
+
+def generator_loss(discriminator_gen_outputs):
+ """Original generator loss for GANs.
+
+ L = -log(sigmoid(D(G(z))))
+
+ See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
+ for more details.
+
+ Args:
+ discriminator_gen_outputs: Discriminator output on generated data. Expected
+ to be in the range of (-inf, inf).
+
+ Returns:
+ A scalar loss Tensor.
+ """
+ loss = tf.losses.sigmoid_cross_entropy(
+ tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
+ tf.contrib.summary.scalar('generator_loss', loss)
+ return loss
+
+
+def train_one_epoch(generator, discriminator,
+ generator_optimizer, discriminator_optimizer,
+ dataset, log_interval, noise_dim):
+ """Trains `generator` and `discriminator` models on `dataset`.
+
+ Args:
+ generator: Generator model.
+ discriminator: Discriminator model.
+ generator_optimizer: Optimizer to use for generator.
+ discriminator_optimizer: Optimizer to use for discriminator.
+ dataset: Dataset of images to train on.
+ log_interval: How many global steps to wait between logging and collecting
+ summaries.
+ noise_dim: Dimension of noise vector to use.
+ """
+
+ total_generator_loss = 0.0
+ total_discriminator_loss = 0.0
+ for (batch_index, images) in enumerate(tfe.Iterator(dataset)):
+ with tf.device('/cpu:0'):
+ tf.assign_add(tf.train.get_global_step(), 1)
+
+ with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval):
+ current_batch_size = images.shape[0]
+ noise = tf.random_uniform(shape=[current_batch_size, noise_dim],
+ minval=-1., maxval=1., seed=batch_index)
+
+ with tfe.GradientTape(persistent=True) as g:
+ generated_images = generator(noise)
+ tf.contrib.summary.image('generated_images',
+ tf.reshape(generated_images, [-1, 28, 28, 1]),
+ max_images=10)
+
+ discriminator_gen_outputs = discriminator(generated_images)
+ discriminator_real_outputs = discriminator(images)
+ discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
+ discriminator_gen_outputs)
+ total_discriminator_loss += discriminator_loss_val
+
+ generator_loss_val = generator_loss(discriminator_gen_outputs)
+ total_generator_loss += generator_loss_val
+
+ generator_grad = g.gradient(generator_loss_val, generator.variables)
+ discriminator_grad = g.gradient(discriminator_loss_val,
+ discriminator.variables)
+
+ with tf.variable_scope('generator'):
+ generator_optimizer.apply_gradients(zip(generator_grad,
+ generator.variables))
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer.apply_gradients(zip(discriminator_grad,
+ discriminator.variables))
+
+ if log_interval and batch_index > 0 and batch_index % log_interval == 0:
+ print('Batch #%d\tAverage Generator Loss: %.6f\t'
+ 'Average Discriminator Loss: %.6f' % (
+ batch_index, total_generator_loss/batch_index,
+ total_discriminator_loss/batch_index))
+
+
+def main(_):
+ (device, data_format) = ('/gpu:0', 'channels_first')
+ if FLAGS.no_gpu or tfe.num_gpus() <= 0:
+ (device, data_format) = ('/cpu:0', 'channels_last')
+ print('Using device %s, and data format %s.' % (device, data_format))
+
+ # Load the datasets
+ data = input_data.read_data_sets(FLAGS.data_dir)
+ dataset = (tf.data.Dataset
+ .from_tensor_slices(data.train.images)
+ .shuffle(60000)
+ .batch(FLAGS.batch_size))
+
+ # Create the models and optimizers
+ generator = Generator(data_format)
+ discriminator = Discriminator(data_format)
+ with tf.variable_scope('generator'):
+ generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr)
+
+ # Prepare summary writer and checkpoint info
+ summary_writer = tf.contrib.summary.create_summary_file_writer(
+ FLAGS.output_dir, flush_millis=1000)
+ checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
+ latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+ if latest_cpkt:
+ print('Using latest checkpoint at ' + latest_cpkt)
+
+ with tf.device(device):
+ for epoch in range(1, 101):
+ with tfe.restore_variables_on_create(latest_cpkt):
+ global_step = tf.train.get_or_create_global_step()
+ start = time.time()
+ with summary_writer.as_default():
+ train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer,
+ dataset, FLAGS.log_interval, FLAGS.noise)
+ end = time.time()
+ print('\nTrain time for epoch #%d (global step %d): %f' % (
+ epoch, global_step.numpy(), end - start))
+
+ all_variables = (
+ generator.variables
+ + discriminator.variables
+ + generator_optimizer.variables()
+ + discriminator_optimizer.variables()
+ + [global_step])
+ tfe.Saver(all_variables).save(
+ checkpoint_prefix, global_step=global_step)
+
+
+if __name__ == '__main__':
+ tfe.enable_eager_execution()
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--data-dir',
+ type=str,
+ default='/tmp/tensorflow/mnist/input_data',
+ help=('Directory for storing input data (default '
+ '/tmp/tensorflow/mnist/input_data)'))
+ parser.add_argument(
+ '--batch-size',
+ type=int,
+ default=128,
+ metavar='N',
+ help='input batch size for training (default: 128)')
+ parser.add_argument(
+ '--log-interval',
+ type=int,
+ default=100,
+ metavar='N',
+ help=('number of batches between logging and writing summaries '
+ '(default: 100)'))
+ parser.add_argument(
+ '--output_dir',
+ type=str,
+ default=None,
+ metavar='DIR',
+ help='Directory to write TensorBoard summaries (defaults to none)')
+ parser.add_argument(
+ '--checkpoint_dir',
+ type=str,
+ default='/tmp/tensorflow/mnist/checkpoints/',
+ metavar='DIR',
+ help=('Directory to save checkpoints in (once per epoch) (default '
+ '/tmp/tensorflow/mnist/checkpoints/)'))
+ parser.add_argument(
+ '--lr',
+ type=float,
+ default=0.001,
+ metavar='LR',
+ help='learning rate (default: 0.001)')
+ parser.add_argument(
+ '--noise',
+ type=int,
+ default=100,
+ metavar='N',
+ help='Length of noise vector for generator input (default: 100)')
+ parser.add_argument(
+ '--no-gpu',
+ action='store_true',
+ default=False,
+ help='disables GPU usage even if a GPU is available')
+
+ FLAGS, unparsed = parser.parse_known_args()
+ tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py
new file mode 100644
index 0000000000..12b39b0cde
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py
@@ -0,0 +1,151 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import time
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.contrib.eager.python.examples.gan import mnist
+
+NOISE_DIM = 100
+# Big enough so that summaries are never recorded.
+# Lower this value if would like to benchmark with some summaries.
+SUMMARY_INTERVAL = 10000
+SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms
+
+
+def data_format():
+ return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+class MnistGraphGanBenchmark(tf.test.Benchmark):
+
+ def _create_graph(self, batch_size):
+ # Generate some random data.
+ images_data = np.random.randn(batch_size, 784).astype(np.float32)
+ dataset = tf.data.Dataset.from_tensors(images_data)
+ images = dataset.repeat().make_one_shot_iterator().get_next()
+
+ # Create the models and optimizers
+ generator = mnist.Generator(data_format())
+ discriminator = mnist.Discriminator(data_format())
+ with tf.variable_scope('generator'):
+ generator_optimizer = tf.train.AdamOptimizer(0.001)
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer = tf.train.AdamOptimizer(0.001)
+
+ # Run models and compute loss
+ noise_placeholder = tf.placeholder(tf.float32,
+ shape=[batch_size, NOISE_DIM])
+ generated_images = generator(noise_placeholder)
+ tf.contrib.summary.image('generated_images',
+ tf.reshape(generated_images, [-1, 28, 28, 1]),
+ max_images=10)
+ discriminator_gen_outputs = discriminator(generated_images)
+ discriminator_real_outputs = discriminator(images)
+ generator_loss = mnist.generator_loss(discriminator_gen_outputs)
+ discriminator_loss = mnist.discriminator_loss(discriminator_real_outputs,
+ discriminator_gen_outputs)
+ # Get train ops
+ with tf.variable_scope('generator'):
+ generator_train = generator_optimizer.minimize(
+ generator_loss, var_list=generator.variables)
+ with tf.variable_scope('discriminator'):
+ discriminator_train = discriminator_optimizer.minimize(
+ discriminator_loss, var_list=discriminator.variables)
+
+ return (generator_train, discriminator_train, noise_placeholder)
+
+ def _report(self, test_name, start, num_iters, batch_size):
+ avg_time = (time.time() - start) / num_iters
+ dev = 'gpu' if tf.test.is_gpu_available() else 'cpu'
+ name = 'graph_%s_%s_batch_%d_%s' % (test_name, dev, batch_size,
+ data_format())
+ extras = {'examples_per_sec': batch_size / avg_time}
+ self.report_benchmark(
+ iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+ def benchmark_train(self):
+ for batch_size in [64, 128, 256]:
+ with tf.Graph().as_default():
+ global_step = tf.train.get_or_create_global_step()
+ increment_global_step = tf.assign_add(global_step, 1)
+ with tf.contrib.summary.create_file_writer(
+ tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(), (
+ tf.contrib.summary.record_summaries_every_n_global_steps(
+ SUMMARY_INTERVAL)):
+ (generator_train, discriminator_train, noise_placeholder
+ ) = self._create_graph(batch_size)
+
+ with tf.Session() as sess:
+ tf.contrib.summary.initialize(graph=tf.get_default_graph(),
+ session=sess)
+
+ sess.run(tf.global_variables_initializer())
+
+ num_burn, num_iters = (3, 100)
+ for _ in range(num_burn):
+ noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+ # Increment global step before evaluating summary ops to avoid
+ # race condition.
+ sess.run(increment_global_step)
+ sess.run([generator_train, discriminator_train,
+ tf.contrib.summary.all_summary_ops()],
+ feed_dict={noise_placeholder: noise})
+
+ # Run and benchmark 2 epochs
+ start = time.time()
+ for _ in range(num_iters):
+ noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+ sess.run(increment_global_step)
+ sess.run([generator_train, discriminator_train,
+ tf.contrib.summary.all_summary_ops()],
+ feed_dict={noise_placeholder: noise})
+ self._report('train', start, num_iters, batch_size)
+
+ def benchmark_generate(self):
+ for batch_size in [64, 128, 256]:
+ with tf.Graph().as_default():
+ # Using random weights. This will generate garbage.
+ generator = mnist.Generator(data_format())
+ noise_placeholder = tf.placeholder(tf.float32,
+ shape=[batch_size, NOISE_DIM])
+ generated_images = generator(noise_placeholder)
+
+ init = tf.global_variables_initializer()
+ with tf.Session() as sess:
+ sess.run(init)
+ noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM])
+ num_burn, num_iters = (30, 1000)
+ for _ in range(num_burn):
+ sess.run(generated_images, feed_dict={noise_placeholder: noise})
+
+ start = time.time()
+ for _ in range(num_iters):
+ # Comparison with the eager execution benchmark in mnist_test.py
+ # isn't entirely fair as the time here includes the cost of copying
+ # the feeds from CPU memory to GPU.
+ sess.run(generated_images, feed_dict={noise_placeholder: noise})
+ self._report('generate', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py
new file mode 100644
index 0000000000..4a3ca8d82b
--- /dev/null
+++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py
@@ -0,0 +1,113 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import time
+
+import tensorflow as tf
+
+import tensorflow.contrib.eager as tfe
+from tensorflow.contrib.eager.python.examples.gan import mnist
+
+NOISE_DIM = 100
+# Big enough so that summaries are never recorded.
+# Lower this value if would like to benchmark with some summaries.
+SUMMARY_INTERVAL = 10000
+SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms
+
+
+def data_format():
+ return 'channels_first' if tf.test.is_gpu_available() else 'channels_last'
+
+
+def device():
+ return '/gpu:0' if tfe.num_gpus() else '/cpu:0'
+
+
+class MnistEagerGanBenchmark(tf.test.Benchmark):
+
+ def _report(self, test_name, start, num_iters, batch_size):
+ avg_time = (time.time() - start) / num_iters
+ dev = 'gpu' if tfe.num_gpus() else 'cpu'
+ name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size,
+ data_format())
+ extras = {'examples_per_sec': batch_size / avg_time}
+ self.report_benchmark(
+ iters=num_iters, wall_time=avg_time, name=name, extras=extras)
+
+ def benchmark_train(self):
+ for batch_size in [64, 128, 256]:
+ # Generate some random data.
+ burn_batches, measure_batches = (3, 100)
+ burn_images = [tf.random_normal([batch_size, 784])
+ for _ in range(burn_batches)]
+ burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images)
+ measure_images = [tf.random_normal([batch_size, 784])
+ for _ in range(measure_batches)]
+ measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images)
+
+ tf.train.get_or_create_global_step()
+ with tf.device(device()):
+ # Create the models and optimizers
+ generator = mnist.Generator(data_format())
+ discriminator = mnist.Discriminator(data_format())
+ with tf.variable_scope('generator'):
+ generator_optimizer = tf.train.AdamOptimizer(0.001)
+ with tf.variable_scope('discriminator'):
+ discriminator_optimizer = tf.train.AdamOptimizer(0.001)
+
+ with tf.contrib.summary.create_file_writer(
+ tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default():
+
+ # warm up
+ mnist.train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer,
+ burn_dataset, log_interval=SUMMARY_INTERVAL,
+ noise_dim=NOISE_DIM)
+ # measure
+ start = time.time()
+ mnist.train_one_epoch(generator, discriminator, generator_optimizer,
+ discriminator_optimizer,
+ measure_dataset, log_interval=SUMMARY_INTERVAL,
+ noise_dim=NOISE_DIM)
+ self._report('train', start, measure_batches, batch_size)
+
+ def benchmark_generate(self):
+ for batch_size in [64, 128, 256]:
+ with tf.device(device()):
+ # Using random weights. This will generate garbage.
+ generator = mnist.Generator(data_format())
+
+ num_burn, num_iters = (30, 1000)
+ for _ in range(num_burn):
+ noise = tf.random_uniform(shape=[batch_size, NOISE_DIM],
+ minval=-1., maxval=1.)
+ generator(noise)
+
+ start = time.time()
+ for _ in range(num_iters):
+ noise = tf.random_uniform(shape=[batch_size, NOISE_DIM],
+ minval=-1., maxval=1.)
+ generator(noise)
+ self._report('generate', start, num_iters, batch_size)
+
+
+if __name__ == '__main__':
+ tfe.enable_eager_execution()
+ tf.test.main()