From e8127588c4cdb8e8e46983c69ec70258ea329108 Mon Sep 17 00:00:00 2001 From: Igor Ganichev Date: Thu, 4 Jan 2018 21:00:08 -0800 Subject: Add MNIST GAN example and benchmarks Training benchmark results on GPU: batch_size 64 128 256 training eager rate 2640 5330 7269 training graph rate 5340 7080 8380 traning eager/graph 0.5 0.75 0.86 generating eager rate 45872 86192 142009 generating graph rate 63558 85610 104939 generating eager/graph 0.72 1.0 1.35 rate is processed/generated mnist images per second. Eager is faster when generating because we don't need to copy "feeds" from CPU to GPU memory. PiperOrigin-RevId: 180885299 --- tensorflow/contrib/eager/python/examples/BUILD | 1 + tensorflow/contrib/eager/python/examples/gan/BUILD | 36 ++ .../contrib/eager/python/examples/gan/README.md | 38 +++ .../contrib/eager/python/examples/gan/mnist.py | 368 +++++++++++++++++++++ .../eager/python/examples/gan/mnist_graph_test.py | 151 +++++++++ .../eager/python/examples/gan/mnist_test.py | 113 +++++++ 6 files changed, 707 insertions(+) create mode 100644 tensorflow/contrib/eager/python/examples/gan/BUILD create mode 100644 tensorflow/contrib/eager/python/examples/gan/README.md create mode 100644 tensorflow/contrib/eager/python/examples/gan/mnist.py create mode 100644 tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py create mode 100644 tensorflow/contrib/eager/python/examples/gan/mnist_test.py diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 6aef010a21..15a21885f6 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,6 +6,7 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ + "//tensorflow/contrib/eager/python/examples/gan:mnist", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/mnist", "//tensorflow/contrib/eager/python/examples/resnet50", diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD new file mode 100644 index 0000000000..c61ec2dbae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "mnist", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/examples/tutorials/mnist:input_data", + ], +) + +cuda_py_test( + name = "mnist_test", + srcs = ["mnist_test.py"], + additional_deps = [ + ":mnist", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "mnist_graph_test", + srcs = ["mnist_graph_test.py"], + additional_deps = [ + ":mnist", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md new file mode 100644 index 0000000000..e8c9db1a1e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/README.md @@ -0,0 +1,38 @@ +# GAN with TensorFlow eager execution + +A simple Generative Adversarial Network (GAN) example using eager execution. +The discriminator and generator networks each contain a few convolution and +fully connected layers. + +Other eager execution examples can be found under the parent directory. + +## Content + +- `mnist.py`: Model definitions and training routines. +- `mnist_test.py`: Benchmarks for training and using the models using eager +execution. +- `mnist_graph_test.py`: Benchmarks for trainig and using the models using +graph execution. The same model definitions and loss functions are used in +all benchmarks. + + +## To run + +- Make sure you have installed TensorFlow 1.5+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- Train model. E.g., + + ```bash + python mnist.py + ``` + + Use `--output_dir=` to direct the script to save TensorBoard summaries + during training. Disabled by default. + + Use `--checkpoint_dir=` to direct the script to save checkpoints to + `` during training. DIR defaults to /tmp/tensorflow/mnist/checkpoints/. + The script will load the latest saved checkpoint from this directory if + one exists. + + Use `-h` for other options. diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py new file mode 100644 index 0000000000..b9ac79f46c --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -0,0 +1,368 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A deep MNIST classifier using convolutional layers. + +Sample usage: + python mnist.py --help +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.examples.tutorials.mnist import input_data + +FLAGS = None + + +class Discriminator(tfe.Network): + """GAN Discriminator. + + A network to differentiate between generated and real handwritten digits. + """ + + def __init__(self, data_format): + """Creates a model for discriminating between real and generated digits. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Discriminator, self).__init__(name='') + if data_format == 'channels_first': + self._input_shape = [-1, 1, 28, 28] + else: + assert data_format == 'channels_last' + self._input_shape = [-1, 28, 28, 1] + self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', + data_format=data_format, + activation=tf.tanh)) + self.pool1 = self.track_layer( + tf.layers.AveragePooling2D(2, 2, data_format=data_format)) + self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, + data_format=data_format, + activation=tf.tanh)) + self.pool2 = self.track_layer( + tf.layers.AveragePooling2D(2, 2, data_format=data_format)) + self.flatten = self.track_layer(tf.layers.Flatten()) + self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) + self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + + def call(self, inputs): + """Return two logits per image estimating input authenticity. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1] + or [batch_size, 1, 28, 28] + + Returns: + A Tensor with shape [batch_size] containing logits estimating + the probability that corresponding digit is real. + """ + x = tf.reshape(inputs, self._input_shape) + x = self.conv1(x) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + return x + + +class Generator(tfe.Network): + """Generator of handwritten digits similar to the ones in the MNIST dataset. + """ + + def __init__(self, data_format): + """Creates a model for discriminating between real and generated digits. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Generator, self).__init__(name='') + self.data_format = data_format + # We are using 128 6x6 channels as input to the first deconvolution layer + if data_format == 'channels_first': + self._pre_conv_shape = [-1, 128, 6, 6] + else: + assert data_format == 'channels_last' + self._pre_conv_shape = [-1, 6, 6, 128] + self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, + activation=tf.tanh)) + + # In call(), we reshape the output of fc1 to _pre_conv_shape + + # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) + self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format)) + + # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) + self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + + def call(self, inputs): + """Return a batch of generated images. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of noise vectors as a Tensor with shape + [batch_size, length of noise vectors]. + + Returns: + A Tensor containing generated images. If data_format is 'channels_last', + the shape of returned images is [batch_size, 28, 28, 1], else + [batch_size, 1, 28, 28] + """ + + x = self.fc1(inputs) + x = tf.reshape(x, shape=self._pre_conv_shape) + x = self.conv1(x) + x = self.conv2(x) + return x + + +def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): + """Original discriminator loss for GANs, with label smoothing. + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more + details. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + + Returns: + A scalar loss Tensor. + """ + + loss_on_real = tf.losses.sigmoid_cross_entropy( + tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + label_smoothing=0.25) + loss_on_generated = tf.losses.sigmoid_cross_entropy( + tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) + loss = loss_on_real + loss_on_generated + tf.contrib.summary.scalar('discriminator_loss', loss) + return loss + + +def generator_loss(discriminator_gen_outputs): + """Original generator loss for GANs. + + L = -log(sigmoid(D(G(z)))) + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) + for more details. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + + Returns: + A scalar loss Tensor. + """ + loss = tf.losses.sigmoid_cross_entropy( + tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) + tf.contrib.summary.scalar('generator_loss', loss) + return loss + + +def train_one_epoch(generator, discriminator, + generator_optimizer, discriminator_optimizer, + dataset, log_interval, noise_dim): + """Trains `generator` and `discriminator` models on `dataset`. + + Args: + generator: Generator model. + discriminator: Discriminator model. + generator_optimizer: Optimizer to use for generator. + discriminator_optimizer: Optimizer to use for discriminator. + dataset: Dataset of images to train on. + log_interval: How many global steps to wait between logging and collecting + summaries. + noise_dim: Dimension of noise vector to use. + """ + + total_generator_loss = 0.0 + total_discriminator_loss = 0.0 + for (batch_index, images) in enumerate(tfe.Iterator(dataset)): + with tf.device('/cpu:0'): + tf.assign_add(tf.train.get_global_step(), 1) + + with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + current_batch_size = images.shape[0] + noise = tf.random_uniform(shape=[current_batch_size, noise_dim], + minval=-1., maxval=1., seed=batch_index) + + with tfe.GradientTape(persistent=True) as g: + generated_images = generator(noise) + tf.contrib.summary.image('generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) + + discriminator_gen_outputs = discriminator(generated_images) + discriminator_real_outputs = discriminator(images) + discriminator_loss_val = discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs) + total_discriminator_loss += discriminator_loss_val + + generator_loss_val = generator_loss(discriminator_gen_outputs) + total_generator_loss += generator_loss_val + + generator_grad = g.gradient(generator_loss_val, generator.variables) + discriminator_grad = g.gradient(discriminator_loss_val, + discriminator.variables) + + with tf.variable_scope('generator'): + generator_optimizer.apply_gradients(zip(generator_grad, + generator.variables)) + with tf.variable_scope('discriminator'): + discriminator_optimizer.apply_gradients(zip(discriminator_grad, + discriminator.variables)) + + if log_interval and batch_index > 0 and batch_index % log_interval == 0: + print('Batch #%d\tAverage Generator Loss: %.6f\t' + 'Average Discriminator Loss: %.6f' % ( + batch_index, total_generator_loss/batch_index, + total_discriminator_loss/batch_index)) + + +def main(_): + (device, data_format) = ('/gpu:0', 'channels_first') + if FLAGS.no_gpu or tfe.num_gpus() <= 0: + (device, data_format) = ('/cpu:0', 'channels_last') + print('Using device %s, and data format %s.' % (device, data_format)) + + # Load the datasets + data = input_data.read_data_sets(FLAGS.data_dir) + dataset = (tf.data.Dataset + .from_tensor_slices(data.train.images) + .shuffle(60000) + .batch(FLAGS.batch_size)) + + # Create the models and optimizers + generator = Generator(data_format) + discriminator = Discriminator(data_format) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + + # Prepare summary writer and checkpoint info + summary_writer = tf.contrib.summary.create_summary_file_writer( + FLAGS.output_dir, flush_millis=1000) + checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') + latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) + if latest_cpkt: + print('Using latest checkpoint at ' + latest_cpkt) + + with tf.device(device): + for epoch in range(1, 101): + with tfe.restore_variables_on_create(latest_cpkt): + global_step = tf.train.get_or_create_global_step() + start = time.time() + with summary_writer.as_default(): + train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + dataset, FLAGS.log_interval, FLAGS.noise) + end = time.time() + print('\nTrain time for epoch #%d (global step %d): %f' % ( + epoch, global_step.numpy(), end - start)) + + all_variables = ( + generator.variables + + discriminator.variables + + generator_optimizer.variables() + + discriminator_optimizer.variables() + + [global_step]) + tfe.Saver(all_variables).save( + checkpoint_prefix, global_step=global_step) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + + parser = argparse.ArgumentParser() + parser.add_argument( + '--data-dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help=('Directory for storing input data (default ' + '/tmp/tensorflow/mnist/input_data)')) + parser.add_argument( + '--batch-size', + type=int, + default=128, + metavar='N', + help='input batch size for training (default: 128)') + parser.add_argument( + '--log-interval', + type=int, + default=100, + metavar='N', + help=('number of batches between logging and writing summaries ' + '(default: 100)')) + parser.add_argument( + '--output_dir', + type=str, + default=None, + metavar='DIR', + help='Directory to write TensorBoard summaries (defaults to none)') + parser.add_argument( + '--checkpoint_dir', + type=str, + default='/tmp/tensorflow/mnist/checkpoints/', + metavar='DIR', + help=('Directory to save checkpoints in (once per epoch) (default ' + '/tmp/tensorflow/mnist/checkpoints/)')) + parser.add_argument( + '--lr', + type=float, + default=0.001, + metavar='LR', + help='learning rate (default: 0.001)') + parser.add_argument( + '--noise', + type=int, + default=100, + metavar='N', + help='Length of noise vector for generator input (default: 100)') + parser.add_argument( + '--no-gpu', + action='store_true', + default=False, + help='disables GPU usage even if a GPU is available') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py new file mode 100644 index 0000000000..12b39b0cde --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -0,0 +1,151 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.gan import mnist + +NOISE_DIM = 100 +# Big enough so that summaries are never recorded. +# Lower this value if would like to benchmark with some summaries. +SUMMARY_INTERVAL = 10000 +SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +class MnistGraphGanBenchmark(tf.test.Benchmark): + + def _create_graph(self, batch_size): + # Generate some random data. + images_data = np.random.randn(batch_size, 784).astype(np.float32) + dataset = tf.data.Dataset.from_tensors(images_data) + images = dataset.repeat().make_one_shot_iterator().get_next() + + # Create the models and optimizers + generator = mnist.Generator(data_format()) + discriminator = mnist.Discriminator(data_format()) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(0.001) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(0.001) + + # Run models and compute loss + noise_placeholder = tf.placeholder(tf.float32, + shape=[batch_size, NOISE_DIM]) + generated_images = generator(noise_placeholder) + tf.contrib.summary.image('generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) + discriminator_gen_outputs = discriminator(generated_images) + discriminator_real_outputs = discriminator(images) + generator_loss = mnist.generator_loss(discriminator_gen_outputs) + discriminator_loss = mnist.discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs) + # Get train ops + with tf.variable_scope('generator'): + generator_train = generator_optimizer.minimize( + generator_loss, var_list=generator.variables) + with tf.variable_scope('discriminator'): + discriminator_train = discriminator_optimizer.minimize( + discriminator_loss, var_list=discriminator.variables) + + return (generator_train, discriminator_train, noise_placeholder) + + def _report(self, test_name, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' + name = 'graph_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, + data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_train(self): + for batch_size in [64, 128, 256]: + with tf.Graph().as_default(): + global_step = tf.train.get_or_create_global_step() + increment_global_step = tf.assign_add(global_step, 1) + with tf.contrib.summary.create_file_writer( + tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(), ( + tf.contrib.summary.record_summaries_every_n_global_steps( + SUMMARY_INTERVAL)): + (generator_train, discriminator_train, noise_placeholder + ) = self._create_graph(batch_size) + + with tf.Session() as sess: + tf.contrib.summary.initialize(graph=tf.get_default_graph(), + session=sess) + + sess.run(tf.global_variables_initializer()) + + num_burn, num_iters = (3, 100) + for _ in range(num_burn): + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + # Increment global step before evaluating summary ops to avoid + # race condition. + sess.run(increment_global_step) + sess.run([generator_train, discriminator_train, + tf.contrib.summary.all_summary_ops()], + feed_dict={noise_placeholder: noise}) + + # Run and benchmark 2 epochs + start = time.time() + for _ in range(num_iters): + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + sess.run(increment_global_step) + sess.run([generator_train, discriminator_train, + tf.contrib.summary.all_summary_ops()], + feed_dict={noise_placeholder: noise}) + self._report('train', start, num_iters, batch_size) + + def benchmark_generate(self): + for batch_size in [64, 128, 256]: + with tf.Graph().as_default(): + # Using random weights. This will generate garbage. + generator = mnist.Generator(data_format()) + noise_placeholder = tf.placeholder(tf.float32, + shape=[batch_size, NOISE_DIM]) + generated_images = generator(noise_placeholder) + + init = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init) + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + num_burn, num_iters = (30, 1000) + for _ in range(num_burn): + sess.run(generated_images, feed_dict={noise_placeholder: noise}) + + start = time.time() + for _ in range(num_iters): + # Comparison with the eager execution benchmark in mnist_test.py + # isn't entirely fair as the time here includes the cost of copying + # the feeds from CPU memory to GPU. + sess.run(generated_images, feed_dict={noise_placeholder: noise}) + self._report('generate', start, num_iters, batch_size) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py new file mode 100644 index 0000000000..4a3ca8d82b --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.gan import mnist + +NOISE_DIM = 100 +# Big enough so that summaries are never recorded. +# Lower this value if would like to benchmark with some summaries. +SUMMARY_INTERVAL = 10000 +SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +def device(): + return '/gpu:0' if tfe.num_gpus() else '/cpu:0' + + +class MnistEagerGanBenchmark(tf.test.Benchmark): + + def _report(self, test_name, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tfe.num_gpus() else 'cpu' + name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, + data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_train(self): + for batch_size in [64, 128, 256]: + # Generate some random data. + burn_batches, measure_batches = (3, 100) + burn_images = [tf.random_normal([batch_size, 784]) + for _ in range(burn_batches)] + burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images) + measure_images = [tf.random_normal([batch_size, 784]) + for _ in range(measure_batches)] + measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images) + + tf.train.get_or_create_global_step() + with tf.device(device()): + # Create the models and optimizers + generator = mnist.Generator(data_format()) + discriminator = mnist.Discriminator(data_format()) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(0.001) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(0.001) + + with tf.contrib.summary.create_file_writer( + tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(): + + # warm up + mnist.train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + burn_dataset, log_interval=SUMMARY_INTERVAL, + noise_dim=NOISE_DIM) + # measure + start = time.time() + mnist.train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + measure_dataset, log_interval=SUMMARY_INTERVAL, + noise_dim=NOISE_DIM) + self._report('train', start, measure_batches, batch_size) + + def benchmark_generate(self): + for batch_size in [64, 128, 256]: + with tf.device(device()): + # Using random weights. This will generate garbage. + generator = mnist.Generator(data_format()) + + num_burn, num_iters = (30, 1000) + for _ in range(num_burn): + noise = tf.random_uniform(shape=[batch_size, NOISE_DIM], + minval=-1., maxval=1.) + generator(noise) + + start = time.time() + for _ in range(num_iters): + noise = tf.random_uniform(shape=[batch_size, NOISE_DIM], + minval=-1., maxval=1.) + generator(noise) + self._report('generate', start, num_iters, batch_size) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + tf.test.main() -- cgit v1.2.3