aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/README.md
blob: 4ead66ca13e74bacc0e4679a8d5c4e0f23d04b69 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
<!-- TODO(joelshor): Add images to the examples. -->
# TensorFlow-GAN (TFGAN)

TFGAN is a lightweight library for training and evaluating Generative
Adversarial Networks (GANs). This technique allows you to train a network
(called the 'generator') to sample from a distribution, without having to
explicitly model the distribution and without writing an explicit loss. For
example, the generator could learn to draw samples from the distribution of
natural images. For more details on this technique, see
['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by
Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an
introduction.

#### Usage
```python
import tensorflow as tf
tfgan = tf.contrib.gan
```

## Why TFGAN?

* Easily train generator and discriminator networks with well-tested, flexible [library calls](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py). You can
mix TFGAN, native TF, and other custom frameworks
* Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc)
* [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them
* Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training
* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/)
* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model
* Improvements in TFGAN infrastructure will automatically benefit your TFGAN project
* Stay up-to-date with research as we add more algorithms

## What are the TFGAN components?

TFGAN is composed of several parts which were design to exist independently.
These include the following main pieces (explained in detail below).

* [core](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/train.py):
provides the main infrastructure needed to train a GAN. Training occurs in four phases, and each phase
can be completed by custom-code or by using a TFGAN library call.

* [features](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/):
Many common GAN operations and normalization techniques are implemented for you
to use, such as instance normalization and conditioning.

* [losses](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/):
Easily experiment with already-implemented and well-tested losses and penalties,
such as the Wasserstein loss, gradient penalty, mutual information penalty, etc

* [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/):
Use `Inception Score` or `Frechet Distance` with a pretrained Inception
network to evaluate your unconditional generative model. You can also use
your own pretrained classifier for more specific performance numbers, or use
other methods for evaluating conditional generative models.

* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb):
See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your
own project. These include unconditional and conditional GANs, InfoGANs,
adversarial losses on existing networks, and image-to-image translation.

## Training a GAN model

Training in TFGAN typically consists of the following steps:

1. Specify the input to your networks.
1. Set up your generator and discriminator using a `GANModel`.
1. Specify your loss using a `GANLoss`.
1. Create your train ops using a `GANTrainOps`.
1. Run your train ops.

At each stage, you can either use TFGAN's convenience functions, or you can
perform the step manually for fine-grained control. We provide examples below.

There are various types of GAN setups. For instance, you can train a generator
to sample unconditionally from a learned distribution, or you can condition on
extra information such as a class label. TFGAN is compatible with many setups,
and we demonstrate a few below:

### Examples

#### Unconditional MNIST generation

This example trains a generator to produce handwritten MNIST digits. The generator maps
random draws from a multivariate normal distribution to MNIST digit images. See
['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by
Goodfellow et al.

```python
# Set up the input.
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=mnist.unconditional_generator,  # you define
    discriminator_fn=mnist.unconditional_discriminator,  # you define
    real_data=images,
    generator_inputs=noise)

# Build the GAN loss.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss)

# Create the train ops, which calculate gradients and apply updates to weights.
train_ops = tfgan.gan_train_ops(
    gan_model,
    gan_loss,
    generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
    discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))

# Run the train ops in the alternating training scheme.
tfgan.gan_train(
    train_ops,
    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
    logdir=FLAGS.train_log_dir)
```

#### Conditional MNIST generation
This example trains a generator to generate MNIST images *of a given class*.
The generator maps random draws from a multivariate normal distribution and a
one-hot label of the desired digit class to an MNIST digit image. See
['Conditional Generative Adversarial Nets'](https://arxiv.org/abs/1411.1784) by
Mirza and Osindero.

```python
# Set up the input.
images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=mnist.conditional_generator,  # you define
    discriminator_fn=mnist.conditional_discriminator,  # you define
    real_data=images,
    generator_inputs=(noise, one_hot_labels))

# The rest is the same as in the unconditional case.
...
```
#### Adversarial loss
This example combines an L1 pixel loss and an adversarial loss to learn to
autoencode images. The bottleneck layer can be used to transmit compressed
representations of the image. Neutral networks with pixel-wise loss only tend to
produce blurry results, so the GAN can be used to make the reconstructions more
plausible. See ['Full Resolution Image Compression with Recurrent Neural Networks'](https://arxiv.org/abs/1608.05148) by Toderici et al
for an example of neural networks used for image compression, and ['Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network'](https://arxiv.org/abs/1609.04802) by Ledig et al for a more detailed description of
how GANs can sharpen image output.

```python
# Set up the input pipeline.
images = image_provider.provide_data(FLAGS.batch_size)

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=nets.autoencoder,  # you define
    discriminator_fn=nets.discriminator,  # you define
    real_data=images,
    generator_inputs=images)

# Build the GAN loss and standard pixel loss.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# Modify the loss tuple to include the pixel loss.
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

# The rest is the same as in the unconditional case.
...
```

#### Image-to-image translation
This example maps images in one domain to images of the same size in a different
dimension. For example, it can map segmentation masks to street images, or
grayscale images to color. See ['Image-to-Image Translation with Conditional Adversarial Networks'](https://arxiv.org/abs/1611.07004) by Isola et al for more details.

```python
# Set up the input pipeline.
input_image, target_image = data_provider.provide_data(FLAGS.batch_size)

# Build the generator and discriminator.
gan_model = tfgan.gan_model(
    generator_fn=nets.generator,  # you define
    discriminator_fn=nets.discriminator,  # you define
    real_data=target_image,
    generator_inputs=input_image)

# Build the GAN loss and standard pixel loss.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.least_squares_generator_loss,
    discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)

# Modify the loss tuple to include the pixel loss.
gan_loss = tfgan.losses.combine_adversarial_loss(
    gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)

# The rest is the same as in the unconditional case.
...
```

#### InfoGAN
Train a generator to generate specific MNIST digit images, and control for digit style *without using any labels*. See ['InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets'](https://arxiv.org/abs/1606.03657) for more details.

```python
# Set up the input pipeline.
images = mnist_data_provider.provide_data(FLAGS.batch_size)

# Build the generator and discriminator.
gan_model = tfgan.infogan_model(
    generator_fn=mnist.infogan_generator,  # you define
    discriminator_fn=mnist.infogran_discriminator,  # you define
    real_data=images,
    unstructured_generator_inputs=unstructured_inputs,  # you define
    structured_generator_inputs=structured_inputs)  # you define

# Build the GAN loss with mutual information penalty.
gan_loss = tfgan.gan_loss(
    gan_model,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
    gradient_penalty=1.0,
    mutual_information_penalty_weight=1.0)

# The rest is the same as in the unconditional case.
...
```

#### Custom model creation
Train an unconditional GAN to generate MNIST digits, but manually construct
the `GANModel` tuple for more fine-grained control.

```python
# Set up the input pipeline.
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])

# Manually build the generator and discriminator.
with tf.variable_scope('Generator') as gen_scope:
  generated_images = generator_fn(noise)
with tf.variable_scope('Discriminator') as dis_scope:
  discriminator_gen_outputs = discriminator_fn(generated_images)
with variable_scope.variable_scope(dis_scope, reuse=True):
  discriminator_real_outputs = discriminator_fn(images)
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
# Depending on what TFGAN features you use, you don't always need to supply
# every `GANModel` field. At a minimum, you need to include the discriminator
# outputs and variables if you want to use TFGAN to construct losses.
gan_model = tfgan.GANModel(
    generator_inputs,
    generated_data,
    generator_variables,
    gen_scope,
    generator_fn,
    real_data,
    discriminator_real_outputs,
    discriminator_gen_outputs,
    discriminator_variables,
    dis_scope,
    discriminator_fn)

# The rest is the same as the unconditional case.
...
```


## Authors
Joel Shor (github: [joel-shor](https://github.com/joel-shor)) and Sergio Guadarrama (github: [sguada](https://github.com/sguada))