aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-27 16:12:26 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-27 16:15:23 -0800
commitc17459a0acb5044fa415d11221a45bea619aa349 (patch)
treee3f57928baec6b67ada15c2552001bb003cdec40 /tensorflow/contrib/gan
parent2be93d0d543591ebee31bcddfa4b9c6c53e5c793 (diff)
[tfgan] Add option to pass MODE to generator_fn, for the purpose of things like prediction.
PiperOrigin-RevId: 177086828
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py33
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py4
2 files changed, 29 insertions, 8 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index 0824ecf616..058dc1d1f8 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
import enum
from tensorflow.contrib.framework.python.ops import variables as variable_lib
@@ -29,6 +30,7 @@ from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import tf_inspect as inspect
__all__ = [
@@ -116,7 +118,10 @@ class GANEstimator(estimator.Estimator):
to continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
- generator. See `TFGAN` for more details and examples.
+ generator. See `TFGAN` for more details and examples. Additionally, if
+ it has an argument called `mode`, the Estimator's `mode` will be passed
+ in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
+ normalization.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
@@ -225,9 +230,12 @@ def _gan_model_fn(
labels=None)
-def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for training."""
+def _make_gan_model(generator_fn, discriminator_fn, real_data,
+ generator_inputs, generator_scope, add_summaries, mode):
+ """Make a `GANModel`, and optionally pass in `mode`."""
+ # If `generator_fn` has an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(generator_fn, mode=mode)
gan_model = tfgan_train.gan_model(
generator_fn,
discriminator_fn,
@@ -245,15 +253,28 @@ def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
return gan_model
+def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
+ generator_inputs, generator_scope, add_summaries):
+ """Make a `GANModel` for training."""
+ return _make_gan_model(generator_fn, discriminator_fn, real_data,
+ generator_inputs, generator_scope, add_summaries,
+ model_fn_lib.ModeKeys.TRAIN)
+
+
def _make_eval_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries):
"""Make a `GANModel` for evaluation."""
- return _make_train_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries)
+ return _make_gan_model(generator_fn, discriminator_fn, real_data,
+ generator_inputs, generator_scope, add_summaries,
+ model_fn_lib.ModeKeys.EVAL)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
+ # If `generator_fn` has an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(generator_fn,
+ mode=model_fn_lib.ModeKeys.PREDICT)
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access
generated_data = generator_fn(generator_inputs)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 1bfdce9ee9..e752f0bccc 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -48,7 +48,8 @@ from tensorflow.python.training import training
from tensorflow.python.training import training_util
-def generator_fn(noise_dict):
+def generator_fn(noise_dict, mode):
+ del mode
noise = noise_dict['x']
return layers.fully_connected(noise, noise.shape[1].value)
@@ -90,7 +91,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data,
generator_var_names,
set([x.name for x in gan_model.generator_variables]))
testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name)
- testcase.assertEqual(generator_fn, gan_model.generator_fn)
testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data)
# TODO(joelshor): Add check on `discriminator_real_outputs`.
# TODO(joelshor): Add check on `discriminator_gen_outputs`.