diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-23 15:05:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-23 15:29:53 -0700 |
commit | 931a3054d2c13c3438fc58978b3463a0bd268aee (patch) | |
tree | 8474f95a82c80115d8373e9da89c5c6481966417 /tensorflow/contrib/gan | |
parent | 09c4c387913c86247121589caa7fb2e85351fa58 (diff) |
[tfgan] Issue #18041: Make pooling consistent in `gan_loss`.
PiperOrigin-RevId: 205731279
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/train.py | 56 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/train_test.py | 71 |
2 files changed, 61 insertions, 66 deletions
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 49d9327333..df603d1f18 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -514,33 +514,42 @@ def _tensor_pool_adjusted_model(model, tensor_pool_fn): Raises: ValueError: If tensor pool does not support the `model`. """ - if tensor_pool_fn is None: - return model - - pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( - (model.generated_data, model.generator_inputs)) - if isinstance(model, namedtuples.GANModel): + pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( + (model.generator_inputs, model.generated_data)) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): dis_gen_outputs = model.discriminator_fn(pooled_generated_data, pooled_generator_inputs) - return model._replace(discriminator_gen_outputs=dis_gen_outputs) + return model._replace( + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + discriminator_gen_outputs=dis_gen_outputs) elif isinstance(model, namedtuples.ACGANModel): + pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( + (model.generator_inputs, model.generated_data)) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (dis_pooled_gen_outputs, - dis_pooled_gen_classification_logits) = model.discriminator_fn( + (pooled_discriminator_gen_outputs, + pooled_discriminator_gen_classification_logits) = model.discriminator_fn( pooled_generated_data, pooled_generator_inputs) return model._replace( - discriminator_gen_outputs=dis_pooled_gen_outputs, + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + discriminator_gen_outputs=pooled_discriminator_gen_outputs, discriminator_gen_classification_logits= - dis_pooled_gen_classification_logits) + pooled_discriminator_gen_classification_logits) elif isinstance(model, namedtuples.InfoGANModel): + pooled_generator_inputs, pooled_generated_data, pooled_structured_input = ( + tensor_pool_fn((model.generator_inputs, model.generated_data, + model.structured_generator_inputs))) with variable_scope.variable_scope(model.discriminator_scope, reuse=True): - (dis_pooled_gen_outputs, + (pooled_discriminator_gen_outputs, pooled_predicted_distributions) = model.discriminator_and_aux_fn( pooled_generated_data, pooled_generator_inputs) return model._replace( - discriminator_gen_outputs=dis_pooled_gen_outputs, + generator_inputs=pooled_generator_inputs, + generated_data=pooled_generated_data, + structured_generator_inputs=pooled_structured_input, + discriminator_gen_outputs=pooled_discriminator_gen_outputs, predicted_distributions=pooled_predicted_distributions) else: raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) @@ -632,33 +641,38 @@ def gan_loss( 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % type(model)) + # Optionally create pooled model. + pooled_model = (_tensor_pool_adjusted_model(model, tensor_pool_fn) if + tensor_pool_fn else model) + # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn( - _tensor_pool_adjusted_model(model, tensor_pool_fn), - add_summaries=add_summaries) + dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, + pooled_model, epsilon=gradient_penalty_epsilon, target=gradient_penalty_target, one_sided=gradient_penalty_one_sided, add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): - info_loss = tfgan_losses.mutual_information_penalty( + gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_loss += mutual_information_penalty_weight * info_loss - gen_loss += mutual_information_penalty_weight * info_loss + dis_info_loss = (gen_info_loss if tensor_pool_fn is None else + tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries)) + gen_loss += mutual_information_penalty_weight * gen_info_loss + dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): ac_gen_loss = tfgan_losses.acgan_generator_loss( model, add_summaries=add_summaries) gen_loss += aux_cond_generator_weight * ac_gen_loss if _use_aux_loss(aux_cond_discriminator_weight): ac_disc_loss = tfgan_losses.acgan_discriminator_loss( - model, add_summaries=add_summaries) + pooled_model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss # Gathers auxiliary losses. if model.generator_scope: diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index cd99a33c03..fa52e9cca1 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -278,25 +278,6 @@ def get_sync_optimizer(): replicas_to_aggregate=1) -def get_tensor_pool_fn(pool_size): - - def tensor_pool_fn_impl(input_values): - return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size) - - return tensor_pool_fn_impl - - -def get_tensor_pool_fn_for_infogan(pool_size): - - def tensor_pool_fn_impl(input_values): - generated_data, generator_inputs = input_values - output_values = random_tensor_pool.tensor_pool( - [generated_data] + generator_inputs, pool_size=pool_size) - return output_values[0], output_values[1:] - - return tensor_pool_fn_impl - - class GANModelTest(test.TestCase, parameterized.TestCase): """Tests for `gan_model`.""" @@ -344,7 +325,6 @@ class StarGANModelTest(test.TestCase): @staticmethod def create_input_and_label_tensor(batch_size, img_size, c_size, num_domains): - input_tensor_list = [] label_tensor_list = [] for _ in range(num_domains): @@ -356,7 +336,6 @@ class StarGANModelTest(test.TestCase): return input_tensor_list, label_tensor_list def test_generate_stargan_random_domain_target(self): - batch_size = 8 domain_numbers = 3 @@ -371,7 +350,6 @@ class StarGANModelTest(test.TestCase): self.assertEqual(1, np.max(target)) def test_stargan_model_output_type(self): - batch_size = 2 img_size = 16 c_size = 3 @@ -395,7 +373,6 @@ class StarGANModelTest(test.TestCase): self.assertTrue(callable(model.generator_fn)) def test_stargan_model_generator_output(self): - batch_size = 2 img_size = 16 c_size = 3 @@ -426,7 +403,6 @@ class StarGANModelTest(test.TestCase): reconstructed_data.shape) def test_stargan_model_discriminator_output(self): - batch_size = 2 img_size = 16 c_size = 3 @@ -643,10 +619,7 @@ class GANLossTest(test.TestCase, parameterized.TestCase): def test_tensor_pool(self, create_gan_model_fn): """Test tensor pool option.""" model = create_gan_model_fn() - if isinstance(model, namedtuples.InfoGANModel): - tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5) - else: - tensor_pool_fn = get_tensor_pool_fn(pool_size=5) + tensor_pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=5) loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) self.assertIsInstance(loss, namedtuples.GANLoss) @@ -656,6 +629,25 @@ class GANLossTest(test.TestCase, parameterized.TestCase): for _ in range(10): sess.run([loss.generator_loss, loss.discriminator_loss]) + def test_discriminator_only_sees_pool(self): + """Checks that discriminator only sees pooled values.""" + def checker_gen_fn(_): + return constant_op.constant(0.0) + model = train.gan_model( + checker_gen_fn, + discriminator_model, + real_data=array_ops.zeros([]), + generator_inputs=random_ops.random_normal([])) + def tensor_pool_fn(_): + return (random_ops.random_uniform([]), random_ops.random_uniform([])) + def checker_dis_fn(inputs, _): + """Discriminator that checks that it only sees pooled Tensors.""" + self.assertFalse(constant_op.is_constant(inputs)) + return inputs + model = model._replace( + discriminator_fn=checker_dis_fn) + train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( @@ -673,8 +665,8 @@ class GANLossTest(test.TestCase, parameterized.TestCase): class TensorPoolAdjusteModelTest(test.TestCase): - def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2, - pool_size): + def _check_tensor_pool_adjusted_model_outputs( + self, tensor1, tensor2, pool_size): history_values = [] with self.test_session(use_gpu=True) as sess: variables.global_variables_initializer().run() @@ -691,10 +683,9 @@ class TensorPoolAdjusteModelTest(test.TestCase): # pool). self.assertTrue(any([(v == t2).all() for v in history_values])) - def _make_new_model_and_check(self, model, pool_size, - pool_fn=get_tensor_pool_fn): - new_model = train._tensor_pool_adjusted_model( - model, pool_fn(pool_size=pool_size)) + def _make_new_model_and_check(self, model, pool_size): + pool_fn = lambda x: random_tensor_pool.tensor_pool(x, pool_size=pool_size) + new_model = train._tensor_pool_adjusted_model(model, pool_fn) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, @@ -702,15 +693,6 @@ class TensorPoolAdjusteModelTest(test.TestCase): return new_model - def test_tensor_pool_adjusted_model_no_pool(self): - """Test `_tensor_pool_adjusted_model` for no pool size.""" - model = create_gan_model() - new_model = train._tensor_pool_adjusted_model(model, None) - - # Check values. - self.assertIs(new_model.discriminator_gen_outputs, - model.discriminator_gen_outputs) - def test_tensor_pool_adjusted_model_gan(self): """Test `_tensor_pool_adjusted_model` for gan model.""" pool_size = 5 @@ -726,8 +708,7 @@ class TensorPoolAdjusteModelTest(test.TestCase): """Test _tensor_pool_adjusted_model for infogan model.""" pool_size = 5 model = create_infogan_model() - new_model = self._make_new_model_and_check( - model, pool_size, pool_fn=get_tensor_pool_fn_for_infogan) + new_model = self._make_new_model_and_check(model, pool_size) # Check values. self.assertIsNot(new_model.predicted_distributions, |