aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Wesley Qian <wwq@google.com>2018-07-17 09:30:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 09:34:15 -0700
commit97b1ef3ee8432d9a3bf664d367377028e95e0e1f (patch)
tree6cf2a88ed7c17aa5acc18637b2551d627d0e8edd /tensorflow/contrib/gan
parentfc26a5829c668ea49187a989e6b9657b6b8b1f02 (diff)
Add StarGAN model for TFGAN.
- Defined namedtuple for StarGAN model. - Function for StarGAN model creation. - Test for StarGAN model creation. - Fix small lint issue in train.py. PiperOrigin-RevId: 204923505
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD5
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py50
-rw-r--r--tensorflow/contrib/gan/python/train.py182
-rw-r--r--tensorflow/contrib/gan/python/train_test.py259
4 files changed, 412 insertions, 84 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 10a8796bcb..c8c2af49d4 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -42,8 +42,10 @@ py_library(
"//tensorflow/contrib/training:training_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:training",
"//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
@@ -58,17 +60,18 @@ py_test(
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
- ":features",
":namedtuples",
":random_tensor_pool",
":train",
"//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/slim:learning",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index 25cfeafeec..a462b68e28 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -25,12 +25,12 @@ from __future__ import print_function
import collections
-
__all__ = [
'GANModel',
'InfoGANModel',
'ACGANModel',
'CycleGANModel',
+ 'StarGANModel',
'GANLoss',
'CycleGANLoss',
'GANTrainOps',
@@ -136,6 +136,54 @@ class CycleGANModel(
"""
+class StarGANModel(
+ collections.namedtuple('StarGANModel', (
+ 'input_data',
+ 'input_data_domain_label',
+ 'generated_data',
+ 'generated_data_domain_target',
+ 'reconstructed_data',
+ 'discriminator_input_data_source_predication',
+ 'discriminator_generated_data_source_predication',
+ 'discriminator_input_data_domain_predication',
+ 'discriminator_generated_data_domain_predication',
+ 'generator_variables',
+ 'generator_scope',
+ 'generator_fn',
+ 'discriminator_variables',
+ 'discriminator_scope',
+ 'discriminator_fn',
+ ))):
+ """A StarGANModel contains all the pieces needed for StarGAN training.
+
+ Args:
+ input_data: The real images that need to be transferred by the generator.
+ input_data_domain_label: The real domain labels associated with the real
+ images.
+ generated_data: The generated images produced by the generator. It has the
+ same shape as the input_data.
+ generated_data_domain_target: The target domain that the generated images
+ belong to. It has the same shape as the input_data_domain_label.
+ reconstructed_data: The reconstructed images produced by the G(enerator).
+ reconstructed_data = G(G(input_data, generated_data_domain_target),
+ input_data_domain_label).
+ discriminator_input_data_source: The discriminator's output for predicting
+ the source (real/generated) of input_data.
+ discriminator_generated_data_source: The discriminator's output for
+ predicting the source (real/generated) of generated_data.
+ discriminator_input_data_domain_predication: The discriminator's output for
+ predicting the domain_label for the input_data.
+ discriminator_generated_data_domain_predication: The discriminatorr's output
+ for predicting the domain_target for the generated_data.
+ generator_variables: A list of all generator variables.
+ generator_scope: Variable scope all generator variables live in.
+ generator_fn: The generator function.
+ discriminator_variables: A list of all discriminator variables.
+ discriminator_scope: Variable scope all discriminator variables live in.
+ discriminator_fn: The discriminator function.
+ """
+
+
class GANLoss(
collections.namedtuple('GANLoss', (
'generator_loss',
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 6fa43059f3..49d9327333 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -36,10 +36,12 @@ from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
@@ -47,12 +49,12 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
-
__all__ = [
'gan_model',
'infogan_model',
'acgan_model',
'cyclegan_model',
+ 'stargan_model',
'gan_loss',
'cyclegan_loss',
'gan_train_ops',
@@ -123,16 +125,9 @@ def gan_model(
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.GANModel(
- generator_inputs,
- generated_data,
- generator_variables,
- gen_scope,
- generator_fn,
- real_data,
- discriminator_real_outputs,
- discriminator_gen_outputs,
- discriminator_variables,
- dis_scope,
+ generator_inputs, generated_data, generator_variables, gen_scope,
+ generator_fn, real_data, discriminator_real_outputs,
+ discriminator_gen_outputs, discriminator_variables, dis_scope,
discriminator_fn)
@@ -201,8 +196,7 @@ def infogan_model(
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
- discriminator_variables = variables_lib.get_trainable_variables(
- disc_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(disc_scope)
return namedtuples.InfoGANModel(
generator_inputs,
@@ -279,12 +273,12 @@ def acgan_model(
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
- with ops.name_scope(dis_scope.name+'/generated/'):
+ with ops.name_scope(dis_scope.name + '/generated/'):
(discriminator_gen_outputs, discriminator_gen_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
- with ops.name_scope(dis_scope.name+'/real/'):
+ with ops.name_scope(dis_scope.name + '/real/'):
real_data = ops.convert_to_tensor(real_data)
(discriminator_real_outputs, discriminator_real_classification_logits
) = _validate_acgan_discriminator_outputs(
@@ -297,8 +291,7 @@ def acgan_model(
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
- discriminator_variables = variables_lib.get_trainable_variables(
- dis_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.ACGANModel(
generator_inputs, generated_data, generator_variables, gen_scope,
@@ -379,6 +372,108 @@ def cyclegan_model(
reconstructed_y)
+def stargan_model(generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ generator_scope='Generator',
+ discriminator_scope='Discriminator'):
+ """Returns a StarGAN model outputs and variables.
+
+ See https://arxiv.org/abs/1711.09020 for more details.
+
+ Args:
+ generator_fn: A python lambda that takes `inputs` and `targets` as inputs
+ and returns 'generated_data' as the transformed version of `input` based
+ on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
+ num_domains), and `generated_data` has the same shape as `input`.
+ discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
+ inputs and returns a tuple (`source_prediction`, `domain_prediction`).
+ `source_prediction` represents the source(real/generated) prediction by
+ the discriminator, and `domain_prediction` represents the domain
+ prediction/classification by the discriminator. `source_prediction` has
+ shape (n) and `domain_prediction` has shape (n, num_domains).
+ input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
+ the real input images.
+ input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
+ num_domains) representing the domain label associated with the real
+ images.
+ generator_scope: Optional generator variable scope. Useful if you want to
+ reuse a subgraph that has already been created.
+ discriminator_scope: Optional discriminator variable scope. Useful if you
+ want to reuse a subgraph that has already been created.
+
+ Returns:
+ StarGANModel nametuple return the tensor that are needed to compute the
+ loss.
+
+ Raises:
+ ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
+ defined in every dimensions.
+ """
+
+ # Convert to tensor.
+ input_data = _convert_tensor_or_l_or_d(input_data)
+ input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
+
+ # Convert list of tensor to a single tensor if applicable.
+ if isinstance(input_data, (list, tuple)):
+ input_data = array_ops.concat(
+ [ops.convert_to_tensor(x) for x in input_data], 0)
+ if isinstance(input_data_domain_label, (list, tuple)):
+ input_data_domain_label = array_ops.concat(
+ [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)
+
+ # Get batch_size, num_domains from the labels.
+ input_data_domain_label.shape.assert_has_rank(2)
+ input_data_domain_label.shape.assert_is_fully_defined()
+ batch_size, num_domains = input_data_domain_label.shape.as_list()
+
+ # Transform input_data to random target domains.
+ with variable_scope.variable_scope(generator_scope) as generator_scope:
+ generated_data_domain_target = _generate_stargan_random_domain_target(
+ batch_size, num_domains)
+ generated_data = generator_fn(input_data, generated_data_domain_target)
+
+ # Transform generated_data back to the original input_data domain.
+ with variable_scope.variable_scope(generator_scope, reuse=True):
+ reconstructed_data = generator_fn(generated_data, input_data_domain_label)
+
+ # Predict source and domain for the generated_data using the discriminator.
+ with variable_scope.variable_scope(
+ discriminator_scope) as discriminator_scope:
+ disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
+ generated_data, num_domains)
+
+ # Predict source and domain for the input_data using the discriminator.
+ with variable_scope.variable_scope(discriminator_scope, reuse=True):
+ disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
+ input_data, num_domains)
+
+ # Collect trainable variables from the neural networks.
+ generator_variables = variables_lib.get_trainable_variables(generator_scope)
+ discriminator_variables = variables_lib.get_trainable_variables(
+ discriminator_scope)
+
+ # Create the StarGANModel namedtuple.
+ return namedtuples.StarGANModel(
+ input_data=input_data,
+ input_data_domain_label=input_data_domain_label,
+ generated_data=generated_data,
+ generated_data_domain_target=generated_data_domain_target,
+ reconstructed_data=reconstructed_data,
+ discriminator_input_data_source_predication=disc_input_data_source_pred,
+ discriminator_generated_data_source_predication=disc_gen_data_source_pred,
+ discriminator_input_data_domain_predication=disc_input_data_domain_pred,
+ discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
+ generator_variables=generator_variables,
+ generator_scope=generator_scope,
+ generator_fn=generator_fn,
+ discriminator_variables=discriminator_variables,
+ discriminator_scope=discriminator_scope,
+ discriminator_fn=discriminator_fn)
+
+
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
if isinstance(aux_loss_weight, ops.Tensor):
aux_loss_weight.shape.assert_is_compatible_with([])
@@ -512,8 +607,8 @@ def gan_loss(
`model` isn't an `InfoGANModel`.
"""
# Validate arguments.
- gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight,
- 'gradient_penalty_weight')
+ gradient_penalty_weight = _validate_aux_loss_weight(
+ gradient_penalty_weight, 'gradient_penalty_weight')
mutual_information_penalty_weight = _validate_aux_loss_weight(
mutual_information_penalty_weight, 'infogan_weight')
aux_cond_generator_weight = _validate_aux_loss_weight(
@@ -631,8 +726,8 @@ def cyclegan_loss(
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
**kwargs)
- return partial_loss._replace(
- generator_loss=partial_loss.generator_loss + aux_loss)
+ return partial_loss._replace(generator_loss=partial_loss.generator_loss +
+ aux_loss)
with ops.name_scope('cyclegan_loss_x2y'):
loss_x2y = _partial_loss(model.model_x2y)
@@ -822,12 +917,14 @@ def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
Returns:
A function that takes a GANTrainOps tuple and returns a list of hooks.
"""
+
def get_hooks(train_ops):
generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
train_steps.generator_train_steps)
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
train_steps.discriminator_train_steps)
return [generator_hook, discriminator_hook]
+
return get_hooks
@@ -881,23 +978,23 @@ def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
d_hook = RunTrainOpsHook(d_op, num_d_steps)
return [joint_hook, g_hook, d_hook]
+
return get_hooks
# TODO(joelshor): This function currently returns the global step. Find a
# good way for it to return the generator, discriminator, and final losses.
-def gan_train(
- train_ops,
- logdir,
- get_hooks_fn=get_sequential_train_hooks(),
- master='',
- is_chief=True,
- scaffold=None,
- hooks=None,
- chief_only_hooks=None,
- save_checkpoint_secs=600,
- save_summaries_steps=100,
- config=None):
+def gan_train(train_ops,
+ logdir,
+ get_hooks_fn=get_sequential_train_hooks(),
+ master='',
+ is_chief=True,
+ scaffold=None,
+ hooks=None,
+ chief_only_hooks=None,
+ save_checkpoint_secs=600,
+ save_summaries_steps=100,
+ config=None):
"""A wrapper around `contrib.training.train` that uses GAN hooks.
Args:
@@ -943,8 +1040,7 @@ def gan_train(
config=config)
-def get_sequential_train_steps(
- train_steps=namedtuples.GANTrainSteps(1, 1)):
+def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a thin wrapper around slim.learning.train_step, for GANs.
This function is to provide support for the Supervisor. For new code, please
@@ -1042,3 +1138,19 @@ def _validate_acgan_discriminator_outputs(discriminator_output):
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b
+
+
+def _generate_stargan_random_domain_target(batch_size, num_domains):
+ """Generate random domain label.
+
+ Args:
+ batch_size: (int) Number of random domain label.
+ num_domains: (int) Number of domains representing with the label.
+
+ Returns:
+ Tensor of shape (batch_size, num_domains) representing random label.
+ """
+ domain_idx = random_ops.random_uniform(
+ [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32)
+
+ return array_ops.one_hot(domain_idx, num_domains)
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 3ebbe55d05..06681eaf83 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib import layers
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python import train
@@ -30,6 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -84,19 +86,47 @@ class InfoGANDiscriminator(object):
def acgan_discriminator_model(inputs, _, num_classes=10):
- return (discriminator_model(inputs, _), array_ops.one_hot(
- # TODO(haeusser): infer batch size from input
- random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
- num_classes))
+ return (
+ discriminator_model(inputs, _),
+ array_ops.one_hot(
+ # TODO(haeusser): infer batch size from input
+ random_ops.random_uniform(
+ [3], maxval=num_classes, dtype=dtypes.int32),
+ num_classes))
class ACGANDiscriminator(object):
def __call__(self, inputs, _, num_classes=10):
- return (discriminator_model(inputs, _), array_ops.one_hot(
- # TODO(haeusser): infer batch size from input
- random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
- num_classes))
+ return (
+ discriminator_model(inputs, _),
+ array_ops.one_hot(
+ # TODO(haeusser): infer batch size from input
+ random_ops.random_uniform(
+ [3], maxval=num_classes, dtype=dtypes.int32),
+ num_classes))
+
+
+def stargan_generator_model(inputs, _):
+ """Dummy generator for StarGAN."""
+
+ return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs
+
+
+def stargan_discriminator_model(inputs, num_domains):
+ """Differentiable dummy discriminator for StarGAN."""
+
+ hidden = layers.flatten(inputs)
+
+ output_src = math_ops.reduce_mean(hidden, axis=1)
+
+ output_cls = layers.fully_connected(
+ inputs=hidden,
+ num_outputs=num_domains,
+ activation_fn=None,
+ normalizer_fn=None,
+ biases_initializer=None)
+ return output_src, output_cls
def get_gan_model():
@@ -122,8 +152,7 @@ def get_gan_model():
def get_callable_gan_model():
ganmodel = get_gan_model()
return ganmodel._replace(
- generator_fn=Generator(),
- discriminator_fn=Discriminator())
+ generator_fn=Generator(), discriminator_fn=Discriminator())
def create_gan_model():
@@ -283,15 +312,15 @@ class GANModelTest(test.TestCase):
self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel)
def test_output_type_callable_infogan(self):
- self._test_output_type_helper(
- get_callable_infogan_model, namedtuples.InfoGANModel)
+ self._test_output_type_helper(get_callable_infogan_model,
+ namedtuples.InfoGANModel)
def test_output_type_acgan(self):
self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel)
def test_output_type_callable_acgan(self):
- self._test_output_type_helper(
- get_callable_acgan_model, namedtuples.ACGANModel)
+ self._test_output_type_helper(get_callable_acgan_model,
+ namedtuples.ACGANModel)
def test_output_type_cyclegan(self):
self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel)
@@ -301,10 +330,13 @@ class GANModelTest(test.TestCase):
namedtuples.CycleGANModel)
def test_no_shape_check(self):
+
def dummy_generator_model(_):
return (None, None)
+
def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument
return 1
+
with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
train.gan_model(
dummy_generator_model,
@@ -320,6 +352,138 @@ class GANModelTest(test.TestCase):
check_shapes=False)
+class StarGANModelTest(test.TestCase):
+ """Tests for `stargan_model`."""
+
+ @staticmethod
+ def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains):
+
+ input_tensor_list = []
+ label_tensor_list = []
+ for _ in range(num_domains):
+ input_tensor_list.append(
+ random_ops.random_uniform((batch_size, img_size, img_size, c_size)))
+ domain_idx = random_ops.random_uniform(
+ [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32)
+ label_tensor_list.append(array_ops.one_hot(domain_idx, num_domains))
+ return input_tensor_list, label_tensor_list
+
+ def test_generate_stargan_random_domain_target(self):
+
+ batch_size = 8
+ domain_numbers = 3
+
+ target_tensor = train._generate_stargan_random_domain_target(
+ batch_size, domain_numbers)
+
+ with self.test_session() as sess:
+ targets = sess.run(target_tensor)
+ self.assertTupleEqual((batch_size, domain_numbers), targets.shape)
+ for target in targets:
+ self.assertEqual(1, np.sum(target))
+ self.assertEqual(1, np.max(target))
+
+ def test_stargan_model_output_type(self):
+
+ batch_size = 2
+ img_size = 16
+ c_size = 3
+ num_domains = 5
+
+ input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
+ batch_size, img_size, c_size, num_domains)
+ model = train.stargan_model(
+ generator_fn=stargan_generator_model,
+ discriminator_fn=stargan_discriminator_model,
+ input_data=input_tensor,
+ input_data_domain_label=label_tensor)
+
+ self.assertIsInstance(model, namedtuples.StarGANModel)
+ self.assertTrue(isinstance(model.discriminator_variables, list))
+ self.assertTrue(isinstance(model.generator_variables, list))
+ self.assertIsInstance(model.discriminator_scope,
+ variable_scope.VariableScope)
+ self.assertTrue(model.generator_scope, variable_scope.VariableScope)
+ self.assertTrue(callable(model.discriminator_fn))
+ self.assertTrue(callable(model.generator_fn))
+
+ def test_stargan_model_generator_output(self):
+
+ batch_size = 2
+ img_size = 16
+ c_size = 3
+ num_domains = 5
+
+ input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
+ batch_size, img_size, c_size, num_domains)
+ model = train.stargan_model(
+ generator_fn=stargan_generator_model,
+ discriminator_fn=stargan_discriminator_model,
+ input_data=input_tensor,
+ input_data_domain_label=label_tensor)
+
+ with self.test_session(use_gpu=True) as sess:
+
+ sess.run(variables.global_variables_initializer())
+
+ input_data, generated_data, reconstructed_data = sess.run(
+ [model.input_data, model.generated_data, model.reconstructed_data])
+ self.assertTupleEqual(
+ (batch_size * num_domains, img_size, img_size, c_size),
+ input_data.shape)
+ self.assertTupleEqual(
+ (batch_size * num_domains, img_size, img_size, c_size),
+ generated_data.shape)
+ self.assertTupleEqual(
+ (batch_size * num_domains, img_size, img_size, c_size),
+ reconstructed_data.shape)
+
+ def test_stargan_model_discriminator_output(self):
+
+ batch_size = 2
+ img_size = 16
+ c_size = 3
+ num_domains = 5
+
+ input_tensor, label_tensor = StarGANModelTest.create_input_and_label_tensor(
+ batch_size, img_size, c_size, num_domains)
+ model = train.stargan_model(
+ generator_fn=stargan_generator_model,
+ discriminator_fn=stargan_discriminator_model,
+ input_data=input_tensor,
+ input_data_domain_label=label_tensor)
+
+ with self.test_session(use_gpu=True) as sess:
+
+ sess.run(variables.global_variables_initializer())
+
+ disc_input_data_source_pred, disc_gen_data_source_pred = sess.run([
+ model.discriminator_input_data_source_predication,
+ model.discriminator_generated_data_source_predication
+ ])
+ self.assertEqual(1, len(disc_input_data_source_pred.shape))
+ self.assertEqual(batch_size * num_domains,
+ disc_input_data_source_pred.shape[0])
+ self.assertEqual(1, len(disc_gen_data_source_pred.shape))
+ self.assertEqual(batch_size * num_domains,
+ disc_gen_data_source_pred.shape[0])
+
+ input_label, disc_input_label, gen_label, disc_gen_label = sess.run([
+ model.input_data_domain_label,
+ model.discriminator_input_data_domain_predication,
+ model.generated_data_domain_target,
+ model.discriminator_generated_data_domain_predication
+ ])
+ self.assertTupleEqual((batch_size * num_domains, num_domains),
+ input_label.shape)
+ self.assertTupleEqual((batch_size * num_domains, num_domains),
+ disc_input_label.shape)
+ self.assertTupleEqual((batch_size * num_domains, num_domains),
+ gen_label.shape)
+ self.assertTupleEqual((batch_size * num_domains, num_domains),
+ disc_gen_label.shape)
+
+
class GANLossTest(test.TestCase):
"""Tests for `gan_loss`."""
@@ -362,9 +526,10 @@ class GANLossTest(test.TestCase):
def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False):
model = create_gan_model_fn()
loss = train.gan_loss(model)
- loss_gp = train.gan_loss(model,
- gradient_penalty_weight=1.0,
- gradient_penalty_one_sided=one_sided)
+ loss_gp = train.gan_loss(
+ model,
+ gradient_penalty_weight=1.0,
+ gradient_penalty_one_sided=one_sided)
self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
# Check values.
@@ -417,8 +582,9 @@ class GANLossTest(test.TestCase):
# Test mutual information penalty option.
def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
- train.gan_loss(create_gan_model_fn(),
- mutual_information_penalty_weight=constant_op.constant(1.0))
+ train.gan_loss(
+ create_gan_model_fn(),
+ mutual_information_penalty_weight=constant_op.constant(1.0))
def test_mutual_info_penalty_infogan(self):
self._test_mutual_info_penalty_helper(get_infogan_model)
@@ -435,11 +601,11 @@ class GANLossTest(test.TestCase):
no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()
with ops.name_scope(get_gan_model_fn().generator_scope.name):
- ops.add_to_collection(
- ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
+ ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
+ constant_op.constant(3.0))
with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
- ops.add_to_collection(
- ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
+ ops.add_to_collection(ops.GraphKeys.REGULARIZATION_LOSSES,
+ constant_op.constant(2.0))
# Check that losses now include the correct regularization values.
reg_loss = train.gan_loss(get_gan_model_fn())
@@ -481,14 +647,14 @@ class GANLossTest(test.TestCase):
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
- loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
- [loss.generator_loss,
- loss_ac_gen.generator_loss,
- loss_ac_dis.generator_loss])
- loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
- [loss.discriminator_loss,
- loss_ac_gen.discriminator_loss,
- loss_ac_dis.discriminator_loss])
+ loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run([
+ loss.generator_loss, loss_ac_gen.generator_loss,
+ loss_ac_dis.generator_loss
+ ])
+ loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run([
+ loss.discriminator_loss, loss_ac_gen.discriminator_loss,
+ loss_ac_dis.discriminator_loss
+ ])
self.assertTrue(loss_gen_np < loss_dis_np)
self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
@@ -707,8 +873,11 @@ class GANTrainOpsTest(test.TestCase):
# Add an update op outside the generator and discriminator scopes.
if provide_update_ops:
- kwargs = {'update_ops':
- [constant_op.constant(1.0), gen_update_op, dis_update_op]}
+ kwargs = {
+ 'update_ops': [
+ constant_op.constant(1.0), gen_update_op, dis_update_op
+ ]
+ }
else:
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
kwargs = {}
@@ -717,8 +886,8 @@ class GANTrainOpsTest(test.TestCase):
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
- train.gan_train_ops(model, loss, g_opt, d_opt,
- check_for_unused_update_ops=True, **kwargs)
+ train.gan_train_ops(
+ model, loss, g_opt, d_opt, check_for_unused_update_ops=True, **kwargs)
train_ops = train.gan_train_ops(
model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)
@@ -771,8 +940,9 @@ class GANTrainOpsTest(test.TestCase):
def test_unused_update_ops_callable_acgan_provideupdates(self):
self._test_unused_update_ops(create_callable_acgan_model, True)
- def _test_sync_replicas_helper(
- self, create_gan_model_fn, create_global_step=False):
+ def _test_sync_replicas_helper(self,
+ create_gan_model_fn,
+ create_global_step=False):
model = create_gan_model_fn()
loss = train.gan_loss(model)
num_trainable_vars = len(variables_lib.get_trainable_variables())
@@ -785,10 +955,7 @@ class GANTrainOpsTest(test.TestCase):
g_opt = get_sync_optimizer()
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
- model,
- loss,
- generator_optimizer=g_opt,
- discriminator_optimizer=d_opt)
+ model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
# No new trainable variables should have been added.
self.assertEqual(num_trainable_vars,
@@ -860,8 +1027,8 @@ class GANTrainTest(test.TestCase):
# joint training.
train_ops = namedtuples.GANTrainOps(
generator_train_op=step.assign_add(generator_add, use_locking=True),
- discriminator_train_op=step.assign_add(discriminator_add,
- use_locking=True),
+ discriminator_train_op=step.assign_add(
+ discriminator_add, use_locking=True),
global_step_inc_op=step.assign_add(1))
return train_ops
@@ -903,8 +1070,7 @@ class GANTrainTest(test.TestCase):
def _test_multiple_steps_helper(self, get_hooks_fn_fn):
train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
train_steps = namedtuples.GANTrainSteps(
- generator_train_steps=3,
- discriminator_train_steps=4)
+ generator_train_steps=3, discriminator_train_steps=4)
final_step = train.gan_train(
train_ops,
get_hooks_fn=get_hooks_fn_fn(train_steps),
@@ -927,8 +1093,7 @@ class GANTrainTest(test.TestCase):
discriminator_train_op=constant_op.constant(2.0),
global_step_inc_op=step.assign_add(1))
train_steps = namedtuples.GANTrainSteps(
- generator_train_steps=3,
- discriminator_train_steps=4)
+ generator_train_steps=3, discriminator_train_steps=4)
final_loss = slim_learning.train(
train_op=train_ops,