aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-03 17:17:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-03 17:21:13 -0700
commit4cf61262ae34d342d8cf094f12ea19ffc02e84bc (patch)
tree6c7bc313c78b2e6ddd0295eeca67684fbef86baf /tensorflow/contrib/gan
parent0068086b9a288281ead6300ff9bec3c1d7afcc1d (diff)
Improve TFGAN documentation.
PiperOrigin-RevId: 170940188
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py37
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py7
-rw-r--r--tensorflow/contrib/gan/python/train.py89
3 files changed, 91 insertions, 42 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
index fca8063891..b341f03a0d 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
@@ -14,10 +14,41 @@
# ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
-Example:
+The losses and penalties in this file all correspond to losses in
+`losses_impl.py`. Losses in that file take individual arguments, whereas in this
+file they take a `GANModel` tuple. For example:
+
+losses_impl.py:
+ ```python
+ def wasserstein_discriminator_loss(
+ discriminator_real_outputs,
+ discriminator_gen_outputs,
+ real_weights=1.0,
+ generated_weights=1.0,
+ scope=None,
+ loss_collection=ops.GraphKeys.LOSSES,
+ reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+ add_summaries=False)
+ ```
+
+tuple_losses_impl.py:
+ ```python
+ def wasserstein_discriminator_loss(
+ gan_model,
+ real_weights=1.0,
+ generated_weights=1.0,
+ scope=None,
+ loss_collection=ops.GraphKeys.LOSSES,
+ reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
+ add_summaries=False)
+ ```
+
+
+
+Example usage:
```python
- # `tfgan.losses.args` losses take individual arguments.
- w_loss = tfgan.losses.args.wasserstein_discriminator_loss(
+ # `tfgan.losses.wargs` losses take individual arguments.
+ w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs)
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index a99e3fbec8..27512526c4 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Named tuples for TFGAN."""
+"""Named tuples for TFGAN.
+
+TFGAN training occurs in four steps, and each step communicates with the next
+step via one of these named tuples. At each step, you can either use a TFGAN
+helper function in `train.py`, or you can manually construct a tuple.
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index cdc4d78e5b..06dd281489 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -14,7 +14,17 @@
# ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework.
-See examples in `tensorflow_models` for details on how to use.
+This file contains the core helper functions to create and train a GAN model.
+See the README or examples in `tensorflow_models` for details on how to use.
+
+TFGAN training occurs in four steps:
+1) Create a model
+2) Add a loss
+3) Create train ops
+4) Run the train ops
+
+The functions in this file are organized around these four steps. Each function
+corresponds to one of the steps.
"""
from __future__ import absolute_import
@@ -51,16 +61,6 @@ __all__ = [
]
-def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
- """Convert input, list of inputs, or dictionary of inputs to Tensors."""
- if isinstance(tensor_or_l_or_d, (list, tuple)):
- return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
- elif isinstance(tensor_or_l_or_d, dict):
- return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
- else:
- return ops.convert_to_tensor(tensor_or_l_or_d)
-
-
def gan_model(
# Lambdas defining models.
generator_fn,
@@ -133,20 +133,6 @@ def gan_model(
discriminator_fn)
-def _validate_distributions(distributions_l, noise_l):
- if not isinstance(distributions_l, (tuple, list)):
- raise ValueError('`predicted_distributions` must be a list. Instead, found '
- '%s.' % type(distributions_l))
- for dist in distributions_l:
- if not isinstance(dist, ds.Distribution):
- raise ValueError('Every element in `predicted_distributions` must be a '
- '`tf.Distribution`. Instead, found %s.' % type(dist))
- if len(distributions_l) != len(noise_l):
- raise ValueError('Length of `predicted_distributions` %i must be the same '
- 'as the length of structured noise %i.' %
- (len(distributions_l), len(noise_l)))
-
-
def infogan_model(
# Lambdas defining models.
generator_fn,
@@ -231,16 +217,6 @@ def infogan_model(
predicted_distributions)
-def _validate_acgan_discriminator_outputs(discriminator_output):
- try:
- a, b = discriminator_output
- except (TypeError, ValueError):
- raise TypeError(
- 'A discriminator function for ACGAN must output a tuple '
- 'consisting of (discrimination logits, classification logits).')
- return a, b
-
-
def acgan_model(
# Lambdas defining models.
generator_fn,
@@ -252,6 +228,7 @@ def acgan_model(
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
+ # Options.
check_shapes=True):
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
@@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
def gan_train_ops(
- model, # GANModel
- loss, # GANLoss
+ model,
+ loss,
generator_optimizer,
discriminator_optimizer,
- # Optional check flags.
check_for_unused_update_ops=True,
# Optional args to pass directly to the `create_train_op`.
**kwargs):
@@ -801,3 +777,40 @@ def get_sequential_train_steps(
return gen_loss + dis_loss, should_stop
return sequential_train_steps
+
+
+# Helpers
+
+
+def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
+ """Convert input, list of inputs, or dictionary of inputs to Tensors."""
+ if isinstance(tensor_or_l_or_d, (list, tuple)):
+ return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
+ elif isinstance(tensor_or_l_or_d, dict):
+ return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
+ else:
+ return ops.convert_to_tensor(tensor_or_l_or_d)
+
+
+def _validate_distributions(distributions_l, noise_l):
+ if not isinstance(distributions_l, (tuple, list)):
+ raise ValueError('`predicted_distributions` must be a list. Instead, found '
+ '%s.' % type(distributions_l))
+ for dist in distributions_l:
+ if not isinstance(dist, ds.Distribution):
+ raise ValueError('Every element in `predicted_distributions` must be a '
+ '`tf.Distribution`. Instead, found %s.' % type(dist))
+ if len(distributions_l) != len(noise_l):
+ raise ValueError('Length of `predicted_distributions` %i must be the same '
+ 'as the length of structured noise %i.' %
+ (len(distributions_l), len(noise_l)))
+
+
+def _validate_acgan_discriminator_outputs(discriminator_output):
+ try:
+ a, b = discriminator_output
+ except (TypeError, ValueError):
+ raise TypeError(
+ 'A discriminator function for ACGAN must output a tuple '
+ 'consisting of (discrimination logits, classification logits).')
+ return a, b