aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-23 15:05:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 15:29:53 -0700
commit931a3054d2c13c3438fc58978b3463a0bd268aee (patch)
tree8474f95a82c80115d8373e9da89c5c6481966417 /tensorflow/contrib/gan
parent09c4c387913c86247121589caa7fb2e85351fa58 (diff)
[tfgan] Issue #18041: Make pooling consistent in `gan_loss`.
PiperOrigin-RevId: 205731279
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/train.py56
-rw-r--r--tensorflow/contrib/gan/python/train_test.py71
2 files changed, 61 insertions, 66 deletions
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 49d9327333..df603d1f18 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -514,33 +514,42 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn):
Raises:
ValueError: If tensor pool does not support the `model`.
"""
- if tensor_pool_fn is None:
- return model
-
- pooled_generated_data, pooled_generator_inputs = tensor_pool_fn(
- (model.generated_data, model.generator_inputs))
-
if isinstance(model, namedtuples.GANModel):
+ pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
+ (model.generator_inputs, model.generated_data))
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
dis_gen_outputs = model.discriminator_fn(pooled_generated_data,
pooled_generator_inputs)
- return model._replace(discriminator_gen_outputs=dis_gen_outputs)
+ return model._replace(
+ generator_inputs=pooled_generator_inputs,
+ generated_data=pooled_generated_data,
+ discriminator_gen_outputs=dis_gen_outputs)
elif isinstance(model, namedtuples.ACGANModel):
+ pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
+ (model.generator_inputs, model.generated_data))
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
- (dis_pooled_gen_outputs,
- dis_pooled_gen_classification_logits) = model.discriminator_fn(
+ (pooled_discriminator_gen_outputs,
+ pooled_discriminator_gen_classification_logits) = model.discriminator_fn(
pooled_generated_data, pooled_generator_inputs)
return model._replace(
- discriminator_gen_outputs=dis_pooled_gen_outputs,
+ generator_inputs=pooled_generator_inputs,
+ generated_data=pooled_generated_data,
+ discriminator_gen_outputs=pooled_discriminator_gen_outputs,
discriminator_gen_classification_logits=
- dis_pooled_gen_classification_logits)
+ pooled_discriminator_gen_classification_logits)
elif isinstance(model, namedtuples.InfoGANModel):
+ pooled_generator_inputs, pooled_generated_data, pooled_structured_input = (
+ tensor_pool_fn((model.generator_inputs, model.generated_data,
+ model.structured_generator_inputs)))
with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
- (dis_pooled_gen_outputs,
+ (pooled_discriminator_gen_outputs,
pooled_predicted_distributions) = model.discriminator_and_aux_fn(
pooled_generated_data, pooled_generator_inputs)
return model._replace(
- discriminator_gen_outputs=dis_pooled_gen_outputs,
+ generator_inputs=pooled_generator_inputs,
+ generated_data=pooled_generated_data,
+ structured_generator_inputs=pooled_structured_input,
+ discriminator_gen_outputs=pooled_discriminator_gen_outputs,
predicted_distributions=pooled_predicted_distributions)
else:
raise ValueError('Tensor pool does not support `model`: %s.' % type(model))
@@ -632,33 +641,38 @@ def gan_loss(
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
type(model))
+ # Optionally create pooled model.
+ pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if
+ tensor_pool_fn else model)
+
# Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
- dis_loss = discriminator_loss_fn(
- _tensor_pool_adjusted_model(model, tensor_pool_fn),
- add_summaries=add_summaries)
+ dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries)
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
- model,
+ pooled_model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
- info_loss = tfgan_losses.mutual_information_penalty(
+ gen_info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries)
- dis_loss += mutual_information_penalty_weight * info_loss
- gen_loss += mutual_information_penalty_weight * info_loss
+ dis_info_loss = (gen_info_loss if tensor_pool_fn is None else
+ tfgan_losses.mutual_information_penalty(
+ pooled_model, add_summaries=add_summaries))
+ gen_loss += mutual_information_penalty_weight * gen_info_loss
+ dis_loss += mutual_information_penalty_weight * dis_info_loss
if _use_aux_loss(aux_cond_generator_weight):
ac_gen_loss = tfgan_losses.acgan_generator_loss(
model, add_summaries=add_summaries)
gen_loss += aux_cond_generator_weight * ac_gen_loss
if _use_aux_loss(aux_cond_discriminator_weight):
ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
- model, add_summaries=add_summaries)
+ pooled_model, add_summaries=add_summaries)
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
# Gathers auxiliary losses.
if model.generator_scope:
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index cd99a33c03..fa52e9cca1 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -278,25 +278,6 @@ def get_sync_optimizer():
replicas_to_aggregate=1)
-def get_tensor_pool_fn(pool_size):
-
- def tensor_pool_fn_impl(input_values):
- return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size)
-
- return tensor_pool_fn_impl
-
-
-def get_tensor_pool_fn_for_infogan(pool_size):
-
- def tensor_pool_fn_impl(input_values):
- generated_data, generator_inputs = input_values
- output_values = random_tensor_pool.tensor_pool(
- [generated_data] + generator_inputs, pool_size=pool_size)
- return output_values[0], output_values[1:]
-
- return tensor_pool_fn_impl
-
-
class GANModelTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_model`."""
@@ -344,7 +325,6 @@ class StarGANModelTest(test.TestCase):
@staticmethod
def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains):
-
input_tensor_list = []
label_tensor_list = []
for _ in range(num_domains):
@@ -356,7 +336,6 @@ class StarGANModelTest(test.TestCase):
return input_tensor_list, label_tensor_list
def test_generate_stargan_random_domain_target(self):
-
batch_size = 8
domain_numbers = 3
@@ -371,7 +350,6 @@ class StarGANModelTest(test.TestCase):
self.assertEqual(1, np.max(target))
def test_stargan_model_output_type(self):
-
batch_size = 2
img_size = 16
c_size = 3
@@ -395,7 +373,6 @@ class StarGANModelTest(test.TestCase):
self.assertTrue(callable(model.generator_fn))
def test_stargan_model_generator_output(self):
-
batch_size = 2
img_size = 16
c_size = 3
@@ -426,7 +403,6 @@ class StarGANModelTest(test.TestCase):
reconstructed_data.shape)
def test_stargan_model_discriminator_output(self):
-
batch_size = 2
img_size = 16
c_size = 3
@@ -643,10 +619,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
def test_tensor_pool(self, create_gan_model_fn):
"""Test tensor pool option."""
model = create_gan_model_fn()
- if isinstance(model, namedtuples.InfoGANModel):
- tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
- else:
- tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
+ tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5)
loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
self.assertIsInstance(loss, namedtuples.GANLoss)
@@ -656,6 +629,25 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
for _ in range(10):
sess.run([loss.generator_loss, loss.discriminator_loss])
+ def test_discriminator_only_sees_pool(self):
+ """Checks that discriminator only sees pooled values."""
+ def checker_gen_fn(_):
+ return constant_op.constant(0.0)
+ model = train.gan_model(
+ checker_gen_fn,
+ discriminator_model,
+ real_data=array_ops.zeros([]),
+ generator_inputs=random_ops.random_normal([]))
+ def tensor_pool_fn(_):
+ return (random_ops.random_uniform([]), random_ops.random_uniform([]))
+ def checker_dis_fn(inputs, _):
+ """Discriminator that checks that it only sees pooled Tensors."""
+ self.assertFalse(constant_op.is_constant(inputs))
+ return inputs
+ model = model._replace(
+ discriminator_fn=checker_dis_fn)
+ train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
+
def test_doesnt_crash_when_in_nested_scope(self):
with variable_scope.variable_scope('outer_scope'):
gan_model = train.gan_model(
@@ -673,8 +665,8 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
class TensorPoolAdjusteModelTest(test.TestCase):
- def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2,
- pool_size):
+ def _check_tensor_pool_adjusted_model_outputs(
+ self, tensor1, tensor2, pool_size):
history_values = []
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
@@ -691,10 +683,9 @@ class TensorPoolAdjusteModelTest(test.TestCase):
# pool).
self.assertTrue(any([(v == t2).all() for v in history_values]))
- def _make_new_model_and_check(self, model, pool_size,
- pool_fn=get_tensor_pool_fn):
- new_model = train._tensor_pool_adjusted_model(
- model, pool_fn(pool_size=pool_size))
+ def _make_new_model_and_check(self, model, pool_size):
+ pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size)
+ new_model = train._tensor_pool_adjusted_model(model, pool_fn)
# 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
self.assertIsNot(new_model.discriminator_gen_outputs,
@@ -702,15 +693,6 @@ class TensorPoolAdjusteModelTest(test.TestCase):
return new_model
- def test_tensor_pool_adjusted_model_no_pool(self):
- """Test `_tensor_pool_adjusted_model` for no pool size."""
- model = create_gan_model()
- new_model = train._tensor_pool_adjusted_model(model, None)
-
- # Check values.
- self.assertIs(new_model.discriminator_gen_outputs,
- model.discriminator_gen_outputs)
-
def test_tensor_pool_adjusted_model_gan(self):
"""Test `_tensor_pool_adjusted_model` for gan model."""
pool_size = 5
@@ -726,8 +708,7 @@ class TensorPoolAdjusteModelTest(test.TestCase):
"""Test _tensor_pool_adjusted_model for infogan model."""
pool_size = 5
model = create_infogan_model()
- new_model = self._make_new_model_and_check(
- model, pool_size, pool_fn=get_tensor_pool_fn_for_infogan)
+ new_model = self._make_new_model_and_check(model, pool_size)
# Check values.
self.assertIsNot(new_model.predicted_distributions,