diff options
author | Alexander Pantyukhin <apantykhin@gmail.com> | 2018-05-22 18:06:26 +0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-05-22 18:06:26 +0400 |
commit | b98adf4f09632d5f46d82d622d6627aed310541b (patch) | |
tree | 071f9109bcecae2271260e1a330b169e83daf483 /tensorflow/contrib/gan | |
parent | 8dc3b3c453180211f4be5302f957664004e1ec04 (diff) | |
parent | 294fdb0c5452715ec46d0555245e3b130308ebb4 (diff) |
Merge branch 'master' into ganhead_constructor_validate
Diffstat (limited to 'tensorflow/contrib/gan')
10 files changed, 85 insertions, 32 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e3fc6bf0f0..4092b32004 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator): generator_optimizer=None, discriminator_optimizer=None, get_hooks_fn=None, + get_eval_metric_ops_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator): list of hooks. These hooks are run on the generator and discriminator train ops, and can be used to implement the GAN training scheme. Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn) + use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 387a62bd74..955482599b 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache @@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) + def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( @@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase): discriminator_loss_fn=losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, model_dir=self._model_dir) # TRAIN @@ -213,6 +221,9 @@ class GANEstimatorIntegrationTest(test.TestCase): scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 652ffee30a..18105f9425 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -26,17 +26,21 @@ from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops from tensorflow.python.training import optimizer +from tensorflow.python.ops import metrics as metrics_lib __all__ = [ 'GANHead', 'gan_head', ] +def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - name=None): + get_eval_metric_ops_fn=None, name=None): """Creates a `GANHead`. Args: @@ -48,9 +52,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + If `None`, uses defaults. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. @@ -63,6 +70,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, + get_eval_metric_ops_fn=get_eval_metric_ops_fn, name=name) @@ -73,6 +81,7 @@ class GANHead(head._Head): # pylint: disable=protected-access generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, + get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. @@ -86,8 +95,11 @@ class GANHead(head._Head): # pylint: disable=protected-access discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. - get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. Defaults to `train.get_sequential_train_hooks()` + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. Defaults to `train.get_sequential_train_hooks()` + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ @@ -120,6 +132,8 @@ class GANHead(head._Head): # pylint: disable=protected-access self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn + self._get_eval_metric_ops_fn = get_eval_metric_ops_fn + self._name = name @property def name(self): @@ -189,13 +203,26 @@ class GANHead(head._Head): # pylint: disable=protected-access gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + eval_metric_ops = { + _summary_key(self._name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(self._name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if self._get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, - # TODO(joelshor): Add metrics. If head name provided, append it to - # metric keys. - eval_metric_ops={}) + eval_metric_ops=eval_metric_ops) elif mode == model_fn_lib.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None.') diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py index 8168f005cd..6587f1fc60 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase): generator_loss_fn=dummy_loss, discriminator_loss_fn=dummy_loss, generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + discriminator_optimizer=training.GradientDescentOptimizer(1.0), + get_eval_metric_ops_fn=self.get_metrics) self.assertTrue(isinstance(self.gan_head, head.GANHead)) + def get_metrics(self, gan_model): + self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) + return {} + def _test_modes_helper(self, mode): self.gan_head.create_estimator_spec( features=None, diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 47e51415fd..d914f54945 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -488,25 +488,25 @@ def frechet_classifier_distance(real_images, The Frechet Inception distance. A floating-point scalar of the same type as the output of `classifier_fn`. """ - real_images_list = array_ops.split( real_images, num_or_size_splits=num_batches) generated_images_list = array_ops.split( generated_images, num_or_size_splits=num_batches) - imgs = array_ops.stack(real_images_list + generated_images_list) + real_imgs = array_ops.stack(real_images_list) + generated_imgs = array_ops.stack(generated_images_list) # Compute the activations using the memory-efficient `map_fn`. - activations = functional_ops.map_fn( - fn=classifier_fn, - elems=imgs, - parallel_iterations=1, - back_prop=False, - swap_memory=True, - name='RunClassifier') + def compute_activations(elems): + return functional_ops.map_fn(fn=classifier_fn, + elems=elems, + parallel_iterations=1, + back_prop=False, + swap_memory=True, + name='RunClassifier') - # Split the activations by the real and generated images. - real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) + real_a = compute_activations(real_imgs) + gen_a = compute_activations(generated_imgs) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) @@ -697,18 +697,20 @@ def frechet_classifier_distance_from_activations(real_activations, # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) m_w = math_ops.reduce_mean(generated_activations, 0) - num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) + num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0]) + num_examples_generated = math_ops.to_double( + array_ops.shape(generated_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( real_centered, real_centered, transpose_a=True) / ( - num_examples - 1) + num_examples_real - 1) gen_centered = generated_activations - m_w sigma_w = math_ops.matmul( gen_centered, gen_centered, transpose_a=True) / ( - num_examples - 1) + num_examples_generated - 1) # Find the Tr(sqrt(sigma sigma_w)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py index 4b10bc0f8e..4b1105f6bd 100644 --- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): proj = random_ops.random_normal( [array_ops.shape(a)[1], random_projection_dim]) proj *= math_ops.rsqrt( - math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True)) # Project both distributions and sort them. proj_a = math_ops.matmul(a, proj) proj_b = math_ops.matmul(b, proj) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py index df71187fbd..a9b8faa712 100644 --- a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Miscellanous utilities for TFGAN code and examples.""" +"""Miscellaneous utilities for TFGAN code and examples.""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py index f8b372546b..650eab97a3 100644 --- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -64,11 +64,11 @@ def _statistics(x, axes): y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x # Compute true mean while keeping the dims for proper broadcasting. - shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True)) + shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True)) - shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True) + shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True) mean = shifted_mean + shift - mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True) + mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True) mean = array_ops.squeeze(mean, axes) mean_squared = array_ops.squeeze(mean_squared, axes) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 2889e93743..9f5fee4542 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -570,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mutual_information_loss/mul' + self._expected_op_name = 'mutual_information_loss/mul_1' self._batch_size = 2 diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 73acd05b60..6fa43059f3 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -710,7 +710,10 @@ def gan_train_ops( be used to train a generator/discriminator pair. """ if isinstance(model, namedtuples.CycleGANModel): - saved_params = locals() + # Get and store all arguments other than model and loss from locals. + # Contents of locals should not be modified, may not affect values. So make + # a copy. https://docs.python.org/2/library/functions.html#locals. + saved_params = dict(locals()) saved_params.pop('model', None) saved_params.pop('loss', None) kwargs = saved_params.pop('kwargs', {}) |