diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-27 02:01:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 02:05:57 -0700 |
commit | 86abbaa083beaca05ee32675ac7bfafb58a4557d (patch) | |
tree | 022e6d08f295c1c6378e4185af033d1261f4f7c7 /tensorflow/contrib/gan | |
parent | 632e3d66334ac3718a0fd41524c7dfc499363cab (diff) |
[TFGAN] StarGAN Estimator Implementation
PiperOrigin-RevId: 210334354
Diffstat (limited to 'tensorflow/contrib/gan')
5 files changed, 753 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 9866fccfba..9d0e6e1335 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -105,6 +105,7 @@ py_library( deps = [ ":gan_estimator", ":head", + ":stargan_estimator", "//tensorflow/python:util", ], ) @@ -534,6 +535,57 @@ py_test( ) py_library( + name = "stargan_estimator", + srcs = [ + "python/estimator/python/stargan_estimator.py", + "python/estimator/python/stargan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":summaries", + ":train", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + ], +) + +py_test( + name = "stargan_estimator_test", + srcs = ["python/estimator/python/stargan_estimator_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":namedtuples", + ":stargan_estimator", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:estimator_py", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@six_archive//:six", + ], +) + +py_library( name = "sliced_wasserstein", srcs = [ "python/eval/python/sliced_wasserstein.py", diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index c9f7bc61b2..99d38011ba 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -26,15 +26,18 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator from tensorflow.contrib.gan.python.estimator.python import head +from tensorflow.contrib.gan.python.estimator.python import stargan_estimator from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * from tensorflow.contrib.gan.python.estimator.python.head import * +from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'gan_estimator', + 'stargan_estimator', 'head', -] + gan_estimator.__all__ + head.__all__ +] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py new file mode 100644 index 0000000000..341bdf9fbb --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.stargan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = stargan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py new file mode 100644 index 0000000000..f60e16bc04 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py @@ -0,0 +1,363 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed StarGAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import enum + +from tensorflow.contrib.framework.python.ops import variables as variable_lib +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_inspect as inspect + +__all__ = ['StarGANEstimator', 'SummaryType'] + + +class SummaryType(enum.IntEnum): + NONE = 0 + VARIABLES = 1 + IMAGES = 2 + IMAGE_COMPARISON = 3 + + +_summary_type_map = { + SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, + SummaryType.IMAGES: tfgan_summaries.add_stargan_image_summaries, +} + + +class StarGANEstimator(estimator.Estimator): + """An estimator for Generative Adversarial Networks (GANs). + + This Estimator is backed by TFGAN. The network functions follow the TFGAN API + except for one exception: if either `generator_fn` or `discriminator_fn` have + an argument called `mode`, then the tf.Estimator mode is passed in for that + argument. This helps with operations like batch normalization, which have + different train and evaluation behavior. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + stargan_estimator = tfgan.estimator.StarGANEstimator( + model_dir, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + loss_fn=loss_fn, + generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5)) + + # Train estimator. + stargan_estimator.train(train_input_fn, steps) + + # Evaluate resulting estimator. + stargan_estimator.evaluate(eval_input_fn) + + # Generate samples from generator. + stargan_estimator = np.array([ + x for x in stargan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + model_dir=None, + generator_fn=None, + discriminator_fn=None, + loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + get_hooks_fn=None, + get_eval_metric_ops_fn=None, + add_summaries=None, + use_loss_summaries=True, + config=None): + """Initializes a StarGANEstimator instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `input_data`. Outputs + a Tensor in the range [-inf, inf]. See `TFGAN` for more details and + examples. + loss_fn: The loss function on the generator. Takes a `StarGANModel` + namedtuple and return a `GANLoss` namedtuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will be + called when the default graph is the `StarGANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. + get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a + dict of metric results keyed by name. The output of this function is + passed into `tf.estimator.EstimatorSpec` during evaluation. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + config: `RunConfig` object to configure the runtime settings. + + Raises: + ValueError: If loss functions aren't callable. + ValueError: If `use_loss_summaries` isn't boolean or `None`. + ValueError: If `get_hooks_fn` isn't callable or `None`. + """ + if not callable(loss_fn): + raise ValueError('loss_fn must be callable.') + if use_loss_summaries not in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + + def _model_fn(features, labels, mode): + """StarGANEstimator model function.""" + if mode not in [ + model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT + ]: + raise ValueError('Mode not recognized: %s' % mode) + + if mode == model_fn_lib.ModeKeys.PREDICT: + input_data = features[0] + input_data_domain_label = features[1] + else: + input_data = features # rename inputs for clarity + input_data_domain_label = labels # rename inputs for clarity + + # Make StarGANModel, which encapsulates the GAN model architectures. + gan_model = _get_gan_model(mode, generator_fn, discriminator_fn, + input_data, input_data_domain_label, + add_summaries) + + # Make the EstimatorSpec, which incorporates the StarGANModel, losses, + # eval, metrics, and optimizers (if required). + return _get_estimator_spec(mode, gan_model, loss_fn, + get_eval_metric_ops_fn, generator_optimizer, + discriminator_optimizer, get_hooks_fn) + + super(StarGANEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +def _get_gan_model(mode, + generator_fn, + discriminator_fn, + input_data, + input_data_domain_label, + add_summaries, + generator_scope='Generator'): + """Makes the StarGANModel tuple.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + gan_model = _make_prediction_gan_model(input_data, input_data_domain_label, + generator_fn, generator_scope) + else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL + gan_model = _make_gan_model(generator_fn, discriminator_fn, input_data, + input_data_domain_label, generator_scope, + add_summaries, mode) + + return gan_model + + +def _get_estimator_spec(mode, + gan_model, + loss_fn, + get_eval_metric_ops_fn, + generator_optimizer, + discriminator_optimizer, + get_hooks_fn=None): + """Get the EstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + gan_loss = loss_fn(gan_model) + if mode == model_fn_lib.ModeKeys.EVAL: + estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, + get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gopt = ( + generator_optimizer() + if callable(generator_optimizer) else generator_optimizer) + dopt = ( + discriminator_optimizer() + if callable(discriminator_optimizer) else discriminator_optimizer) + get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() + estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, + dopt, get_hooks_fn) + + return estimator_spec + + +def _make_gan_model(generator_fn, discriminator_fn, input_data, + input_data_domain_label, generator_scope, add_summaries, + mode): + """Construct a `StarGANModel`, and optionally pass in `mode`.""" + # If network functions have an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, mode=mode) + if 'mode' in inspect.getargspec(discriminator_fn).args: + discriminator_fn = functools.partial(discriminator_fn, mode=mode) + gan_model = tfgan_train.stargan_model( + generator_fn, + discriminator_fn, + input_data, + input_data_domain_label, + generator_scope=generator_scope) + if add_summaries: + if not isinstance(add_summaries, (tuple, list)): + add_summaries = [add_summaries] + with ops.name_scope(None): + for summary_type in add_summaries: + _summary_type_map[summary_type](gan_model) + + return gan_model + + +def _make_prediction_gan_model(input_data, input_data_domain_label, + generator_fn, generator_scope): + """Make a `StarGANModel` from just the generator.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial( + generator_fn, mode=model_fn_lib.ModeKeys.PREDICT) + with variable_scope.variable_scope(generator_scope) as gen_scope: + # pylint:disable=protected-access + input_data = tfgan_train._convert_tensor_or_l_or_d(input_data) + input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d( + input_data_domain_label) + # pylint:enable=protected-access + generated_data = generator_fn(input_data, input_data_domain_label) + generator_variables = variable_lib.get_trainable_variables(gen_scope) + + return tfgan_tuples.StarGANModel( + input_data=input_data, + input_data_domain_label=None, + generated_data=generated_data, + generated_data_domain_target=input_data_domain_label, + reconstructed_data=None, + discriminator_input_data_source_predication=None, + discriminator_generated_data_source_predication=None, + discriminator_input_data_domain_predication=None, + discriminator_generated_data_domain_predication=None, + generator_variables=generator_variables, + generator_scope=generator_scope, + generator_fn=generator_fn, + discriminator_variables=None, + discriminator_scope=None, + discriminator_fn=None) + + +def _get_eval_estimator_spec(gan_model, + gan_loss, + get_eval_metric_ops_fn=None, + name=None): + """Return an EstimatorSpec for the eval case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, gan_loss.discriminator_loss]): + + def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + + eval_metric_ops = { + _summary_key(name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def _get_train_estimator_spec(gan_model, + gan_loss, + generator_optimizer, + discriminator_optimizer, + get_hooks_fn, + train_op_fn=tfgan_train.gan_train_ops): + """Return an EstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, + discriminator_optimizer) + training_hooks = get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) + + +def stargan_prediction_input_fn_wrapper(fn): + """StarGAN Estimator prediction input_fn wrapper. + + Since estimator will disregard the "label" variable pass to the model, we will + use a wrapper to pack the (feature, label) tuple as feature passed to the + model. + + Args: + fn: input_fn for the prediction. + + Returns: + A tuple ((feature, label), None) where the second element is the dummy label + to be disregarded and the first element is the true input to the estimator. + """ + + def new_fn(): + return fn(), None + + return new_fn diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py new file mode 100644 index 0000000000..2ec7938c7c --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's stargan_estimator.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from absl.testing import parameterized +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl as estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training +from tensorflow.python.training import training_util + + +def dummy_generator_fn(input_data, input_data_domain_label, mode): + del input_data_domain_label, mode + + return variable_scope.get_variable('dummy_g', initializer=0.5) * input_data + + +def dummy_discriminator_fn(input_data, num_domains, mode): + del mode + + hidden = layers.flatten(input_data) + output_src = math_ops.reduce_mean(hidden, axis=1) + output_cls = layers.fully_connected( + inputs=hidden, num_outputs=num_domains, scope='debug') + + return output_src, output_cls + + +class StarGetGANModelTest(test.TestCase, parameterized.TestCase): + """Tests that `StarGetGANModel` produces the correct model.""" + + @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_gan_model(self, mode): + with ops.Graph().as_default(): + input_data = array_ops.ones([6, 4, 4, 3]) + input_data_domain_label = array_ops.one_hot([0] * 6, 5) + gan_model = estimator._get_gan_model( + mode, + dummy_generator_fn, + dummy_discriminator_fn, + input_data, + input_data_domain_label, + add_summaries=False) + + self.assertEqual(input_data, gan_model.input_data) + self.assertIsNotNone(gan_model.generated_data) + self.assertIsNotNone(gan_model.generated_data_domain_target) + self.assertEqual(1, len(gan_model.generator_variables)) + self.assertIsNotNone(gan_model.generator_scope) + self.assertIsNotNone(gan_model.generator_fn) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertIsNone(gan_model.input_data_domain_label) + self.assertEqual(input_data_domain_label, + gan_model.generated_data_domain_target) + self.assertIsNone(gan_model.reconstructed_data) + self.assertIsNone(gan_model.discriminator_input_data_source_predication) + self.assertIsNone( + gan_model.discriminator_generated_data_source_predication) + self.assertIsNone(gan_model.discriminator_input_data_domain_predication) + self.assertIsNone( + gan_model.discriminator_generated_data_domain_predication) + self.assertIsNone(gan_model.discriminator_variables) + self.assertIsNone(gan_model.discriminator_scope) + self.assertIsNone(gan_model.discriminator_fn) + else: + self.assertEqual(input_data_domain_label, + gan_model.input_data_domain_label) + self.assertIsNotNone(gan_model.reconstructed_data.shape) + self.assertIsNotNone( + gan_model.discriminator_input_data_source_predication) + self.assertIsNotNone( + gan_model.discriminator_generated_data_source_predication) + self.assertIsNotNone( + gan_model.discriminator_input_data_domain_predication) + self.assertIsNotNone( + gan_model.discriminator_generated_data_domain_predication) + self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.discriminator_scope) + self.assertIsNotNone(gan_model.discriminator_fn) + + +def get_dummy_gan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=array_ops.ones([1, 2, 2, 3]), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var, + discriminator_generated_data_source_predication=array_ops.ones( + [1]) * gen_var * dis_var, + discriminator_input_data_domain_predication=array_ops.ones([1, 2 + ]) * dis_var, + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) * + gen_var * dis_var, + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def dummy_loss_fn(gan_model): + loss = math_ops.reduce_sum( + gan_model.discriminator_input_data_domain_predication - + gan_model.discriminator_generated_data_domain_predication) + loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data) + return tfgan_tuples.GANLoss(loss, loss) + + +def get_metrics(gan_model): + return { + 'mse_custom_metric': + metrics_lib.mean_squared_error(gan_model.input_data, + gan_model.generated_data) + } + + +class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + cls._generator_optimizer = training.GradientDescentOptimizer(1.0) + cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) + + @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_estimator_spec(self, mode): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer) + + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metric_ops) + + +# TODO(joelshor): Add pandas test. +class StarGANEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow(self, + train_input_fn, + eval_input_fn, + predict_input_fn, + prediction_size, + lr_decay=False): + + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.StarGANEstimator( + generator_fn=dummy_generator_fn, + discriminator_fn=dummy_discriminator_fn, + loss_fn=dummy_loss_fn, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + get_eval_metric_ops_fn=get_metrics, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) + self.assertIn('mse_custom_metric', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([x for x in est.predict(predict_input_fn)]) + + self.assertAllEqual(prediction_size, predictions.shape) + + @staticmethod + def _numpy_input_fn_wrapper(numpy_input_fn, batch_size, label_size): + """Wrapper to remove the dictionary in numpy_input_fn. + + NOTE: + We create the domain_label here because the model expect a fully define + batch_size from the input. + + Args: + numpy_input_fn: input_fn created from numpy_io + batch_size: (int) number of items for each batch + label_size: (int) number of domains + + Returns: + a new input_fn + """ + + def new_input_fn(): + features = numpy_input_fn() + return features['x'], array_ops.one_hot([0] * batch_size, label_size) + + return new_input_fn + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + batch_size = 5 + img_size = 8 + channel_size = 3 + label_size = 3 + image_data = np.zeros( + [batch_size, img_size, img_size, channel_size], dtype=np.float32) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': image_data}, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': image_data}, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': image_data}, shuffle=False) + + train_input_fn = self._numpy_input_fn_wrapper(train_input_fn, batch_size, + label_size) + eval_input_fn = self._numpy_input_fn_wrapper(eval_input_fn, batch_size, + label_size) + predict_input_fn = self._numpy_input_fn_wrapper(predict_input_fn, + batch_size, label_size) + + predict_input_fn = estimator.stargan_prediction_input_fn_wrapper( + predict_input_fn) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, img_size, img_size, channel_size]) + + +if __name__ == '__main__': + test.main() |