aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/train.py')
-rw-r--r--tensorflow/contrib/gan/python/train.py238
1 files changed, 182 insertions, 56 deletions
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 6fa43059f3..df603d1f18 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -36,10 +36,12 @@ from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
@@ -47,12 +49,12 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-
__all__ = [
'gan_model',
'infogan_model',
'acgan_model',
'cyclegan_model',
+ 'stargan_model',
'gan_loss',
'cyclegan_loss',
'gan_train_ops',
@@ -123,16 +125,9 @@ def gan_model(
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.GANModel(
- generator_inputs,
- generated_data,
- generator_variables,
- gen_scope,
- generator_fn,
- real_data,
- discriminator_real_outputs,
- discriminator_gen_outputs,
- discriminator_variables,
- dis_scope,
+ generator_inputs, generated_data, generator_variables, gen_scope,
+ generator_fn, real_data, discriminator_real_outputs,
+ discriminator_gen_outputs, discriminator_variables, dis_scope,
discriminator_fn)
@@ -201,8 +196,7 @@ def infogan_model(
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
- discriminator_variables = variables_lib.get_trainable_variables(
- disc_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(disc_scope)
return namedtuples.InfoGANModel(
generator_inputs,
@@ -279,12 +273,12 @@ def acgan_model(
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
- with ops.name_scope(dis_scope.name+'/generated/'):
+ with ops.name_scope(dis_scope.name + '/generated/'):
(discriminator_gen_outputs, discriminator_gen_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
- with ops.name_scope(dis_scope.name+'/real/'):
+ with ops.name_scope(dis_scope.name + '/real/'):
real_data = ops.convert_to_tensor(real_data)
(discriminator_real_outputs, discriminator_real_classification_logits
) = _validate_acgan_discriminator_outputs(
@@ -297,8 +291,7 @@ def acgan_model(
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
- discriminator_variables = variables_lib.get_trainable_variables(
- dis_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.ACGANModel(
generator_inputs, generated_data, generator_variables, gen_scope,
@@ -379,6 +372,108 @@ def cyclegan_model(
reconstructed_y)
+def stargan_model(generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ generator_scope='Generator',
+ discriminator_scope='Discriminator'):
+ """Returns a StarGAN model outputs and variables.
+
+ See https://arxiv.org/abs/1711.09020 for more details.
+
+ Args:
+ generator_fn: A python lambda that takes `inputs` and `targets` as inputs
+ and returns 'generated_data' as the transformed version of `input` based
+ on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
+ num_domains), and `generated_data` has the same shape as `input`.
+ discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
+ inputs and returns a tuple (`source_prediction`, `domain_prediction`).
+ `source_prediction` represents the source(real/generated) prediction by
+ the discriminator, and `domain_prediction` represents the domain
+ prediction/classification by the discriminator. `source_prediction` has
+ shape (n) and `domain_prediction` has shape (n, num_domains).
+ input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
+ the real input images.
+ input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
+ num_domains) representing the domain label associated with the real
+ images.
+ generator_scope: Optional generator variable scope. Useful if you want to
+ reuse a subgraph that has already been created.
+ discriminator_scope: Optional discriminator variable scope. Useful if you
+ want to reuse a subgraph that has already been created.
+
+ Returns:
+ StarGANModel nametuple return the tensor that are needed to compute the
+ loss.
+
+ Raises:
+ ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
+ defined in every dimensions.
+ """
+
+ # Convert to tensor.
+ input_data = _convert_tensor_or_l_or_d(input_data)
+ input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
+
+ # Convert list of tensor to a single tensor if applicable.
+ if isinstance(input_data, (list, tuple)):
+ input_data = array_ops.concat(
+ [ops.convert_to_tensor(x) for x in input_data], 0)
+ if isinstance(input_data_domain_label, (list, tuple)):
+ input_data_domain_label = array_ops.concat(
+ [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)
+
+ # Get batch_size, num_domains from the labels.
+ input_data_domain_label.shape.assert_has_rank(2)
+ input_data_domain_label.shape.assert_is_fully_defined()
+ batch_size, num_domains = input_data_domain_label.shape.as_list()
+
+ # Transform input_data to random target domains.
+ with variable_scope.variable_scope(generator_scope) as generator_scope:
+ generated_data_domain_target = _generate_stargan_random_domain_target(
+ batch_size, num_domains)
+ generated_data = generator_fn(input_data, generated_data_domain_target)
+
+ # Transform generated_data back to the original input_data domain.
+ with variable_scope.variable_scope(generator_scope, reuse=True):
+ reconstructed_data = generator_fn(generated_data, input_data_domain_label)
+
+ # Predict source and domain for the generated_data using the discriminator.
+ with variable_scope.variable_scope(
+ discriminator_scope) as discriminator_scope:
+ disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
+ generated_data, num_domains)
+
+ # Predict source and domain for the input_data using the discriminator.
+ with variable_scope.variable_scope(discriminator_scope, reuse=True):
+ disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
+ input_data, num_domains)
+
+ # Collect trainable variables from the neural networks.
+ generator_variables = variables_lib.get_trainable_variables(generator_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(
+ discriminator_scope)
+
+ # Create the StarGANModel namedtuple.
+ return namedtuples.StarGANModel(
+ input_data=input_data,
+ input_data_domain_label=input_data_domain_label,
+ generated_data=generated_data,
+ generated_data_domain_target=generated_data_domain_target,
+ reconstructed_data=reconstructed_data,
+ discriminator_input_data_source_predication=disc_input_data_source_pred,
+ discriminator_generated_data_source_predication=disc_gen_data_source_pred,
+ discriminator_input_data_domain_predication=disc_input_data_domain_pred,
+ discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
+ generator_variables=generator_variables,
+ generator_scope=generator_scope,
+ generator_fn=generator_fn,
+ discriminator_variables=discriminator_variables,
+ discriminator_scope=discriminator_scope,
+ discriminator_fn=discriminator_fn)
+
+
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
if isinstance(aux_loss_weight, ops.Tensor):
aux_loss_weight.shape.assert_is_compatible_with([])
@@ -419,33 +514,42 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn):
Raises:
ValueError: If tensor pool does not support the `model`.
"""
- if tensor_pool_fn is None:
- return model
-
- pooled_generated_data, pooled_generator_inputs = tensor_pool_fn(
- (model.generated_data, model.generator_inputs))
-
if isinstance(model, namedtuples.GANModel):
+ pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
+ (model.generator_inputs, model.generated_data))
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
dis_gen_outputs = model.discriminator_fn(pooled_generated_data,
pooled_generator_inputs)
- return model._replace(discriminator_gen_outputs=dis_gen_outputs)
+ return model._replace(
+ generator_inputs=pooled_generator_inputs,
+ generated_data=pooled_generated_data,
+ discriminator_gen_outputs=dis_gen_outputs)
elif isinstance(model, namedtuples.ACGANModel):
+ pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
+ (model.generator_inputs, model.generated_data))
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
- (dis_pooled_gen_outputs,
- dis_pooled_gen_classification_logits) = model.discriminator_fn(
+ (pooled_discriminator_gen_outputs,
+ pooled_discriminator_gen_classification_logits) = model.discriminator_fn(
pooled_generated_data, pooled_generator_inputs)
return model._replace(
- discriminator_gen_outputs=dis_pooled_gen_outputs,
+ generator_inputs=pooled_generator_inputs,
+ generated_data=pooled_generated_data,
+ discriminator_gen_outputs=pooled_discriminator_gen_outputs,
discriminator_gen_classification_logits=
- dis_pooled_gen_classification_logits)
+ pooled_discriminator_gen_classification_logits)
elif isinstance(model, namedtuples.InfoGANModel):
+ pooled_generator_inputs, pooled_generated_data, pooled_structured_input = (
+ tensor_pool_fn((model.generator_inputs, model.generated_data,
+ model.structured_generator_inputs)))
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
- (dis_pooled_gen_outputs,
+ (pooled_discriminator_gen_outputs,
pooled_predicted_distributions) = model.discriminator_and_aux_fn(
pooled_generated_data, pooled_generator_inputs)
return model._replace(
- discriminator_gen_outputs=dis_pooled_gen_outputs,
+ generator_inputs=pooled_generator_inputs,
+ generated_data=pooled_generated_data,
+ structured_generator_inputs=pooled_structured_input,
+ discriminator_gen_outputs=pooled_discriminator_gen_outputs,
predicted_distributions=pooled_predicted_distributions)
else:
raise ValueError('Tensor pool does not support `model`: %s.' % type(model))
@@ -512,8 +616,8 @@ def gan_loss(
`model` isn't an `InfoGANModel`.
"""
# Validate arguments.
- gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight,
- 'gradient_penalty_weight')
+ gradient_penalty_weight = _validate_aux_loss_weight(
+ gradient_penalty_weight, 'gradient_penalty_weight')
mutual_information_penalty_weight = _validate_aux_loss_weight(
mutual_information_penalty_weight, 'infogan_weight')
aux_cond_generator_weight = _validate_aux_loss_weight(
@@ -537,33 +641,38 @@ def gan_loss(
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
type(model))
+ # Optionally create pooled model.
+ pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if
+ tensor_pool_fn else model)
+
# Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
- dis_loss = discriminator_loss_fn(
- _tensor_pool_adjusted_model(model, tensor_pool_fn),
- add_summaries=add_summaries)
+ dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries)
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
- model,
+ pooled_model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
- info_loss = tfgan_losses.mutual_information_penalty(
+ gen_info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries)
- dis_loss += mutual_information_penalty_weight * info_loss
- gen_loss += mutual_information_penalty_weight * info_loss
+ dis_info_loss = (gen_info_loss if tensor_pool_fn is None else
+ tfgan_losses.mutual_information_penalty(
+ pooled_model, add_summaries=add_summaries))
+ gen_loss += mutual_information_penalty_weight * gen_info_loss
+ dis_loss += mutual_information_penalty_weight * dis_info_loss
if _use_aux_loss(aux_cond_generator_weight):
ac_gen_loss = tfgan_losses.acgan_generator_loss(
model, add_summaries=add_summaries)
gen_loss += aux_cond_generator_weight * ac_gen_loss
if _use_aux_loss(aux_cond_discriminator_weight):
ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
- model, add_summaries=add_summaries)
+ pooled_model, add_summaries=add_summaries)
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
# Gathers auxiliary losses.
if model.generator_scope:
@@ -631,8 +740,8 @@ def cyclegan_loss(
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
**kwargs)
- return partial_loss._replace(
- generator_loss=partial_loss.generator_loss + aux_loss)
+ return partial_loss._replace(generator_loss=partial_loss.generator_loss +
+ aux_loss)
with ops.name_scope('cyclegan_loss_x2y'):
loss_x2y = _partial_loss(model.model_x2y)
@@ -822,12 +931,14 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
Returns:
A function that takes a GANTrainOps tuple and returns a list of hooks.
"""
+
def get_hooks(train_ops):
generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
train_steps.generator_train_steps)
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
train_steps.discriminator_train_steps)
return [generator_hook, discriminator_hook]
+
return get_hooks
@@ -881,23 +992,23 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
d_hook = RunTrainOpsHook(d_op, num_d_steps)
return [joint_hook, g_hook, d_hook]
+
return get_hooks
# TODO(joelshor): This function currently returns the global step. Find a
# good way for it to return the generator, discriminator, and final losses.
-def gan_train(
- train_ops,
- logdir,
- get_hooks_fn=get_sequential_train_hooks(),
- master='',
- is_chief=True,
- scaffold=None,
- hooks=None,
- chief_only_hooks=None,
- save_checkpoint_secs=600,
- save_summaries_steps=100,
- config=None):
+def gan_train(train_ops,
+ logdir,
+ get_hooks_fn=get_sequential_train_hooks(),
+ master='',
+ is_chief=True,
+ scaffold=None,
+ hooks=None,
+ chief_only_hooks=None,
+ save_checkpoint_secs=600,
+ save_summaries_steps=100,
+ config=None):
"""A wrapper around `contrib.training.train` that uses GAN hooks.
Args:
@@ -943,8 +1054,7 @@ def gan_train(
config=config)
-def get_sequential_train_steps(
- train_steps=namedtuples.GANTrainSteps(1, 1)):
+def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a thin wrapper around slim.learning.train_step, for GANs.
This function is to provide support for the Supervisor. For new code, please
@@ -1042,3 +1152,19 @@ def _validate_acgan_discriminator_outputs(discriminator_output):
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b
+
+
+def _generate_stargan_random_domain_target(batch_size, num_domains):
+ """Generate random domain label.
+
+ Args:
+ batch_size: (int) Number of random domain label.
+ num_domains: (int) Number of domains representing with the label.
+
+ Returns:
+ Tensor of shape (batch_size, num_domains) representing random label.
+ """
+ domain_idx = random_ops.random_uniform(
+ [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32)
+
+ return array_ops.one_hot(domain_idx, num_domains)