aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 10:09:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 10:13:53 -0700
commitf63aa7f49f81a66112bfef6670a18658d5a479e5 (patch)
tree3e8185bff7e41ee87a292cd1833412965da1b5e2
parentbc6b60f1bc79c2753cea087cf0eba1d76c5702df (diff)
Migrate core TFGAN functions to opensource.
PiperOrigin-RevId: 168391923
-rw-r--r--tensorflow/contrib/gan/BUILD54
-rw-r--r--tensorflow/contrib/gan/__init__.py21
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py149
-rw-r--r--tensorflow/contrib/gan/python/train.py804
-rw-r--r--tensorflow/contrib/gan/python/train_test.py745
5 files changed, 1769 insertions, 4 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index cb2cd7c7ef..c3ae738acf 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -16,6 +16,60 @@ py_library(
deps = [
":features",
":losses",
+ ":namedtuples",
+ ":train",
+ ],
+)
+
+py_library(
+ name = "namedtuples",
+ srcs = ["python/namedtuples.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "train",
+ srcs = ["python/train.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":losses",
+ ":namedtuples",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/slim:learning",
+ "//tensorflow/contrib/training:training_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/ops/distributions",
+ "//tensorflow/python/ops/losses",
+ ],
+)
+
+py_test(
+ name = "train_test",
+ srcs = ["python/train_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":namedtuples",
+ ":train",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/slim:learning",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/ops/distributions",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py
index b2f4bf0119..3c423e72d0 100644
--- a/tensorflow/contrib/gan/__init__.py
+++ b/tensorflow/contrib/gan/__init__.py
@@ -1,4 +1,4 @@
-# Copyright 2016 Google Inc. All Rights Reserved.
+# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,20 @@ from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
+from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python import train
-del absolute_import
-del division
-del print_function
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.gan.python.namedtuples import *
+from tensorflow.contrib.gan.python.train import *
+# pylint: enable=unused-import,wildcard-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ 'features',
+ 'losses',
+]
+_allowed_symbols += train.__all__
+_allowed_symbols += namedtuples.__all__
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
new file mode 100644
index 0000000000..a99e3fbec8
--- /dev/null
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -0,0 +1,149 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Named tuples for TFGAN."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+__all__ = [
+ 'GANModel',
+ 'InfoGANModel',
+ 'ACGANModel',
+ 'GANLoss',
+ 'GANTrainOps',
+ 'GANTrainSteps',
+]
+
+
+class GANModel(
+ collections.namedtuple('GANModel', (
+ 'generator_inputs',
+ 'generated_data',
+ 'generator_variables',
+ 'generator_scope',
+ 'generator_fn',
+ 'real_data',
+ 'discriminator_real_outputs',
+ 'discriminator_gen_outputs',
+ 'discriminator_variables',
+ 'discriminator_scope',
+ 'discriminator_fn',
+ ))):
+ """A GANModel contains all the pieces needed for GAN training.
+
+ Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt
+ to create an implicit generative model of data by solving a two agent game.
+ The generator generates candidate examples that are supposed to match the
+ data distribution, and the discriminator aims to tell the real examples
+ apart from the generated samples.
+
+ Args:
+ generator_inputs: The random noise source that acts as input to the
+ generator.
+ generated_data: The generated output data of the GAN.
+ generator_variables: A list of all generator variables.
+ generator_scope: Variable scope all generator variables live in.
+ generator_fn: The generator function.
+ real_data: A tensor or real data.
+ discriminator_real_outputs: The discriminator's output on real data.
+ discriminator_gen_outputs: The discriminator's output on generated data.
+ discriminator_variables: A list of all discriminator variables.
+ discriminator_scope: Variable scope all discriminator variables live in.
+ discriminator_fn: The discriminator function.
+ """
+
+
+# TODO(joelshor): Have this class inherit from `GANModel`.
+class InfoGANModel(
+ collections.namedtuple('InfoGANModel', GANModel._fields + (
+ 'structured_generator_inputs',
+ 'predicted_distributions',
+ ))):
+ """An InfoGANModel contains all the pieces needed for InfoGAN training.
+
+ See https://arxiv.org/abs/1606.03657 for more details.
+
+ Args:
+ structured_generator_inputs: A list of Tensors representing the random noise
+ that must have high mutual information with the generator output. List
+ length should match `predicted_distributions`.
+ predicted_distributions: A list of tf.Distributions. Predicted by the
+ recognizer, and used to evaluate the likelihood of the structured noise.
+ List length should match `structured_generator_inputs`.
+ """
+
+
+class ACGANModel(
+ collections.namedtuple('ACGANModel', GANModel._fields +
+ ('one_hot_labels',
+ 'discriminator_real_classification_logits',
+ 'discriminator_gen_classification_logits',))):
+ """An ACGANModel contains all the pieces needed for ACGAN training.
+
+ See https://arxiv.org/abs/1610.09585 for more details.
+
+ Args:
+ one_hot_labels: A Tensor holding one-hot-labels for the batch.
+ discriminator_real_classification_logits: Classification logits for real
+ data.
+ discriminator_gen_classification_logits: Classification logits for generated
+ data.
+ """
+
+
+class GANLoss(
+ collections.namedtuple('GANLoss', (
+ 'generator_loss',
+ 'discriminator_loss'
+ ))):
+ """GANLoss contains the generator and discriminator losses.
+
+ Args:
+ generator_loss: A tensor for the generator loss..
+ discriminator_loss: A tensor for the discriminator loss.
+ """
+
+
+class GANTrainOps(
+ collections.namedtuple('GANTrainOps', (
+ 'generator_train_op',
+ 'discriminator_train_op',
+ 'global_step_inc_op'
+ ))):
+ """GANTrainOps contains the training ops.
+
+ Args:
+ generator_train_op: Op that performs a generator update step.
+ discriminator_train_op: Op that performs a discriminator update step.
+ global_step_inc_op: Op that increments the shared global step.
+ """
+
+
+class GANTrainSteps(
+ collections.namedtuple('GANTrainSteps', (
+ 'generator_train_steps',
+ 'discriminator_train_steps'
+ ))):
+ """Contains configuration for the GAN Training.
+
+ Args:
+ generator_train_steps: Number of generator steps to take in each GAN step.
+ discriminator_train_steps: Number of discriminator steps to take in each GAN
+ step.
+ """
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
new file mode 100644
index 0000000000..af7dbcf249
--- /dev/null
+++ b/tensorflow/contrib/gan/python/train.py
@@ -0,0 +1,804 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The TFGAN project provides a lightweight GAN training/testing framework.
+
+See examples in `tensorflow_models` for details on how to use.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.framework.python.ops import variables as variables_lib
+from tensorflow.contrib.gan.python import losses as tfgan_losses
+from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.slim.python.slim import learning as slim_learning
+from tensorflow.contrib.training.python.training import training
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.distributions import distribution as ds
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.training import sync_replicas_optimizer
+from tensorflow.python.training import training_util
+
+
+__all__ = [
+ 'gan_model',
+ 'infogan_model',
+ 'acgan_model',
+ 'gan_loss',
+ 'gan_train_ops',
+ 'gan_train',
+ 'get_sequential_train_hooks',
+ 'get_joint_train_hooks',
+ 'get_sequential_train_steps',
+]
+
+
+def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
+ """Convert input, list of inputs, or dictionary of inputs to Tensors."""
+ if isinstance(tensor_or_l_or_d, (list, tuple)):
+ return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
+ elif isinstance(tensor_or_l_or_d, dict):
+ return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
+ else:
+ return ops.convert_to_tensor(tensor_or_l_or_d)
+
+
+def gan_model(
+ # Lambdas defining models.
+ generator_fn,
+ discriminator_fn,
+ # Real data and conditioning.
+ real_data,
+ generator_inputs,
+ # Optional scopes.
+ generator_scope='Generator',
+ discriminator_scope='Discriminator',
+ # Options.
+ check_shapes=True):
+ """Returns GAN model outputs and variables.
+
+ Args:
+ generator_fn: A python lambda that takes `generator_inputs` as inputs and
+ returns the outputs of the GAN generator.
+ discriminator_fn: A python lambda that takes `real_data`/`generated data`
+ and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
+ real_data: A Tensor representing the real data.
+ generator_inputs: A Tensor or list of Tensors to the generator. In the
+ vanilla GAN case, this might be a single noise Tensor. In the conditional
+ GAN case, this might be the generator's conditioning.
+ generator_scope: Optional generator variable scope. Useful if you want to
+ reuse a subgraph that has already been created.
+ discriminator_scope: Optional discriminator variable scope. Useful if you
+ want to reuse a subgraph that has already been created.
+ check_shapes: If `True`, check that generator produces Tensors that are the
+ same shape as real data. Otherwise, skip this check.
+
+ Returns:
+ A GANModel namedtuple.
+
+ Raises:
+ ValueError: If the generator outputs a Tensor that isn't the same shape as
+ `real_data`.
+ """
+ # Create models
+ with variable_scope.variable_scope(generator_scope) as gen_scope:
+ generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
+ generated_data = generator_fn(generator_inputs)
+ with variable_scope.variable_scope(discriminator_scope) as dis_scope:
+ discriminator_gen_outputs = discriminator_fn(generated_data,
+ generator_inputs)
+ with variable_scope.variable_scope(dis_scope, reuse=True):
+ real_data = ops.convert_to_tensor(real_data)
+ discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
+
+ if check_shapes:
+ if not generated_data.shape.is_compatible_with(real_data.shape):
+ raise ValueError(
+ 'Generator output shape (%s) must be the same shape as real data '
+ '(%s).' % (generated_data.shape, real_data.shape))
+
+ # Get model-specific variables.
+ generator_variables = variables_lib.get_trainable_variables(gen_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
+
+ return namedtuples.GANModel(
+ generator_inputs,
+ generated_data,
+ generator_variables,
+ gen_scope,
+ generator_fn,
+ real_data,
+ discriminator_real_outputs,
+ discriminator_gen_outputs,
+ discriminator_variables,
+ dis_scope,
+ discriminator_fn)
+
+
+def _validate_distributions(distributions_l, noise_l):
+ if not isinstance(distributions_l, (tuple, list)):
+ raise ValueError('`predicted_distributions` must be a list. Instead, found '
+ '%s.' % type(distributions_l))
+ for dist in distributions_l:
+ if not isinstance(dist, ds.Distribution):
+ raise ValueError('Every element in `predicted_distributions` must be a '
+ '`tf.Distribution`. Instead, found %s.' % type(dist))
+ if len(distributions_l) != len(noise_l):
+ raise ValueError('Length of `predicted_distributions` %i must be the same '
+ 'as the length of structured noise %i.' %
+ (len(distributions_l), len(noise_l)))
+
+
+def infogan_model(
+ # Lambdas defining models.
+ generator_fn,
+ discriminator_fn,
+ # Real data and conditioning.
+ real_data,
+ unstructured_generator_inputs,
+ structured_generator_inputs,
+ # Optional scopes.
+ generator_scope='Generator',
+ discriminator_scope='Discriminator'):
+ """Returns an InfoGAN model outputs and variables.
+
+ See https://arxiv.org/abs/1606.03657 for more details.
+
+ Args:
+ generator_fn: A python lambda that takes a list of Tensors as inputs and
+ returns the outputs of the GAN generator.
+ discriminator_fn: A python lambda that takes `real_data`/`generated data`
+ and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
+ `logits` are in the range [-inf, inf], and `distribution_list` is a list
+ of Tensorflow distributions representing the predicted noise distribution
+ of the ith structure noise.
+ real_data: A Tensor representing the real data.
+ unstructured_generator_inputs: A list of Tensors to the generator.
+ These tensors represent the unstructured noise or conditioning.
+ structured_generator_inputs: A list of Tensors to the generator.
+ These tensors must have high mutual information with the recognizer.
+ generator_scope: Optional generator variable scope. Useful if you want to
+ reuse a subgraph that has already been created.
+ discriminator_scope: Optional discriminator variable scope. Useful if you
+ want to reuse a subgraph that has already been created.
+
+ Returns:
+ An InfoGANModel namedtuple.
+
+ Raises:
+ ValueError: If the generator outputs a Tensor that isn't the same shape as
+ `real_data`.
+ ValueError: If the discriminator output is malformed.
+ """
+ # Create models
+ with variable_scope.variable_scope(generator_scope) as gen_scope:
+ unstructured_generator_inputs = _convert_tensor_or_l_or_d(
+ unstructured_generator_inputs)
+ structured_generator_inputs = _convert_tensor_or_l_or_d(
+ structured_generator_inputs)
+ generator_inputs = (
+ unstructured_generator_inputs + structured_generator_inputs)
+ generated_data = generator_fn(generator_inputs)
+ with variable_scope.variable_scope(discriminator_scope) as disc_scope:
+ dis_gen_outputs, predicted_distributions = discriminator_fn(
+ generated_data, generator_inputs)
+ _validate_distributions(predicted_distributions, structured_generator_inputs)
+ with variable_scope.variable_scope(disc_scope, reuse=True):
+ real_data = ops.convert_to_tensor(real_data)
+ dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
+
+ if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
+ raise ValueError(
+ 'Generator output shape (%s) must be the same shape as real data '
+ '(%s).' % (generated_data.get_shape(), real_data.get_shape()))
+
+ # Get model-specific variables.
+ generator_variables = variables_lib.get_trainable_variables(gen_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(
+ disc_scope)
+
+ return namedtuples.InfoGANModel(
+ generator_inputs,
+ generated_data,
+ generator_variables,
+ gen_scope,
+ generator_fn,
+ real_data,
+ dis_real_outputs,
+ dis_gen_outputs,
+ discriminator_variables,
+ disc_scope,
+ lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
+ structured_generator_inputs,
+ predicted_distributions)
+
+
+def _validate_acgan_discriminator_outputs(discriminator_output):
+ try:
+ a, b = discriminator_output
+ except (TypeError, ValueError):
+ raise TypeError(
+ 'A discriminator function for ACGAN must output a tuple '
+ 'consisting of (discrimination logits, classification logits).')
+ return a, b
+
+
+def acgan_model(
+ # Lambdas defining models.
+ generator_fn,
+ discriminator_fn,
+ # Real data and conditioning.
+ real_data,
+ generator_inputs,
+ one_hot_labels,
+ # Optional scopes.
+ generator_scope='Generator',
+ discriminator_scope='Discriminator',
+ check_shapes=True):
+ """Returns an ACGANModel contains all the pieces needed for ACGAN training.
+
+ The `acgan_model` is the same as the `gan_model` with the only difference
+ being that the discriminator additionally outputs logits to classify the input
+ (real or generated).
+ Therefore, an explicit field holding one_hot_labels is necessary, as well as a
+ discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
+ classification.
+
+ See https://arxiv.org/abs/1610.09585 for more details.
+
+ Args:
+ generator_fn: A python lambda that takes `generator_inputs` as inputs and
+ returns the outputs of the GAN generator.
+ discriminator_fn: A python lambda that takes `real_data`/`generated data`
+ and `generator_inputs`. Outputs a tuple consisting of two Tensors:
+ (1) real/fake logits in the range [-inf, inf]
+ (2) classification logits in the range [-inf, inf]
+ real_data: A Tensor representing the real data.
+ generator_inputs: A Tensor or list of Tensors to the generator. In the
+ vanilla GAN case, this might be a single noise Tensor. In the conditional
+ GAN case, this might be the generator's conditioning.
+ one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
+ acgan_loss.
+ generator_scope: Optional generator variable scope. Useful if you want to
+ reuse a subgraph that has already been created.
+ discriminator_scope: Optional discriminator variable scope. Useful if you
+ want to reuse a subgraph that has already been created.
+ check_shapes: If `True`, check that generator produces Tensors that are the
+ same shape as real data. Otherwise, skip this check.
+
+ Returns:
+ A ACGANModel namedtuple.
+
+ Raises:
+ ValueError: If the generator outputs a Tensor that isn't the same shape as
+ `real_data`.
+ TypeError: If the discriminator does not output a tuple consisting of
+ (discrimination logits, classification logits).
+ """
+ # Create models
+ with variable_scope.variable_scope(generator_scope) as gen_scope:
+ generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
+ generated_data = generator_fn(generator_inputs)
+ with variable_scope.variable_scope(discriminator_scope) as dis_scope:
+ (discriminator_gen_outputs, discriminator_gen_classification_logits
+ ) = _validate_acgan_discriminator_outputs(
+ discriminator_fn(generated_data, generator_inputs))
+ with variable_scope.variable_scope(dis_scope, reuse=True):
+ real_data = ops.convert_to_tensor(real_data)
+ (discriminator_real_outputs, discriminator_real_classification_logits
+ ) = _validate_acgan_discriminator_outputs(
+ discriminator_fn(real_data, generator_inputs))
+ if check_shapes:
+ if not generated_data.shape.is_compatible_with(real_data.shape):
+ raise ValueError(
+ 'Generator output shape (%s) must be the same shape as real data '
+ '(%s).' % (generated_data.shape, real_data.shape))
+
+ # Get model-specific variables.
+ generator_variables = variables_lib.get_trainable_variables(gen_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(
+ dis_scope)
+
+ return namedtuples.ACGANModel(
+ generator_inputs, generated_data, generator_variables, gen_scope,
+ generator_fn, real_data, discriminator_real_outputs,
+ discriminator_gen_outputs, discriminator_variables, dis_scope,
+ discriminator_fn, one_hot_labels,
+ discriminator_real_classification_logits,
+ discriminator_gen_classification_logits)
+
+
+def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
+ if isinstance(aux_loss_weight, ops.Tensor):
+ aux_loss_weight.shape.assert_is_compatible_with([])
+ with ops.control_dependencies(
+ [check_ops.assert_greater_equal(aux_loss_weight, 0.0)]):
+ aux_loss_weight = array_ops.identity(aux_loss_weight)
+ elif aux_loss_weight is not None and aux_loss_weight < 0:
+ raise ValueError('`%s` must be greater than 0. Instead, was %s' %
+ (name, aux_loss_weight))
+ return aux_loss_weight
+
+
+def _use_aux_loss(aux_loss_weight):
+ if aux_loss_weight is not None:
+ if not isinstance(aux_loss_weight, ops.Tensor):
+ return aux_loss_weight > 0
+ else:
+ return True
+ else:
+ return False
+
+
+def gan_loss(
+ # GANModel.
+ model,
+ # Loss functions.
+ generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
+ discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
+ # Auxiliary losses.
+ gradient_penalty_weight=None,
+ gradient_penalty_epsilon=1e-10,
+ mutual_information_penalty_weight=None,
+ aux_cond_generator_weight=None,
+ aux_cond_discriminator_weight=None,
+ # Options.
+ add_summaries=True):
+ """Returns losses necessary to train generator and discriminator.
+
+ Args:
+ model: A GANModel tuple.
+ generator_loss_fn: The loss function on the generator. Takes a GANModel
+ tuple.
+ discriminator_loss_fn: The loss function on the discriminator. Takes a
+ GANModel tuple.
+ gradient_penalty_weight: If not `None`, must be a non-negative Python number
+ or Tensor indicating how much to weight the gradient penalty. See
+ https://arxiv.org/pdf/1704.00028.pdf for more details.
+ gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
+ small positive value used by the gradient penalty function for numerical
+ stability. Note some applications will need to increase this value to
+ avoid NaNs.
+ mutual_information_penalty_weight: If not `None`, must be a non-negative
+ Python number or Tensor indicating how much to weight the mutual
+ information penalty. See https://arxiv.org/abs/1606.03657 for more
+ details.
+ aux_cond_generator_weight: If not None: add a classification loss as in
+ https://arxiv.org/abs/1610.09585
+ aux_cond_discriminator_weight: If not None: add a classification loss as in
+ https://arxiv.org/abs/1610.09585
+ add_summaries: Whether or not to add summaries for the losses.
+
+ Returns:
+ A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
+ regularization losses.
+
+ Raises:
+ ValueError: If any of the auxiliary loss weights is provided and negative.
+ ValueError: If `mutual_information_penalty_weight` is provided, but the
+ `model` isn't an `InfoGANModel`.
+ """
+ # Validate arguments.
+ gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight,
+ 'gradient_penalty_weight')
+ mutual_information_penalty_weight = _validate_aux_loss_weight(
+ mutual_information_penalty_weight, 'infogan_weight')
+ aux_cond_generator_weight = _validate_aux_loss_weight(
+ aux_cond_generator_weight, 'aux_cond_generator_weight')
+ aux_cond_discriminator_weight = _validate_aux_loss_weight(
+ aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
+
+ # Verify configuration for mutual information penalty
+ if (_use_aux_loss(mutual_information_penalty_weight) and
+ not isinstance(model, namedtuples.InfoGANModel)):
+ raise ValueError(
+ 'When `mutual_information_penalty_weight` is provided, `model` must be '
+ 'an `InfoGANModel`. Instead, was %s.' % type(model))
+
+ # Verify configuration for mutual auxiliary condition loss (ACGAN).
+ if ((_use_aux_loss(aux_cond_generator_weight) or
+ _use_aux_loss(aux_cond_discriminator_weight)) and
+ not isinstance(model, namedtuples.ACGANModel)):
+ raise ValueError(
+ 'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
+ 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
+ type(model))
+
+ # Create standard losses.
+ gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
+ dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
+
+ # Add optional extra losses.
+ if _use_aux_loss(gradient_penalty_weight):
+ gp_loss = tfgan_losses.wasserstein_gradient_penalty(
+ model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
+ dis_loss += gradient_penalty_weight * gp_loss
+ if _use_aux_loss(mutual_information_penalty_weight):
+ info_loss = tfgan_losses.mutual_information_penalty(
+ model, add_summaries=add_summaries)
+ dis_loss += mutual_information_penalty_weight * info_loss
+ gen_loss += mutual_information_penalty_weight * info_loss
+ if _use_aux_loss(aux_cond_generator_weight):
+ ac_gen_loss = tfgan_losses.acgan_generator_loss(
+ model, add_summaries=add_summaries)
+ gen_loss += aux_cond_generator_weight * ac_gen_loss
+ if _use_aux_loss(aux_cond_discriminator_weight):
+ ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
+ model, add_summaries=add_summaries)
+ dis_loss += aux_cond_discriminator_weight * ac_disc_loss
+ # Gathers auxilliary losses.
+ if model.generator_scope:
+ gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
+ else:
+ gen_reg_loss = 0
+ if model.discriminator_scope:
+ dis_reg_loss = losses.get_regularization_loss(
+ model.discriminator_scope.name)
+ else:
+ dis_reg_loss = 0
+
+ return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
+
+
+def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
+ """Gets generator and discriminator update ops.
+
+ Args:
+ kwargs: A dictionary of kwargs to be passed to `create_train_op`.
+ `update_ops` is removed, if present.
+ gen_scope: A scope for the generator.
+ dis_scope: A scope for the discriminator.
+ check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
+ unused update ops.
+
+ Returns:
+ A 2-tuple of (generator update ops, discriminator train ops).
+
+ Raises:
+ ValueError: If there are update ops outside of the generator or
+ discriminator scopes.
+ """
+ if 'update_ops' in kwargs:
+ update_ops = set(kwargs['update_ops'])
+ del kwargs['update_ops']
+ else:
+ update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
+
+ all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope))
+ all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope))
+
+ if check_for_unused_ops:
+ unused_ops = update_ops - all_gen_ops - all_dis_ops
+ if unused_ops:
+ raise ValueError('There are unused update ops: %s' % unused_ops)
+
+ gen_update_ops = list(all_gen_ops & update_ops)
+ dis_update_ops = list(all_dis_ops & update_ops)
+
+ return gen_update_ops, dis_update_ops
+
+
+def gan_train_ops(
+ model, # GANModel
+ loss, # GANLoss
+ generator_optimizer,
+ discriminator_optimizer,
+ # Optional check flags.
+ check_for_unused_update_ops=True,
+ # Optional args to pass directly to the `create_train_op`.
+ **kwargs):
+ """Returns GAN train ops.
+
+ The highest-level call in TFGAN. It is composed of functions that can also
+ be called, should a user require more control over some part of the GAN
+ training process.
+
+ Args:
+ model: A GANModel.
+ loss: A GANLoss.
+ generator_optimizer: The optimizer for generator updates.
+ discriminator_optimizer: The optimizer for the discriminator updates.
+ check_for_unused_update_ops: If `True`, throws an exception if there are
+ update ops outside of the generator or discriminator scopes.
+ **kwargs: Keyword args to pass directly to
+ `training.create_train_op` for both the generator and
+ discriminator train op.
+
+ Returns:
+ A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
+ be used to train a generator/discriminator pair.
+ """
+ # Create global step increment op.
+ global_step = training_util.get_or_create_global_step()
+ global_step_inc = global_step.assign_add(1)
+
+ # Get generator and discriminator update ops. We split them so that update
+ # ops aren't accidentally run multiple times. For now, throw an error if
+ # there are update ops that aren't associated with either the generator or
+ # the discriminator. Might modify the `kwargs` dictionary.
+ gen_update_ops, dis_update_ops = _get_update_ops(
+ kwargs, model.generator_scope.name, model.discriminator_scope.name,
+ check_for_unused_update_ops)
+
+ generator_global_step = None
+ if isinstance(generator_optimizer,
+ sync_replicas_optimizer.SyncReplicasOptimizer):
+ # TODO(joelshor): Figure out a way to get this work without including the
+ # dummy global step in the checkpoint.
+ # WARNING: Making this variable a local variable causes sync replicas to
+ # hang forever.
+ generator_global_step = variable_scope.get_variable(
+ 'dummy_global_step_generator',
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ gen_update_ops += [generator_global_step.assign(global_step)]
+ with ops.name_scope('generator_train'):
+ gen_train_op = training.create_train_op(
+ total_loss=loss.generator_loss,
+ optimizer=generator_optimizer,
+ variables_to_train=model.generator_variables,
+ global_step=generator_global_step,
+ update_ops=gen_update_ops,
+ **kwargs)
+
+ discriminator_global_step = None
+ if isinstance(discriminator_optimizer,
+ sync_replicas_optimizer.SyncReplicasOptimizer):
+ # See comment above `generator_global_step`.
+ discriminator_global_step = variable_scope.get_variable(
+ 'dummy_global_step_discriminator',
+ shape=[],
+ dtype=dtypes.int64,
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ collections=[ops.GraphKeys.GLOBAL_VARIABLES])
+ dis_update_ops += [discriminator_global_step.assign(global_step)]
+ with ops.name_scope('discriminator_train'):
+ disc_train_op = training.create_train_op(
+ total_loss=loss.discriminator_loss,
+ optimizer=discriminator_optimizer,
+ variables_to_train=model.discriminator_variables,
+ global_step=discriminator_global_step,
+ update_ops=dis_update_ops,
+ **kwargs)
+
+ return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
+
+
+# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
+# Image Compression` (https://arxiv.org/abs/1705.05823)
+class RunTrainOpsHook(session_run_hook.SessionRunHook):
+ """A hook to run train ops a fixed number of times."""
+
+ def __init__(self, train_ops, train_steps):
+ """Run train ops a certain number of times.
+
+ Args:
+ train_ops: A train op or iterable of train ops to run.
+ train_steps: The number of times to run the op(s).
+ """
+ if not isinstance(train_ops, (list, tuple)):
+ train_ops = [train_ops]
+ self._train_ops = train_ops
+ self._train_steps = train_steps
+
+ def before_run(self, run_context):
+ for _ in range(self._train_steps):
+ run_context.session.run(self._train_ops)
+
+
+def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
+ """Returns a hooks function for sequential GAN training.
+
+ Args:
+ train_steps: A `GANTrainSteps` tuple that determines how many generator
+ and discriminator training steps to take.
+
+ Returns:
+ A function that takes a GANTrainOps tuple and returns a list of hooks.
+ """
+ def get_hooks(train_ops):
+ generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
+ train_steps.generator_train_steps)
+ discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
+ train_steps.discriminator_train_steps)
+ return [generator_hook, discriminator_hook]
+ return get_hooks
+
+
+def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
+ """Returns a hooks function for sequential GAN training.
+
+ When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
+ ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
+
+ The order of steps taken is:
+ 1) Combined generator and discriminator steps
+ 2) Generator only steps, if any remain
+ 3) Discriminator only steps, if any remain
+
+ **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
+ for the generator and discriminator simultaneously whenever possible. This
+ reduces the number of `tf.Session` calls, and can also change the training
+ semantics.
+
+ To illustrate the difference look at the following example:
+
+ `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
+ `get_sequential_train_hooks` to make 8 session calls:
+ 1) 3 generator steps
+ 2) 5 discriminator steps
+
+ In contrast, `get_joint_train_steps` will make 5 session calls:
+ 1) 3 generator + discriminator steps
+ 2) 2 discriminator steps
+
+ Args:
+ train_steps: A `GANTrainSteps` tuple that determines how many generator
+ and discriminator training steps to take.
+
+ Returns:
+ A function that takes a GANTrainOps tuple and returns a list of hooks.
+ """
+ g_steps = train_steps.generator_train_steps
+ d_steps = train_steps.discriminator_train_steps
+ # Get the number of each type of step that should be run.
+ num_d_and_g_steps = min(g_steps, d_steps)
+ num_g_steps = g_steps - num_d_and_g_steps
+ num_d_steps = d_steps - num_d_and_g_steps
+
+ def get_hooks(train_ops):
+ g_op = train_ops.generator_train_op
+ d_op = train_ops.discriminator_train_op
+
+ joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
+ g_hook = RunTrainOpsHook(g_op, num_g_steps)
+ d_hook = RunTrainOpsHook(d_op, num_d_steps)
+
+ return [joint_hook, g_hook, d_hook]
+ return get_hooks
+
+
+# TODO(joelshor): This function currently returns the global step. Find a
+# good way for it to return the generator, discriminator, and final losses.
+def gan_train(
+ train_ops,
+ logdir,
+ get_hooks_fn=get_sequential_train_hooks(),
+ master='',
+ is_chief=True,
+ scaffold=None,
+ hooks=None,
+ chief_only_hooks=None,
+ save_checkpoint_secs=600,
+ save_summaries_steps=100,
+ config=None):
+ """A wrapper around `contrib.training.train` that uses GAN hooks.
+
+ Args:
+ train_ops: A GANTrainOps named tuple.
+ logdir: The directory where the graph and checkpoints are saved.
+ get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
+ of hooks.
+ master: The URL of the master.
+ is_chief: Specifies whether or not the training is being run by the primary
+ replica during replica training.
+ scaffold: An tf.train.Scaffold instance.
+ hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
+ training loop.
+ chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
+ inside the training loop for the chief trainer only.
+ save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
+ using a default checkpoint saver. If `save_checkpoint_secs` is set to
+ `None`, then the default checkpoint saver isn't used.
+ save_summaries_steps: The frequency, in number of global steps, that the
+ summaries are written to disk using a default summary saver. If
+ `save_summaries_steps` is set to `None`, then the default summary saver
+ isn't used.
+ config: An instance of `tf.ConfigProto`.
+
+ Returns:
+ Output of the call to `training.train`.
+ """
+ new_hooks = get_hooks_fn(train_ops)
+ if hooks is not None:
+ hooks = list(hooks) + list(new_hooks)
+ else:
+ hooks = new_hooks
+ return training.train(
+ train_ops.global_step_inc_op,
+ logdir,
+ master=master,
+ is_chief=is_chief,
+ scaffold=scaffold,
+ hooks=hooks,
+ chief_only_hooks=chief_only_hooks,
+ save_checkpoint_secs=save_checkpoint_secs,
+ save_summaries_steps=save_summaries_steps,
+ config=config)
+
+
+def get_sequential_train_steps(
+ train_steps=namedtuples.GANTrainSteps(1, 1)):
+ """Returns a thin wrapper around slim.learning.train_step, for GANs.
+
+ This function is to provide support for the Supervisor. For new code, please
+ use `MonitoredSession` and `get_sequential_train_hooks`.
+
+ Args:
+ train_steps: A `GANTrainSteps` tuple that determines how many generator
+ and discriminator training steps to take.
+
+ Returns:
+ A function that can be used for `train_step_fn` for GANs.
+ """
+
+ def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
+ """A thin wrapper around slim.learning.train_step, for GANs.
+
+ Args:
+ sess: A Tensorflow session.
+ train_ops: A GANTrainOps tuple of train ops to run.
+ global_step: The global step.
+ train_step_kwargs: Dictionary controlling `train_step` behavior.
+
+ Returns:
+ A scalar final loss and a bool whether or not the train loop should stop.
+ """
+ # Only run `should_stop` at the end, if required. Make a local copy of
+ # `train_step_kwargs`, if necessary, so as not to modify the caller's
+ # dictionary.
+ should_stop_op, train_kwargs = None, train_step_kwargs
+ if 'should_stop' in train_step_kwargs:
+ should_stop_op = train_step_kwargs['should_stop']
+ train_kwargs = train_step_kwargs.copy()
+ del train_kwargs['should_stop']
+
+ # Run generator training steps.
+ gen_loss = 0
+ for _ in range(train_steps.generator_train_steps):
+ cur_gen_loss, _ = slim_learning.train_step(
+ sess, train_ops.generator_train_op, global_step, train_kwargs)
+ gen_loss += cur_gen_loss
+
+ # Run discriminator training steps.
+ dis_loss = 0
+ for _ in range(train_steps.discriminator_train_steps):
+ cur_dis_loss, _ = slim_learning.train_step(
+ sess, train_ops.discriminator_train_op, global_step, train_kwargs)
+ dis_loss += cur_dis_loss
+
+ sess.run(train_ops.global_step_inc_op)
+
+ # Run the `should_stop` op after the global step has been incremented, so
+ # that the `should_stop` aligns with the proper `global_step` count.
+ if should_stop_op is not None:
+ should_stop = sess.run(should_stop_op)
+ else:
+ should_stop = False
+
+ return gen_loss + dis_loss, should_stop
+
+ return sequential_train_steps
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
new file mode 100644
index 0000000000..83b763806c
--- /dev/null
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -0,0 +1,745 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for gan.python.train."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.framework.python.ops import variables as variables_lib
+from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python import train
+from tensorflow.contrib.slim.python.slim import learning as slim_learning
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.ops.distributions import categorical
+from tensorflow.python.platform import test
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import coordinator
+from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import sync_replicas_optimizer
+from tensorflow.python.training import training_util
+
+
+def generator_model(inputs):
+ return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
+
+
+class Generator(object):
+
+ def __call__(self, inputs):
+ return generator_model(inputs)
+
+
+def infogan_generator_model(inputs):
+ return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs[0]
+
+
+class InfoGANGenerator(object):
+
+ def __call__(self, inputs):
+ return infogan_generator_model(inputs)
+
+
+def discriminator_model(inputs, _):
+ return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
+
+
+class Discriminator(object):
+
+ def __call__(self, inputs, _):
+ return discriminator_model(inputs, _)
+
+
+def infogan_discriminator_model(inputs, _):
+ return (variable_scope.get_variable('dummy_d', initializer=2.0) * inputs,
+ [categorical.Categorical([1.0])])
+
+
+class InfoGANDiscriminator(object):
+
+ def __call__(self, inputs, _):
+ return infogan_discriminator_model(inputs, _)
+
+
+def acgan_discriminator_model(inputs, _, num_classes=10):
+ return (discriminator_model(inputs, _), array_ops.one_hot(
+ # TODO(haeusser): infer batch size from input
+ random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
+ num_classes))
+
+
+class ACGANDiscriminator(object):
+
+ def __call__(self, inputs, _, num_classes=10):
+ return (discriminator_model(inputs, _), array_ops.one_hot(
+ # TODO(haeusser): infer batch size from input
+ random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
+ num_classes))
+
+
+def get_gan_model():
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ pass
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+ return namedtuples.GANModel(
+ generator_inputs=None,
+ generated_data=None,
+ generator_variables=None,
+ generator_scope=gen_scope,
+ generator_fn=generator_model,
+ real_data=array_ops.ones([1, 2, 3]),
+ discriminator_real_outputs=array_ops.ones([1, 2, 3]),
+ discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=discriminator_model)
+
+
+def get_callable_gan_model():
+ ganmodel = get_gan_model()
+ return ganmodel._replace(
+ generator_fn=Generator(),
+ discriminator_fn=Discriminator())
+
+
+def create_gan_model():
+ return train.gan_model(
+ generator_model,
+ discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=random_ops.random_normal([1, 2]))
+
+
+def create_callable_gan_model():
+ return train.gan_model(
+ Generator(),
+ Discriminator(),
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=random_ops.random_normal([1, 2]))
+
+
+def get_infogan_model():
+ return namedtuples.InfoGANModel(
+ *get_gan_model(),
+ structured_generator_inputs=[constant_op.constant(0)],
+ predicted_distributions=[categorical.Categorical([1.0])])
+
+
+def get_callable_infogan_model():
+ return namedtuples.InfoGANModel(
+ *get_callable_gan_model(),
+ structured_generator_inputs=[constant_op.constant(0)],
+ predicted_distributions=[categorical.Categorical([1.0])])
+
+
+def create_infogan_model():
+ return train.infogan_model(
+ infogan_generator_model,
+ infogan_discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ unstructured_generator_inputs=[],
+ structured_generator_inputs=[random_ops.random_normal([1, 2])])
+
+
+def create_callable_infogan_model():
+ return train.infogan_model(
+ InfoGANGenerator(),
+ InfoGANDiscriminator(),
+ real_data=array_ops.zeros([1, 2]),
+ unstructured_generator_inputs=[],
+ structured_generator_inputs=[random_ops.random_normal([1, 2])])
+
+
+def get_acgan_model():
+ return namedtuples.ACGANModel(
+ *get_gan_model(),
+ one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
+ discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
+ discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
+
+
+def get_callable_acgan_model():
+ return namedtuples.ACGANModel(
+ *get_callable_gan_model(),
+ one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
+ discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
+ discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
+
+
+def create_acgan_model():
+ return train.acgan_model(
+ generator_model,
+ acgan_discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=random_ops.random_normal([1, 2]),
+ one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
+
+
+def create_callable_acgan_model():
+ return train.acgan_model(
+ Generator(),
+ ACGANDiscriminator(),
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=random_ops.random_normal([1, 2]),
+ one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
+
+
+def get_sync_optimizer():
+ return sync_replicas_optimizer.SyncReplicasOptimizer(
+ gradient_descent.GradientDescentOptimizer(learning_rate=1.0),
+ replicas_to_aggregate=1)
+
+
+class GANModelTest(test.TestCase):
+ """Tests for `gan_model`."""
+
+ def _test_output_type_helper(self, create_fn, tuple_type):
+ self.assertTrue(isinstance(create_fn(), tuple_type))
+
+ def test_output_type_gan(self):
+ self._test_output_type_helper(get_gan_model, namedtuples.GANModel)
+
+ def test_output_type_callable_gan(self):
+ self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel)
+
+ def test_output_type_infogan(self):
+ self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel)
+
+ def test_output_type_callable_infogan(self):
+ self._test_output_type_helper(
+ get_callable_infogan_model, namedtuples.InfoGANModel)
+
+ def test_output_type_acgan(self):
+ self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel)
+
+ def test_output_type_callable_acgan(self):
+ self._test_output_type_helper(
+ get_callable_acgan_model, namedtuples.ACGANModel)
+
+ def test_no_shape_check(self):
+ def dummy_generator_model(_):
+ return (None, None)
+ def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument
+ return 1
+ with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
+ train.gan_model(
+ dummy_generator_model,
+ dummy_discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=array_ops.zeros([1]),
+ check_shapes=True)
+ train.gan_model(
+ dummy_generator_model,
+ dummy_discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=array_ops.zeros([1]),
+ check_shapes=False)
+
+
+class GANLossTest(test.TestCase):
+ """Tests for `gan_loss`."""
+
+ # Test output type.
+ def _test_output_type_helper(self, get_gan_model_fn):
+ loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
+ self.assertTrue(isinstance(loss, namedtuples.GANLoss))
+ self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
+
+ def test_output_type_gan(self):
+ self._test_output_type_helper(get_gan_model)
+
+ def test_output_type_callable_gan(self):
+ self._test_output_type_helper(get_callable_gan_model)
+
+ def test_output_type_infogan(self):
+ self._test_output_type_helper(get_infogan_model)
+
+ def test_output_type_callable_infogan(self):
+ self._test_output_type_helper(get_callable_infogan_model)
+
+ def test_output_type_acgan(self):
+ self._test_output_type_helper(get_acgan_model)
+
+ def test_output_type_callable_acgan(self):
+ self._test_output_type_helper(get_callable_acgan_model)
+
+ # Test gradient penalty option.
+ def _test_grad_penalty_helper(self, create_gan_model_fn):
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+ loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
+ self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
+
+ # Check values.
+ with self.test_session(use_gpu=True) as sess:
+ variables.global_variables_initializer().run()
+ loss_gen_np, loss_gen_gp_np = sess.run(
+ [loss.generator_loss, loss_gp.generator_loss])
+ loss_dis_np, loss_dis_gp_np = sess.run(
+ [loss.discriminator_loss, loss_gp.discriminator_loss])
+
+ self.assertEqual(loss_gen_np, loss_gen_gp_np)
+ self.assertTrue(loss_dis_np < loss_dis_gp_np)
+
+ def test_grad_penalty_gan(self):
+ self._test_grad_penalty_helper(create_gan_model)
+
+ def test_grad_penalty_callable_gan(self):
+ self._test_grad_penalty_helper(create_callable_gan_model)
+
+ def test_grad_penalty_infogan(self):
+ self._test_grad_penalty_helper(create_infogan_model)
+
+ def test_grad_penalty_callable_infogan(self):
+ self._test_grad_penalty_helper(create_callable_infogan_model)
+
+ def test_grad_penalty_acgan(self):
+ self._test_grad_penalty_helper(create_acgan_model)
+
+ def test_grad_penalty_callable_acgan(self):
+ self._test_grad_penalty_helper(create_callable_acgan_model)
+
+ # Test mutual information penalty option.
+ def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
+ train.gan_loss(create_gan_model_fn(),
+ mutual_information_penalty_weight=constant_op.constant(1.0))
+
+ def test_mutual_info_penalty_infogan(self):
+ self._test_mutual_info_penalty_helper(get_infogan_model)
+
+ def test_mutual_info_penalty_callable_infogan(self):
+ self._test_mutual_info_penalty_helper(get_callable_infogan_model)
+
+ # Test regularization loss.
+ def _test_regularization_helper(self, get_gan_model_fn):
+ # Evaluate losses without regularization.
+ no_reg_loss = train.gan_loss(get_gan_model_fn())
+ with self.test_session(use_gpu=True):
+ no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
+ no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()
+
+ with ops.name_scope(get_gan_model_fn().generator_scope.name):
+ ops.add_to_collection(
+ ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
+ with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
+ ops.add_to_collection(
+ ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
+
+ # Check that losses now include the correct regularization values.
+ reg_loss = train.gan_loss(get_gan_model_fn())
+ with self.test_session(use_gpu=True):
+ reg_loss_gen_np = reg_loss.generator_loss.eval()
+ reg_loss_dis_np = reg_loss.discriminator_loss.eval()
+
+ self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
+ self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
+
+ def test_regularization_gan(self):
+ self._test_regularization_helper(get_gan_model)
+
+ def test_regularization_callable_gan(self):
+ self._test_regularization_helper(get_callable_gan_model)
+
+ def test_regularization_infogan(self):
+ self._test_regularization_helper(get_infogan_model)
+
+ def test_regularization_callable_infogan(self):
+ self._test_regularization_helper(get_callable_infogan_model)
+
+ def test_regularization_acgan(self):
+ self._test_regularization_helper(get_acgan_model)
+
+ def test_regularization_callable_acgan(self):
+ self._test_regularization_helper(get_callable_acgan_model)
+
+ # Test that ACGan models work.
+ def _test_acgan_helper(self, create_gan_model_fn):
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+ loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
+ loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
+ self.assertTrue(isinstance(loss, namedtuples.GANLoss))
+ self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
+ self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))
+
+ # Check values.
+ with self.test_session(use_gpu=True) as sess:
+ variables.global_variables_initializer().run()
+ loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
+ [loss.generator_loss,
+ loss_ac_gen.generator_loss,
+ loss_ac_dis.generator_loss])
+ loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
+ [loss.discriminator_loss,
+ loss_ac_gen.discriminator_loss,
+ loss_ac_dis.discriminator_loss])
+
+ self.assertTrue(loss_gen_np < loss_dis_np)
+ self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
+ self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
+ self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
+ self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
+
+ def test_acgan(self):
+ self._test_acgan_helper(create_acgan_model)
+
+ def test_callable_acgan(self):
+ self._test_acgan_helper(create_callable_acgan_model)
+
+ def test_doesnt_crash_when_in_nested_scope(self):
+ with variable_scope.variable_scope('outer_scope'):
+ gan_model = train.gan_model(
+ generator_model,
+ discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=random_ops.random_normal([1, 2]))
+
+ # This should work inside a scope.
+ train.gan_loss(gan_model, gradient_penalty_weight=1.0)
+
+ # This should also work outside a scope.
+ train.gan_loss(gan_model, gradient_penalty_weight=1.0)
+
+
+class GANTrainOpsTest(test.TestCase):
+ """Tests for `gan_train_ops`."""
+
+ def _test_output_type_helper(self, create_gan_model_fn):
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+
+ g_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ d_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ train_ops = train.gan_train_ops(
+ model,
+ loss,
+ g_opt,
+ d_opt,
+ summarize_gradients=True,
+ colocate_gradients_with_ops=True)
+
+ self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
+
+ def test_output_type_gan(self):
+ self._test_output_type_helper(create_gan_model)
+
+ def test_output_type_callable_gan(self):
+ self._test_output_type_helper(create_callable_gan_model)
+
+ def test_output_type_infogan(self):
+ self._test_output_type_helper(create_infogan_model)
+
+ def test_output_type_callable_infogan(self):
+ self._test_output_type_helper(create_callable_infogan_model)
+
+ def test_output_type_acgan(self):
+ self._test_output_type_helper(create_acgan_model)
+
+ def test_output_type_callable_acgan(self):
+ self._test_output_type_helper(create_callable_acgan_model)
+
+ # TODO(joelshor): Add a test to check that custom update op is run.
+ def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+
+ # Add generator and discriminator update ops.
+ with variable_scope.variable_scope(model.generator_scope):
+ gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
+ gen_update_op = gen_update_count.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
+ with variable_scope.variable_scope(model.discriminator_scope):
+ dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
+ dis_update_op = dis_update_count.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)
+
+ # Add an update op outside the generator and discriminator scopes.
+ if provide_update_ops:
+ kwargs = {'update_ops':
+ [constant_op.constant(1.0), gen_update_op, dis_update_op]}
+ else:
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
+ kwargs = {}
+
+ g_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ d_opt = gradient_descent.GradientDescentOptimizer(1.0)
+
+ with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
+ train.gan_train_ops(model, loss, g_opt, d_opt,
+ check_for_unused_update_ops=True, **kwargs)
+ train_ops = train.gan_train_ops(
+ model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertEqual(0, gen_update_count.eval())
+ self.assertEqual(0, dis_update_count.eval())
+
+ train_ops.generator_train_op.eval()
+ self.assertEqual(1, gen_update_count.eval())
+ self.assertEqual(0, dis_update_count.eval())
+
+ train_ops.discriminator_train_op.eval()
+ self.assertEqual(1, gen_update_count.eval())
+ self.assertEqual(1, dis_update_count.eval())
+
+ def test_unused_update_ops_gan(self):
+ self._test_unused_update_ops(create_gan_model, False)
+
+ def test_unused_update_ops_gan_provideupdates(self):
+ self._test_unused_update_ops(create_gan_model, True)
+
+ def test_unused_update_ops_callable_gan(self):
+ self._test_unused_update_ops(create_callable_gan_model, False)
+
+ def test_unused_update_ops_callable_gan_provideupdates(self):
+ self._test_unused_update_ops(create_callable_gan_model, True)
+
+ def test_unused_update_ops_infogan(self):
+ self._test_unused_update_ops(create_infogan_model, False)
+
+ def test_unused_update_ops_infogan_provideupdates(self):
+ self._test_unused_update_ops(create_infogan_model, True)
+
+ def test_unused_update_ops_callable_infogan(self):
+ self._test_unused_update_ops(create_callable_infogan_model, False)
+
+ def test_unused_update_ops_callable_infogan_provideupdates(self):
+ self._test_unused_update_ops(create_callable_infogan_model, True)
+
+ def test_unused_update_ops_acgan(self):
+ self._test_unused_update_ops(create_acgan_model, False)
+
+ def test_unused_update_ops_acgan_provideupdates(self):
+ self._test_unused_update_ops(create_acgan_model, True)
+
+ def test_unused_update_ops_callable_acgan(self):
+ self._test_unused_update_ops(create_callable_acgan_model, False)
+
+ def test_unused_update_ops_callable_acgan_provideupdates(self):
+ self._test_unused_update_ops(create_callable_acgan_model, True)
+
+ def _test_sync_replicas_helper(self, create_gan_model_fn):
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+ num_trainable_vars = len(variables_lib.get_trainable_variables())
+
+ g_opt = get_sync_optimizer()
+ d_opt = get_sync_optimizer()
+ train_ops = train.gan_train_ops(
+ model,
+ loss,
+ generator_optimizer=g_opt,
+ discriminator_optimizer=d_opt)
+ self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
+ # No new trainable variables should have been added.
+ self.assertEqual(num_trainable_vars,
+ len(variables_lib.get_trainable_variables()))
+
+ g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
+ d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
+
+ # Check that update op is run properly.
+ global_step = training_util.get_or_create_global_step()
+ with self.test_session(use_gpu=True) as sess:
+ variables.global_variables_initializer().run()
+ variables.local_variables_initializer().run()
+
+ g_opt.chief_init_op.run()
+ d_opt.chief_init_op.run()
+
+ gstep_before = global_step.eval()
+
+ # Start required queue runner for SyncReplicasOptimizer.
+ coord = coordinator.Coordinator()
+ g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
+ d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)
+
+ g_sync_init_op.run()
+ d_sync_init_op.run()
+
+ train_ops.generator_train_op.eval()
+ # Check that global step wasn't incremented.
+ self.assertEqual(gstep_before, global_step.eval())
+
+ train_ops.discriminator_train_op.eval()
+ # Check that global step wasn't incremented.
+ self.assertEqual(gstep_before, global_step.eval())
+
+ coord.request_stop()
+ coord.join(g_threads + d_threads)
+
+ def test_sync_replicas_gan(self):
+ self._test_sync_replicas_helper(create_gan_model)
+
+ def test_sync_replicas_callable_gan(self):
+ self._test_sync_replicas_helper(create_callable_gan_model)
+
+ def test_sync_replicas_infogan(self):
+ self._test_sync_replicas_helper(create_infogan_model)
+
+ def test_sync_replicas_callable_infogan(self):
+ self._test_sync_replicas_helper(create_callable_infogan_model)
+
+ def test_sync_replicas_acgan(self):
+ self._test_sync_replicas_helper(create_acgan_model)
+
+ def test_sync_replicas_callable_acgan(self):
+ self._test_sync_replicas_helper(create_callable_acgan_model)
+
+
+class GANTrainTest(test.TestCase):
+ """Tests for `gan_train`."""
+
+ def _gan_train_ops(self, generator_add, discriminator_add):
+ step = training_util.create_global_step()
+ # Increment the global count every time a train op is run so we can count
+ # the number of times they're run.
+ # NOTE: `use_locking=True` is required to avoid race conditions with
+ # joint training.
+ train_ops = namedtuples.GANTrainOps(
+ generator_train_op=step.assign_add(generator_add, use_locking=True),
+ discriminator_train_op=step.assign_add(discriminator_add,
+ use_locking=True),
+ global_step_inc_op=step.assign_add(1))
+ return train_ops
+
+ def _test_run_helper(self, create_gan_model_fn):
+ random_seed.set_random_seed(1234)
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+
+ g_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ d_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
+
+ final_step = train.gan_train(
+ train_ops,
+ logdir='',
+ hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
+ self.assertTrue(np.isscalar(final_step))
+ self.assertEqual(2, final_step)
+
+ def test_run_gan(self):
+ self._test_run_helper(create_gan_model)
+
+ def test_run_callable_gan(self):
+ self._test_run_helper(create_callable_gan_model)
+
+ def test_run_infogan(self):
+ self._test_run_helper(create_infogan_model)
+
+ def test_run_callable_infogan(self):
+ self._test_run_helper(create_callable_infogan_model)
+
+ def test_run_acgan(self):
+ self._test_run_helper(create_acgan_model)
+
+ def test_run_callable_acgan(self):
+ self._test_run_helper(create_callable_acgan_model)
+
+ # Test multiple train steps.
+ def _test_multiple_steps_helper(self, get_hooks_fn_fn):
+ train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
+ train_steps = namedtuples.GANTrainSteps(
+ generator_train_steps=3,
+ discriminator_train_steps=4)
+ final_step = train.gan_train(
+ train_ops,
+ get_hooks_fn=get_hooks_fn_fn(train_steps),
+ logdir='',
+ hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])
+
+ self.assertTrue(np.isscalar(final_step))
+ self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
+
+ def test_multiple_steps_seq_train_steps(self):
+ self._test_multiple_steps_helper(train.get_sequential_train_hooks)
+
+ def test_multiple_steps_efficient_seq_train_steps(self):
+ self._test_multiple_steps_helper(train.get_joint_train_hooks)
+
+ def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
+ step = training_util.create_global_step()
+ train_ops = namedtuples.GANTrainOps(
+ generator_train_op=constant_op.constant(3.0),
+ discriminator_train_op=constant_op.constant(2.0),
+ global_step_inc_op=step.assign_add(1))
+ train_steps = namedtuples.GANTrainSteps(
+ generator_train_steps=3,
+ discriminator_train_steps=4)
+
+ final_loss = slim_learning.train(
+ train_op=train_ops,
+ logdir='',
+ global_step=step,
+ number_of_steps=1,
+ train_step_fn=train.get_sequential_train_steps(train_steps))
+ self.assertTrue(np.isscalar(final_loss))
+ self.assertEqual(17.0, final_loss)
+
+
+class PatchGANTest(test.TestCase):
+ """Tests that functions work on PatchGAN style output."""
+
+ def _test_patchgan_helper(self, create_gan_model_fn):
+ """Ensure that patch-based discriminators work end-to-end."""
+ random_seed.set_random_seed(1234)
+ model = create_gan_model_fn()
+ loss = train.gan_loss(model)
+
+ g_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ d_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
+
+ final_step = train.gan_train(
+ train_ops,
+ logdir='',
+ hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
+ self.assertTrue(np.isscalar(final_step))
+ self.assertEqual(2, final_step)
+
+ def test_patchgan_gan(self):
+ self._test_patchgan_helper(create_gan_model)
+
+ def test_patchgan_callable_gan(self):
+ self._test_patchgan_helper(create_callable_gan_model)
+
+ def test_patchgan_infogan(self):
+ self._test_patchgan_helper(create_infogan_model)
+
+ def test_patchgan_callable_infogan(self):
+ self._test_patchgan_helper(create_callable_infogan_model)
+
+ def test_patchgan_acgan(self):
+ self._test_patchgan_helper(create_acgan_model)
+
+ def test_patchgan_callable_acgan(self):
+ self._test_patchgan_helper(create_callable_acgan_model)
+
+
+if __name__ == '__main__':
+ test.main()