aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-27 02:01:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 02:05:57 -0700
commit86abbaa083beaca05ee32675ac7bfafb58a4557d (patch)
tree022e6d08f295c1c6378e4185af033d1261f4f7c7 /tensorflow/contrib/gan
parent632e3d66334ac3718a0fd41524c7dfc499363cab (diff)
[TFGAN] StarGAN Estimator Implementation
PiperOrigin-RevId: 210334354
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD52
-rw-r--r--tensorflow/contrib/gan/python/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py28
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py363
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py306
5 files changed, 753 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 9866fccfba..9d0e6e1335 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -105,6 +105,7 @@ py_library(
deps = [
":gan_estimator",
":head",
+ ":stargan_estimator",
"//tensorflow/python:util",
],
)
@@ -534,6 +535,57 @@ py_test(
)
py_library(
+ name = "stargan_estimator",
+ srcs = [
+ "python/estimator/python/stargan_estimator.py",
+ "python/estimator/python/stargan_estimator_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":namedtuples",
+ ":summaries",
+ ":train",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "stargan_estimator_test",
+ srcs = ["python/estimator/python/stargan_estimator_test.py"],
+ shard_count = 1,
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":namedtuples",
+ ":stargan_estimator",
+ ":tuple_losses",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/learn",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "sliced_wasserstein",
srcs = [
"python/eval/python/sliced_wasserstein.py",
diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py
index c9f7bc61b2..99d38011ba 100644
--- a/tensorflow/contrib/gan/python/estimator/__init__.py
+++ b/tensorflow/contrib/gan/python/estimator/__init__.py
@@ -26,15 +26,18 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
from tensorflow.contrib.gan.python.estimator.python import head
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
from tensorflow.contrib.gan.python.estimator.python.head import *
+from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'gan_estimator',
+ 'stargan_estimator',
'head',
-] + gan_estimator.__all__ + head.__all__
+] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py
new file mode 100644
index 0000000000..341bdf9fbb
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""`tf.Learn` components for `GANEstimator`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl
+# pylint: disable=wildcard-import
+from tensorflow.contrib.gan.python.estimator.python.stargan_estimator_impl import *
+# pylint: enable=wildcard-import
+from tensorflow.python.util.all_util import remove_undocumented
+
+__all__ = stargan_estimator_impl.__all__
+remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py
new file mode 100644
index 0000000000..f60e16bc04
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py
@@ -0,0 +1,363 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A TFGAN-backed StarGAN Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import enum
+
+from tensorflow.contrib.framework.python.ops import variables as variable_lib
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
+from tensorflow.contrib.gan.python import train as tfgan_train
+from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import tf_inspect as inspect
+
+__all__ = ['StarGANEstimator', 'SummaryType']
+
+
+class SummaryType(enum.IntEnum):
+ NONE = 0
+ VARIABLES = 1
+ IMAGES = 2
+ IMAGE_COMPARISON = 3
+
+
+_summary_type_map = {
+ SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries,
+ SummaryType.IMAGES: tfgan_summaries.add_stargan_image_summaries,
+}
+
+
+class StarGANEstimator(estimator.Estimator):
+ """An estimator for Generative Adversarial Networks (GANs).
+
+ This Estimator is backed by TFGAN. The network functions follow the TFGAN API
+ except for one exception: if either `generator_fn` or `discriminator_fn` have
+ an argument called `mode`, then the tf.Estimator mode is passed in for that
+ argument. This helps with operations like batch normalization, which have
+ different train and evaluation behavior.
+
+ Example:
+
+ ```python
+ import tensorflow as tf
+ tfgan = tf.contrib.gan
+
+ # See TFGAN's `train.py` for a description of the generator and
+ # discriminator API.
+ def generator_fn(generator_inputs):
+ ...
+ return generated_data
+
+ def discriminator_fn(data, conditioning):
+ ...
+ return logits
+
+ # Create GAN estimator.
+ stargan_estimator = tfgan.estimator.StarGANEstimator(
+ model_dir,
+ generator_fn=generator_fn,
+ discriminator_fn=discriminator_fn,
+ loss_fn=loss_fn,
+ generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5),
+ discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5))
+
+ # Train estimator.
+ stargan_estimator.train(train_input_fn, steps)
+
+ # Evaluate resulting estimator.
+ stargan_estimator.evaluate(eval_input_fn)
+
+ # Generate samples from generator.
+ stargan_estimator = np.array([
+ x for x in stargan_estimator.predict(predict_input_fn)])
+ ```
+ """
+
+ def __init__(self,
+ model_dir=None,
+ generator_fn=None,
+ discriminator_fn=None,
+ loss_fn=None,
+ generator_optimizer=None,
+ discriminator_optimizer=None,
+ get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
+ add_summaries=None,
+ use_loss_summaries=True,
+ config=None):
+ """Initializes a StarGANEstimator instance.
+
+ Args:
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ generator_fn: A python function that takes a Tensor, Tensor list, or
+ Tensor dictionary as inputs and returns the outputs of the GAN
+ generator. See `TFGAN` for more details and examples. Additionally, if
+ it has an argument called `mode`, the Estimator's `mode` will be passed
+ in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
+ normalization.
+ discriminator_fn: A python function that takes the output of
+ `generator_fn` or real data in the GAN setup, and `input_data`. Outputs
+ a Tensor in the range [-inf, inf]. See `TFGAN` for more details and
+ examples.
+ loss_fn: The loss function on the generator. Takes a `StarGANModel`
+ namedtuple and return a `GANLoss` namedtuple.
+ generator_optimizer: The optimizer for generator updates, or a function
+ that takes no arguments and returns an optimizer. This function will be
+ called when the default graph is the `StarGANEstimator`'s graph, so
+ utilities like `tf.contrib.framework.get_or_create_global_step` will
+ work.
+ discriminator_optimizer: Same as `generator_optimizer`, but for the
+ discriminator updates.
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks. These hooks are run on the generator and discriminator
+ train ops, and can be used to implement the GAN training scheme.
+ Defaults to `train.get_sequential_train_hooks()`.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
+ add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
+ use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
+ If `None`, uses defaults.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: If loss functions aren't callable.
+ ValueError: If `use_loss_summaries` isn't boolean or `None`.
+ ValueError: If `get_hooks_fn` isn't callable or `None`.
+ """
+ if not callable(loss_fn):
+ raise ValueError('loss_fn must be callable.')
+ if use_loss_summaries not in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
+
+ def _model_fn(features, labels, mode):
+ """StarGANEstimator model function."""
+ if mode not in [
+ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT
+ ]:
+ raise ValueError('Mode not recognized: %s' % mode)
+
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ input_data = features[0]
+ input_data_domain_label = features[1]
+ else:
+ input_data = features # rename inputs for clarity
+ input_data_domain_label = labels # rename inputs for clarity
+
+ # Make StarGANModel, which encapsulates the GAN model architectures.
+ gan_model = _get_gan_model(mode, generator_fn, discriminator_fn,
+ input_data, input_data_domain_label,
+ add_summaries)
+
+ # Make the EstimatorSpec, which incorporates the StarGANModel, losses,
+ # eval, metrics, and optimizers (if required).
+ return _get_estimator_spec(mode, gan_model, loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer,
+ discriminator_optimizer, get_hooks_fn)
+
+ super(StarGANEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+def _get_gan_model(mode,
+ generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ add_summaries,
+ generator_scope='Generator'):
+ """Makes the StarGANModel tuple."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ gan_model = _make_prediction_gan_model(input_data, input_data_domain_label,
+ generator_fn, generator_scope)
+ else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
+ gan_model = _make_gan_model(generator_fn, discriminator_fn, input_data,
+ input_data_domain_label, generator_scope,
+ add_summaries, mode)
+
+ return gan_model
+
+
+def _get_estimator_spec(mode,
+ gan_model,
+ loss_fn,
+ get_eval_metric_ops_fn,
+ generator_optimizer,
+ discriminator_optimizer,
+ get_hooks_fn=None):
+ """Get the EstimatorSpec for the current mode."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=mode, predictions=gan_model.generated_data)
+ else:
+ gan_loss = loss_fn(gan_model)
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
+ get_eval_metric_ops_fn)
+ else: # model_fn_lib.ModeKeys.TRAIN:
+ gopt = (
+ generator_optimizer()
+ if callable(generator_optimizer) else generator_optimizer)
+ dopt = (
+ discriminator_optimizer()
+ if callable(discriminator_optimizer) else discriminator_optimizer)
+ get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
+ estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
+ dopt, get_hooks_fn)
+
+ return estimator_spec
+
+
+def _make_gan_model(generator_fn, discriminator_fn, input_data,
+ input_data_domain_label, generator_scope, add_summaries,
+ mode):
+ """Construct a `StarGANModel`, and optionally pass in `mode`."""
+ # If network functions have an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(generator_fn, mode=mode)
+ if 'mode' in inspect.getargspec(discriminator_fn).args:
+ discriminator_fn = functools.partial(discriminator_fn, mode=mode)
+ gan_model = tfgan_train.stargan_model(
+ generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ generator_scope=generator_scope)
+ if add_summaries:
+ if not isinstance(add_summaries, (tuple, list)):
+ add_summaries = [add_summaries]
+ with ops.name_scope(None):
+ for summary_type in add_summaries:
+ _summary_type_map[summary_type](gan_model)
+
+ return gan_model
+
+
+def _make_prediction_gan_model(input_data, input_data_domain_label,
+ generator_fn, generator_scope):
+ """Make a `StarGANModel` from just the generator."""
+ # If `generator_fn` has an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(
+ generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
+ with variable_scope.variable_scope(generator_scope) as gen_scope:
+ # pylint:disable=protected-access
+ input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
+ input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
+ input_data_domain_label)
+ # pylint:enable=protected-access
+ generated_data = generator_fn(input_data, input_data_domain_label)
+ generator_variables = variable_lib.get_trainable_variables(gen_scope)
+
+ return tfgan_tuples.StarGANModel(
+ input_data=input_data,
+ input_data_domain_label=None,
+ generated_data=generated_data,
+ generated_data_domain_target=input_data_domain_label,
+ reconstructed_data=None,
+ discriminator_input_data_source_predication=None,
+ discriminator_generated_data_source_predication=None,
+ discriminator_input_data_domain_predication=None,
+ discriminator_generated_data_domain_predication=None,
+ generator_variables=generator_variables,
+ generator_scope=generator_scope,
+ generator_fn=generator_fn,
+ discriminator_variables=None,
+ discriminator_scope=None,
+ discriminator_fn=None)
+
+
+def _get_eval_estimator_spec(gan_model,
+ gan_loss,
+ get_eval_metric_ops_fn=None,
+ name=None):
+ """Return an EstimatorSpec for the eval case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss, gan_loss.discriminator_loss]):
+
+ def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+
+ eval_metric_ops = {
+ _summary_key(name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL,
+ predictions=gan_model.generated_data,
+ loss=scalar_loss,
+ eval_metric_ops=eval_metric_ops)
+
+
+def _get_train_estimator_spec(gan_model,
+ gan_loss,
+ generator_optimizer,
+ discriminator_optimizer,
+ get_hooks_fn,
+ train_op_fn=tfgan_train.gan_train_ops):
+ """Return an EstimatorSpec for the train case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
+ discriminator_optimizer)
+ training_hooks = get_hooks_fn(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ loss=scalar_loss,
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ train_op=train_ops.global_step_inc_op,
+ training_hooks=training_hooks)
+
+
+def stargan_prediction_input_fn_wrapper(fn):
+ """StarGAN Estimator prediction input_fn wrapper.
+
+ Since estimator will disregard the "label" variable pass to the model, we will
+ use a wrapper to pack the (feature, label) tuple as feature passed to the
+ model.
+
+ Args:
+ fn: input_fn for the prediction.
+
+ Returns:
+ A tuple ((feature, label), None) where the second element is the dummy label
+ to be disregarded and the first element is the true input to the estimator.
+ """
+
+ def new_fn():
+ return fn(), None
+
+ return new_fn
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py
new file mode 100644
index 0000000000..2ec7938c7c
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py
@@ -0,0 +1,306 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TFGAN's stargan_estimator.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+from absl.testing import parameterized
+import numpy as np
+import six
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl as estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import training
+from tensorflow.python.training import training_util
+
+
+def dummy_generator_fn(input_data, input_data_domain_label, mode):
+ del input_data_domain_label, mode
+
+ return variable_scope.get_variable('dummy_g', initializer=0.5) * input_data
+
+
+def dummy_discriminator_fn(input_data, num_domains, mode):
+ del mode
+
+ hidden = layers.flatten(input_data)
+ output_src = math_ops.reduce_mean(hidden, axis=1)
+ output_cls = layers.fully_connected(
+ inputs=hidden, num_outputs=num_domains, scope='debug')
+
+ return output_src, output_cls
+
+
+class StarGetGANModelTest(test.TestCase, parameterized.TestCase):
+ """Tests that `StarGetGANModel` produces the correct model."""
+
+ @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_gan_model(self, mode):
+ with ops.Graph().as_default():
+ input_data = array_ops.ones([6, 4, 4, 3])
+ input_data_domain_label = array_ops.one_hot([0] * 6, 5)
+ gan_model = estimator._get_gan_model(
+ mode,
+ dummy_generator_fn,
+ dummy_discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ add_summaries=False)
+
+ self.assertEqual(input_data, gan_model.input_data)
+ self.assertIsNotNone(gan_model.generated_data)
+ self.assertIsNotNone(gan_model.generated_data_domain_target)
+ self.assertEqual(1, len(gan_model.generator_variables))
+ self.assertIsNotNone(gan_model.generator_scope)
+ self.assertIsNotNone(gan_model.generator_fn)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertIsNone(gan_model.input_data_domain_label)
+ self.assertEqual(input_data_domain_label,
+ gan_model.generated_data_domain_target)
+ self.assertIsNone(gan_model.reconstructed_data)
+ self.assertIsNone(gan_model.discriminator_input_data_source_predication)
+ self.assertIsNone(
+ gan_model.discriminator_generated_data_source_predication)
+ self.assertIsNone(gan_model.discriminator_input_data_domain_predication)
+ self.assertIsNone(
+ gan_model.discriminator_generated_data_domain_predication)
+ self.assertIsNone(gan_model.discriminator_variables)
+ self.assertIsNone(gan_model.discriminator_scope)
+ self.assertIsNone(gan_model.discriminator_fn)
+ else:
+ self.assertEqual(input_data_domain_label,
+ gan_model.input_data_domain_label)
+ self.assertIsNotNone(gan_model.reconstructed_data.shape)
+ self.assertIsNotNone(
+ gan_model.discriminator_input_data_source_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_generated_data_source_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_input_data_domain_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_generated_data_domain_predication)
+ self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.discriminator_scope)
+ self.assertIsNotNone(gan_model.discriminator_fn)
+
+
+def get_dummy_gan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ return tfgan_tuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=array_ops.ones([1, 2, 2, 3]),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var,
+ discriminator_generated_data_source_predication=array_ops.ones(
+ [1]) * gen_var * dis_var,
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2
+ ]) * dis_var,
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) *
+ gen_var * dis_var,
+ generator_variables=[gen_var],
+ generator_scope=gen_scope,
+ generator_fn=None,
+ discriminator_variables=[dis_var],
+ discriminator_scope=dis_scope,
+ discriminator_fn=None)
+
+
+def dummy_loss_fn(gan_model):
+ loss = math_ops.reduce_sum(
+ gan_model.discriminator_input_data_domain_predication -
+ gan_model.discriminator_generated_data_domain_predication)
+ loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data)
+ return tfgan_tuples.GANLoss(loss, loss)
+
+
+def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric':
+ metrics_lib.mean_squared_error(gan_model.input_data,
+ gan_model.generated_data)
+ }
+
+
+class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
+ """Tests that the EstimatorSpec is constructed appropriately."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
+ cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
+
+ @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_estimator_spec(self, mode):
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ spec = estimator._get_estimator_spec(
+ mode,
+ self._gan_model,
+ loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=self._generator_optimizer,
+ discriminator_optimizer=self._discriminator_optimizer)
+
+ self.assertEqual(mode, spec.mode)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNotNone(spec.training_hooks)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.eval_metric_ops)
+
+
+# TODO(joelshor): Add pandas test.
+class StarGANEstimatorIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self,
+ train_input_fn,
+ eval_input_fn,
+ predict_input_fn,
+ prediction_size,
+ lr_decay=False):
+
+ def make_opt():
+ gstep = training_util.get_or_create_global_step()
+ lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
+ return training.GradientDescentOptimizer(lr)
+
+ gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
+ dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
+ est = estimator.StarGANEstimator(
+ generator_fn=dummy_generator_fn,
+ discriminator_fn=dummy_discriminator_fn,
+ loss_fn=dummy_loss_fn,
+ generator_optimizer=gopt,
+ discriminator_optimizer=dopt,
+ get_eval_metric_ops_fn=get_metrics,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+ self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
+ scores['loss'])
+ self.assertIn('mse_custom_metric', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([x for x in est.predict(predict_input_fn)])
+
+ self.assertAllEqual(prediction_size, predictions.shape)
+
+ @staticmethod
+ def _numpy_input_fn_wrapper(numpy_input_fn, batch_size, label_size):
+ """Wrapper to remove the dictionary in numpy_input_fn.
+
+ NOTE:
+ We create the domain_label here because the model expect a fully define
+ batch_size from the input.
+
+ Args:
+ numpy_input_fn: input_fn created from numpy_io
+ batch_size: (int) number of items for each batch
+ label_size: (int) number of domains
+
+ Returns:
+ a new input_fn
+ """
+
+ def new_input_fn():
+ features = numpy_input_fn()
+ return features['x'], array_ops.one_hot([0] * batch_size, label_size)
+
+ return new_input_fn
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ batch_size = 5
+ img_size = 8
+ channel_size = 3
+ label_size = 3
+ image_data = np.zeros(
+ [batch_size, img_size, img_size, channel_size], dtype=np.float32)
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data},
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data}, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data}, shuffle=False)
+
+ train_input_fn = self._numpy_input_fn_wrapper(train_input_fn, batch_size,
+ label_size)
+ eval_input_fn = self._numpy_input_fn_wrapper(eval_input_fn, batch_size,
+ label_size)
+ predict_input_fn = self._numpy_input_fn_wrapper(predict_input_fn,
+ batch_size, label_size)
+
+ predict_input_fn = estimator.stargan_prediction_input_fn_wrapper(
+ predict_input_fn)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ prediction_size=[batch_size, img_size, img_size, channel_size])
+
+
+if __name__ == '__main__':
+ test.main()