diff options
author | 2017-10-03 17:17:07 -0700 | |
---|---|---|
committer | 2017-10-03 17:21:13 -0700 | |
commit | 4cf61262ae34d342d8cf094f12ea19ffc02e84bc (patch) | |
tree | 6c7bc313c78b2e6ddd0295eeca67684fbef86baf /tensorflow/contrib/gan | |
parent | 0068086b9a288281ead6300ff9bec3c1d7afcc1d (diff) |
Improve TFGAN documentation.
PiperOrigin-RevId: 170940188
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py | 37 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/namedtuples.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/train.py | 89 |
3 files changed, 91 insertions, 42 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index fca8063891..b341f03a0d 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -14,10 +14,41 @@ # ============================================================================== """TFGAN utilities for loss functions that accept GANModel namedtuples. -Example: +The losses and penalties in this file all correspond to losses in +`losses_impl.py`. Losses in that file take individual arguments, whereas in this +file they take a `GANModel` tuple. For example: + +losses_impl.py: + ```python + def wasserstein_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False) + ``` + +tuple_losses_impl.py: + ```python + def wasserstein_discriminator_loss( + gan_model, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False) + ``` + + + +Example usage: ```python - # `tfgan.losses.args` losses take individual arguments. - w_loss = tfgan.losses.args.wasserstein_discriminator_loss( + # `tfgan.losses.wargs` losses take individual arguments. + w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss( discriminator_real_outputs, discriminator_gen_outputs) diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index a99e3fbec8..27512526c4 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN.""" +"""Named tuples for TFGAN. + +TFGAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TFGAN +helper function in `train.py`, or you can manually construct a tuple. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index cdc4d78e5b..06dd281489 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -14,7 +14,17 @@ # ============================================================================== """The TFGAN project provides a lightweight GAN training/testing framework. -See examples in `tensorflow_models` for details on how to use. +This file contains the core helper functions to create and train a GAN model. +See the README or examples in `tensorflow_models` for details on how to use. + +TFGAN training occurs in four steps: +1) Create a model +2) Add a loss +3) Create train ops +4) Run the train ops + +The functions in this file are organized around these four steps. Each function +corresponds to one of the steps. """ from __future__ import absolute_import @@ -51,16 +61,6 @@ __all__ = [ ] -def _convert_tensor_or_l_or_d(tensor_or_l_or_d): - """Convert input, list of inputs, or dictionary of inputs to Tensors.""" - if isinstance(tensor_or_l_or_d, (list, tuple)): - return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] - elif isinstance(tensor_or_l_or_d, dict): - return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} - else: - return ops.convert_to_tensor(tensor_or_l_or_d) - - def gan_model( # Lambdas defining models. generator_fn, @@ -133,20 +133,6 @@ def gan_model( discriminator_fn) -def _validate_distributions(distributions_l, noise_l): - if not isinstance(distributions_l, (tuple, list)): - raise ValueError('`predicted_distributions` must be a list. Instead, found ' - '%s.' % type(distributions_l)) - for dist in distributions_l: - if not isinstance(dist, ds.Distribution): - raise ValueError('Every element in `predicted_distributions` must be a ' - '`tf.Distribution`. Instead, found %s.' % type(dist)) - if len(distributions_l) != len(noise_l): - raise ValueError('Length of `predicted_distributions` %i must be the same ' - 'as the length of structured noise %i.' % - (len(distributions_l), len(noise_l))) - - def infogan_model( # Lambdas defining models. generator_fn, @@ -231,16 +217,6 @@ def infogan_model( predicted_distributions) -def _validate_acgan_discriminator_outputs(discriminator_output): - try: - a, b = discriminator_output - except (TypeError, ValueError): - raise TypeError( - 'A discriminator function for ACGAN must output a tuple ' - 'consisting of (discrimination logits, classification logits).') - return a, b - - def acgan_model( # Lambdas defining models. generator_fn, @@ -252,6 +228,7 @@ def acgan_model( # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', + # Options. check_shapes=True): """Returns an ACGANModel contains all the pieces needed for ACGAN training. @@ -497,11 +474,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): def gan_train_ops( - model, # GANModel - loss, # GANLoss + model, + loss, generator_optimizer, discriminator_optimizer, - # Optional check flags. check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): @@ -801,3 +777,40 @@ def get_sequential_train_steps( return gen_loss + dis_loss, should_stop return sequential_train_steps + + +# Helpers + + +def _convert_tensor_or_l_or_d(tensor_or_l_or_d): + """Convert input, list of inputs, or dictionary of inputs to Tensors.""" + if isinstance(tensor_or_l_or_d, (list, tuple)): + return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] + elif isinstance(tensor_or_l_or_d, dict): + return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} + else: + return ops.convert_to_tensor(tensor_or_l_or_d) + + +def _validate_distributions(distributions_l, noise_l): + if not isinstance(distributions_l, (tuple, list)): + raise ValueError('`predicted_distributions` must be a list. Instead, found ' + '%s.' % type(distributions_l)) + for dist in distributions_l: + if not isinstance(dist, ds.Distribution): + raise ValueError('Every element in `predicted_distributions` must be a ' + '`tf.Distribution`. Instead, found %s.' % type(dist)) + if len(distributions_l) != len(noise_l): + raise ValueError('Length of `predicted_distributions` %i must be the same ' + 'as the length of structured noise %i.' % + (len(distributions_l), len(noise_l))) + + +def _validate_acgan_discriminator_outputs(discriminator_output): + try: + a, b = discriminator_output + except (TypeError, ValueError): + raise TypeError( + 'A discriminator function for ACGAN must output a tuple ' + 'consisting of (discrimination logits, classification logits).') + return a, b |