From c17459a0acb5044fa415d11221a45bea619aa349 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 27 Nov 2017 16:12:26 -0800 Subject: [tfgan] Add option to pass MODE to generator_fn, for the purpose of things like prediction. PiperOrigin-RevId: 177086828 --- .../python/estimator/python/gan_estimator_impl.py | 33 ++++++++++++++++++---- .../python/estimator/python/gan_estimator_test.py | 4 +-- 2 files changed, 29 insertions(+), 8 deletions(-) (limited to 'tensorflow/contrib/gan') diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 0824ecf616..058dc1d1f8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib @@ -29,6 +30,7 @@ from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_inspect as inspect __all__ = [ @@ -116,7 +118,10 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details @@ -225,9 +230,12 @@ def _gan_model_fn( labels=None) -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" +def _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, mode): + """Make a `GANModel`, and optionally pass in `mode`.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, @@ -245,15 +253,28 @@ def _make_train_gan_model(generator_fn, discriminator_fn, real_data, return gan_model +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.TRAIN) + + def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries): """Make a `GANModel` for evaluation.""" - return _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries) + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.EVAL) def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, + mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 1bfdce9ee9..e752f0bccc 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -48,7 +48,8 @@ from tensorflow.python.training import training from tensorflow.python.training import training_util -def generator_fn(noise_dict): +def generator_fn(noise_dict, mode): + del mode noise = noise_dict['x'] return layers.fully_connected(noise, noise.shape[1].value) @@ -90,7 +91,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, generator_var_names, set([x.name for x in gan_model.generator_variables])) testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(generator_fn, gan_model.generator_fn) testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) # TODO(joelshor): Add check on `discriminator_real_outputs`. # TODO(joelshor): Add check on `discriminator_gen_outputs`. -- cgit v1.2.3