diff options
author | 2018-06-13 03:02:11 -0700 | |
---|---|---|
committer | 2018-06-13 03:05:03 -0700 | |
commit | e6d00acfd8e4539291a087a6c3e0799253ba9d6f (patch) | |
tree | f8e78063d153a3a310e9e14f350d1d501acbe163 | |
parent | 97d5bfed6c8a42ea6d8779309e9eb64a1e488d07 (diff) |
Remove GANHead from GANEstimator.
PiperOrigin-RevId: 200362771
7 files changed, 218 insertions, 603 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index b305f37791..d38d770bc5 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -45,6 +45,7 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:training", + "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python/ops/distributions", "//tensorflow/python/ops/losses", @@ -59,6 +60,7 @@ py_test( deps = [ ":features", ":namedtuples", + ":random_tensor_pool", ":train", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/slim:learning", @@ -70,6 +72,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:training", + "//tensorflow/python:training_util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", @@ -96,7 +99,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gan_estimator", - ":head", "//tensorflow/python:util", ], ) @@ -188,6 +190,7 @@ py_test( srcs = ["python/losses/python/tuple_losses_test.py"], srcs_version = "PY2AND3", deps = [ + ":namedtuples", ":tuple_losses", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -344,9 +347,11 @@ py_library( "//tensorflow/python:image_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:util", + "@six_archive//:six", ], ) @@ -429,40 +434,6 @@ py_test( ) py_library( - name = "head", - srcs = [ - "python/estimator/python/head.py", - "python/estimator/python/head_impl.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":namedtuples", - ":train", - "//tensorflow/python:framework_ops", - "//tensorflow/python:util", - "//tensorflow/python/estimator:head", - "//tensorflow/python/estimator:model_fn", - ], -) - -py_test( - name = "head_test", - srcs = ["python/estimator/python/head_test.py"], - shard_count = 1, - srcs_version = "PY2AND3", - deps = [ - ":head", - ":namedtuples", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:math_ops", - "//tensorflow/python:training", - "//tensorflow/python:variable_scope", - "//tensorflow/python/estimator:model_fn", - ], -) - -py_library( name = "gan_estimator", srcs = [ "python/estimator/python/gan_estimator.py", @@ -470,12 +441,12 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":head", ":namedtuples", ":summaries", ":train", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:framework_ops", + "//tensorflow/python:metrics", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator", @@ -498,16 +469,19 @@ py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:parsing_ops", "//tensorflow/python:summary", "//tensorflow/python:training", - "//tensorflow/python/estimator:head", + "//tensorflow/python:training_util", + "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:numpy_io", "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index c9f7bc61b2..04dddb4b55 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -25,16 +25,13 @@ from __future__ import print_function # Collapse `estimator` into a single namespace. # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.estimator.python import gan_estimator -from tensorflow.contrib.gan.python.estimator.python import head from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * -from tensorflow.contrib.gan.python.estimator.python.head import * # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'gan_estimator', - 'head', -] + gan_estimator.__all__ + head.__all__ +] + gan_estimator.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 4092b32004..7104c8aa61 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -24,11 +24,11 @@ import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.contrib.gan.python.estimator.python import head as head_lib from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variable_scope from tensorflow.python.util import tf_inspect as inspect @@ -158,90 +158,77 @@ class GANEstimator(estimator.Estimator): # TODO(joelshor): Explicitly validate inputs. def _model_fn(features, labels, mode): - gopt = (generator_optimizer() if callable(generator_optimizer) else - generator_optimizer) - dopt = (discriminator_optimizer() if callable(discriminator_optimizer) - else discriminator_optimizer) - gan_head = head_lib.gan_head( - generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries, get_hooks_fn=get_hooks_fn, - get_eval_metric_ops_fn=get_eval_metric_ops_fn) - return _gan_model_fn( - features, labels, mode, generator_fn, discriminator_fn, gan_head, + """GANEstimator model function.""" + if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, + model_fn_lib.ModeKeys.PREDICT]: + raise ValueError('Mode not recognized: %s' % mode) + real_data = labels # rename inputs for clarity + generator_inputs = features # rename inputs for clarity + + # Make GANModel, which encapsulates the GAN model architectures. + gan_model = _get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries) + # Make the EstimatorSpec, which incorporates the GANModel, losses, eval + # metrics, and optimizers (if required). + return _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + get_hooks_fn) + super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) -def _gan_model_fn( - features, - labels, - mode, - generator_fn, - discriminator_fn, - head, - add_summaries=None, - generator_scope_name='Generator'): - """The `model_fn` for the GAN estimator. - - We make the following convention: - features -> TFGAN's `generator_inputs` - labels -> TFGAN's `real_data` - - Args: - features: A dictionary to feed to generator. In the unconditional case, - this might be just `noise`. In the conditional GAN case, this - might be the generator's conditioning. The `generator_fn` determines - what the required keys are. - labels: Real data. Can be any structure, as long as `discriminator_fn` - can accept it for the first argument. - mode: Defines whether this is training, evaluation or prediction. - See `ModeKeys`. - generator_fn: A python lambda that takes `generator_inputs` as inputs and - returns the outputs of the GAN generator. - discriminator_fn: A python lambda that takes `real_data`/`generated data` - and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. - head: A `Head` instance suitable for GANs. - add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. - generator_scope_name: The name of the generator scope. We need this to be - the same for GANModels produced by TFGAN's `train.gan_model` and the - manually constructed ones for predictions. - - Returns: - `ModelFnOps` - - Raises: - ValueError: If `labels` isn't `None` during prediction. - """ - real_data = labels - generator_inputs = features - - if mode == model_fn_lib.ModeKeys.TRAIN: - gan_model = _make_train_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope_name, add_summaries) - elif mode == model_fn_lib.ModeKeys.EVAL: - gan_model = _make_eval_gan_model( - generator_fn, discriminator_fn, real_data, generator_inputs, - generator_scope_name, add_summaries) - else: +def _get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries, generator_scope='Generator'): + """Makes the GANModel tuple, which encapsulates the GAN model architecture.""" + if mode == model_fn_lib.ModeKeys.PREDICT: if real_data is not None: raise ValueError('`labels` must be `None` when mode is `predict`. ' 'Instead, found %s' % real_data) gan_model = _make_prediction_gan_model( - generator_inputs, generator_fn, generator_scope_name) + generator_inputs, generator_fn, generator_scope) + else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL + gan_model = _make_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope, add_summaries, mode) + + return gan_model - return head.create_estimator_spec( - features=None, - mode=mode, - logits=gan_model, - labels=None) + +def _get_estimator_spec( + mode, gan_model, generator_loss_fn, discriminator_loss_fn, + get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, + get_hooks_fn=None): + """Get the EstimatorSpec for the current mode.""" + if mode == model_fn_lib.ModeKeys.PREDICT: + estimator_spec = model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + gan_loss = tfgan_tuples.GANLoss( + generator_loss=generator_loss_fn(gan_model), + discriminator_loss=discriminator_loss_fn(gan_model)) + if mode == model_fn_lib.ModeKeys.EVAL: + estimator_spec = _get_eval_estimator_spec( + gan_model, gan_loss, get_eval_metric_ops_fn) + else: # model_fn_lib.ModeKeys.TRAIN: + gopt = (generator_optimizer() if callable(generator_optimizer) else + generator_optimizer) + dopt = (discriminator_optimizer() if callable(discriminator_optimizer) + else discriminator_optimizer) + get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() + estimator_spec = _get_train_estimator_spec( + gan_model, gan_loss, gopt, dopt, get_hooks_fn) + + return estimator_spec def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): - """Make a `GANModel`, and optionally pass in `mode`.""" + """Construct a `GANModel`, and optionally pass in `mode`.""" # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) @@ -264,22 +251,6 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data, return gan_model -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" - return _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, - model_fn_lib.ModeKeys.TRAIN) - - -def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for evaluation.""" - return _make_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries, - model_fn_lib.ModeKeys.EVAL) - - def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" # If `generator_fn` has an argument `mode`, pass mode to it. @@ -303,3 +274,46 @@ def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): discriminator_variables=None, discriminator_scope=None, discriminator_fn=None) + + +def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None, + name=None): + """Return an EstimatorSpec for the eval case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + eval_metric_ops = { + _summary_key(name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } + if get_eval_metric_ops_fn is not None: + custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model) + if not isinstance(custom_eval_metric_ops, dict): + raise TypeError('get_eval_metric_ops_fn must return a dict, ' + 'received: {}'.format(custom_eval_metric_ops)) + eval_metric_ops.update(custom_eval_metric_ops) + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + eval_metric_ops=eval_metric_ops) + + +def _get_train_estimator_spec( + gan_model, gan_loss, generator_optimizer, discriminator_optimizer, + get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops): + """Return an EstimatorSpec for the train case.""" + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer, + discriminator_optimizer) + training_hooks = get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 955482599b..9ac9c6ca9c 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -21,30 +21,30 @@ from __future__ import print_function import shutil import tempfile +from absl.testing import parameterized import numpy as np import six from tensorflow.contrib import layers -from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses from tensorflow.contrib.learn.python.learn.learn_io import graph_io from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.inputs import numpy_io from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import input as input_lib from tensorflow.python.training import learning_rate_decay -from tensorflow.python.training import monitored_session from tensorflow.python.training import training from tensorflow.python.training import training_util @@ -60,120 +60,109 @@ def discriminator_fn(data, unused_conditioning, mode): return layers.fully_connected(data, 1) -def mock_head(testcase, expected_generator_inputs, expected_real_data, - generator_scope_name): - """Returns a mock head that validates logits values and variable names.""" - discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults - generator_var_names = set([ - '%s/fully_connected/weights:0' % generator_scope_name, - '%s/fully_connected/biases:0' % generator_scope_name]) - discriminator_var_names = set([ - '%s/fully_connected/weights:0' % discriminator_scope_name, - '%s/fully_connected/biases:0' % discriminator_scope_name]) - - def _create_estimator_spec(features, mode, logits, labels): - gan_model = logits # renaming for clarity - is_predict = mode == model_fn_lib.ModeKeys.PREDICT - testcase.assertIsNone(features) - testcase.assertIsNone(labels) - testcase.assertIsInstance(gan_model, namedtuples.GANModel) - - trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) - expected_var_names = (generator_var_names if is_predict else - generator_var_names | discriminator_var_names) - testcase.assertItemsEqual(expected_var_names, - [var.name for var in trainable_vars]) - - assertions = [] - def _or_none(x): - return None if is_predict else x - testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs) - # TODO(joelshor): Add check on `generated_data`. - testcase.assertItemsEqual( - generator_var_names, - set([x.name for x in gan_model.generator_variables])) - testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) - # TODO(joelshor): Add check on `discriminator_real_outputs`. - # TODO(joelshor): Add check on `discriminator_gen_outputs`. - if is_predict: - testcase.assertIsNone(gan_model.discriminator_scope) - else: - testcase.assertEqual(discriminator_scope_name, - gan_model.discriminator_scope.name) - - with ops.control_dependencies(assertions): - if mode == model_fn_lib.ModeKeys.TRAIN: - return model_fn_lib.EstimatorSpec( - mode=mode, loss=array_ops.zeros([]), - train_op=control_flow_ops.no_op(), training_hooks=[]) - elif mode == model_fn_lib.ModeKeys.EVAL: - return model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data, - loss=array_ops.zeros([])) - elif mode == model_fn_lib.ModeKeys.PREDICT: - return model_fn_lib.EstimatorSpec( - mode=mode, predictions=gan_model.generated_data) - else: - testcase.fail('Invalid mode: {}'.format(mode)) - - head = test.mock.NonCallableMagicMock(spec=head_lib._Head) - head.create_estimator_spec = test.mock.MagicMock( - wraps=_create_estimator_spec) - - return head - - -class GANModelFnTest(test.TestCase): - """Tests that _gan_model_fn passes expected logits to mock head.""" - - def setUp(self): - self._model_dir = tempfile.mkdtemp() - - def tearDown(self): - if self._model_dir: - writer_cache.FileWriterCache.clear() - shutil.rmtree(self._model_dir) +class GetGANModelTest(test.TestCase, parameterized.TestCase): + """Tests that `GetGANModel` produces the correct model.""" - def _test_logits_helper(self, mode): - """Tests that the expected logits are passed to mock head.""" + @parameterized.named_parameters( + ('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_gan_model(self, mode): with ops.Graph().as_default(): - training_util.get_or_create_global_step() - generator_inputs = {'x': array_ops.zeros([5, 4])} - real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else - array_ops.zeros([5, 4])) - generator_scope_name = 'generator' - head = mock_head(self, - expected_generator_inputs=generator_inputs, - expected_real_data=real_data, - generator_scope_name=generator_scope_name) - estimator_spec = estimator._gan_model_fn( - features=generator_inputs, - labels=real_data, - mode=mode, - generator_fn=generator_fn, - discriminator_fn=discriminator_fn, - generator_scope_name=generator_scope_name, - head=head) - with monitored_session.MonitoredTrainingSession( - checkpoint_dir=self._model_dir) as sess: - if mode == model_fn_lib.ModeKeys.TRAIN: - sess.run(estimator_spec.train_op) - elif mode == model_fn_lib.ModeKeys.EVAL: - sess.run(estimator_spec.loss) - elif mode == model_fn_lib.ModeKeys.PREDICT: - sess.run(estimator_spec.predictions) - else: - self.fail('Invalid mode: {}'.format(mode)) - - def test_logits_predict(self): - self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT) - - def test_logits_eval(self): - self._test_logits_helper(model_fn_lib.ModeKeys.EVAL) - - def test_logits_train(self): - self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN) + generator_inputs = {'x': array_ops.ones([3, 4])} + real_data = (array_ops.zeros([3, 4]) if + mode != model_fn_lib.ModeKeys.PREDICT else None) + gan_model = estimator._get_gan_model( + mode, generator_fn, discriminator_fn, real_data, generator_inputs, + add_summaries=False) + + self.assertEqual(generator_inputs, gan_model.generator_inputs) + self.assertIsNotNone(gan_model.generated_data) + self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.generator_fn) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertIsNone(gan_model.real_data) + self.assertIsNone(gan_model.discriminator_real_outputs) + self.assertIsNone(gan_model.discriminator_gen_outputs) + self.assertIsNone(gan_model.discriminator_variables) + self.assertIsNone(gan_model.discriminator_scope) + self.assertIsNone(gan_model.discriminator_fn) + else: + self.assertIsNotNone(gan_model.real_data) + self.assertIsNotNone(gan_model.discriminator_real_outputs) + self.assertIsNotNone(gan_model.discriminator_gen_outputs) + self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer + self.assertIsNotNone(gan_model.discriminator_scope) + self.assertIsNotNone(gan_model.discriminator_fn) + + +def get_dummy_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=array_ops.zeros([3, 4]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +def dummy_loss_fn(gan_model): + return math_ops.reduce_sum(gan_model.discriminator_real_outputs - + gan_model.discriminator_gen_outputs) + + +def get_metrics(gan_model): + return { + 'mse_custom_metric': metrics_lib.mean_squared_error( + gan_model.real_data, gan_model.generated_data) + } + + +class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase): + """Tests that the EstimatorSpec is constructed appropriately.""" + + @classmethod + def setUpClass(cls): + cls._generator_optimizer = training.GradientDescentOptimizer(1.0) + cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0) + + @parameterized.named_parameters( + ('train', model_fn_lib.ModeKeys.TRAIN), + ('eval', model_fn_lib.ModeKeys.EVAL), + ('predict', model_fn_lib.ModeKeys.PREDICT)) + def test_get_estimator_spec(self, mode): + with ops.Graph().as_default(): + self._gan_model = get_dummy_gan_model() + spec = estimator._get_estimator_spec( + mode, + self._gan_model, + generator_loss_fn=dummy_loss_fn, + discriminator_loss_fn=dummy_loss_fn, + get_eval_metric_ops_fn=get_metrics, + generator_optimizer=self._generator_optimizer, + discriminator_optimizer=self._discriminator_optimizer) + + self.assertEqual(mode, spec.mode) + if mode == model_fn_lib.ModeKeys.PREDICT: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + elif mode == model_fn_lib.ModeKeys.TRAIN: + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.train_op) + self.assertIsNotNone(spec.training_hooks) + elif mode == model_fn_lib.ModeKeys.EVAL: + self.assertEqual(self._gan_model.generated_data, spec.predictions) + self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar + self.assertIsNotNone(spec.eval_metric_ops) # TODO(joelshor): Add pandas test. @@ -195,12 +184,6 @@ class GANEstimatorIntegrationTest(test.TestCase): lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) return training.GradientDescentOptimizer(lr) - def get_metrics(gan_model): - return { - 'mse_custom_metric': metrics_lib.mean_squared_error( - gan_model.real_data, gan_model.generated_data) - } - gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) est = estimator.GANEstimator( diff --git a/tensorflow/contrib/gan/python/estimator/python/head.py b/tensorflow/contrib/gan/python/estimator/python/head.py deleted file mode 100644 index 3225d6f41a..0000000000 --- a/tensorflow/contrib/gan/python/estimator/python/head.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""`tf.Learn` components for `GANEstimator`'s loss.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python.estimator.python import head_impl -# pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.estimator.python.head_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented - -__all__ = head_impl.__all__ -remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py deleted file mode 100644 index ff903a78cc..0000000000 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A TFGAN-backed GAN Estimator.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python import train as tfgan_train -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator.canned import head -from tensorflow.python.framework import ops -from tensorflow.python.ops import metrics as metrics_lib - -__all__ = [ - 'GANHead', - 'gan_head', -] - -def _summary_key(head_name, val): - return '%s/%s' % (val, head_name) if head_name else val - - -def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, - discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), - get_eval_metric_ops_fn=None, name=None): - """Creates a `GANHead`. - - Args: - generator_loss_fn: A TFGAN loss function for the generator. Takes a - `GANModel` and returns a scalar. - discriminator_loss_fn: Same as `generator_loss_fn`, but for the - discriminator. - generator_optimizer: The optimizer for generator updates. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a - list of hooks. - get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a - dict of metric results keyed by name. The output of this function is - passed into `tf.estimator.EstimatorSpec` during evaluation. - name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. - - Returns: - An instance of `GANHead`. - """ - return GANHead(generator_loss_fn=generator_loss_fn, - discriminator_loss_fn=discriminator_loss_fn, - generator_optimizer=generator_optimizer, - discriminator_optimizer=discriminator_optimizer, - use_loss_summaries=use_loss_summaries, - get_hooks_fn=get_hooks_fn, - get_eval_metric_ops_fn=get_eval_metric_ops_fn, - name=name) - - -class GANHead(head._Head): # pylint: disable=protected-access - """`Head` for a GAN.""" - - def __init__(self, generator_loss_fn, discriminator_loss_fn, - generator_optimizer, discriminator_optimizer, - use_loss_summaries=True, - get_hooks_fn=None, - get_eval_metric_ops_fn=None, - name=None): - """`Head` for GAN training. - - Args: - generator_loss_fn: A TFGAN loss function for the generator. Takes a - `GANModel` and returns a scalar. - discriminator_loss_fn: Same as `generator_loss_fn`, but for the - discriminator. - generator_optimizer: The optimizer for generator updates. - discriminator_optimizer: Same as `generator_optimizer`, but for the - discriminator updates. - use_loss_summaries: If `True`, add loss summaries. If `False`, does not. - If `None`, uses defaults. - get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a - list of hooks. Defaults to `train.get_sequential_train_hooks()` - get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a - dict of metric results keyed by name. The output of this function is - passed into `tf.estimator.EstimatorSpec` during evaluation. - name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. - """ - if get_hooks_fn is None: - get_hooks_fn = tfgan_train.get_sequential_train_hooks() - # TODO(joelshor): Validate inputs. - - if use_loss_summaries in [True, False]: - generator_loss_fn = functools.partial( - generator_loss_fn, add_summaries=use_loss_summaries) - discriminator_loss_fn = functools.partial( - discriminator_loss_fn, add_summaries=use_loss_summaries) - self._generator_loss_fn = generator_loss_fn - self._discriminator_loss_fn = discriminator_loss_fn - self._generator_optimizer = generator_optimizer - self._discriminator_optimizer = discriminator_optimizer - self._get_hooks_fn = get_hooks_fn - self._get_eval_metric_ops_fn = get_eval_metric_ops_fn - self._name = name - - @property - def name(self): - return self._name - - @property - def logits_dimension(self): - return None - - def create_loss(self, features, mode, logits, labels): - """Returns a GANLoss tuple from the provided GANModel. - - See `Head` for more details. - - Args: - features: Input `dict` of `Tensor` objects. Unused. - mode: Estimator's `ModeKeys`. - logits: A GANModel tuple. - labels: Must be `None`. - - Returns: - A GANLoss tuple. - - """ - _validate_logits_and_labels(logits, labels) - del mode, labels, features # unused for this head. - gan_model = logits # rename variable for clarity - return tfgan_tuples.GANLoss( - generator_loss=self._generator_loss_fn(gan_model), - discriminator_loss=self._discriminator_loss_fn(gan_model)) - - def create_estimator_spec( - self, features, mode, logits, labels=None, - train_op_fn=tfgan_train.gan_train_ops): - """Returns `EstimatorSpec` that a model_fn can return. - - See `Head` for more details. - - Args: - features: Must be `None`. - mode: Estimator's `ModeKeys`. - logits: A GANModel tuple. - labels: Must be `None`. - train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, - and discriminator optimizer, and returns a `GANTrainOps` tuple. For - example, this function can come from TFGAN's `train.py` library, or can - be custom. - - Returns: - `EstimatorSpec`. - - Raises: - ValueError: If `features` isn't `None`. - ValueError: If `train_op_fn` isn't provided in train mode. - """ - _validate_logits_and_labels(logits, labels) - if features is not None: - raise ValueError('`features` should be `None`. Instead, found: %s' % - features) - gan_model = logits # rename variable for clarity - with ops.name_scope('GANHead'): - if mode == model_fn_lib.ModeKeys.PREDICT: - return model_fn_lib.EstimatorSpec( - mode=model_fn_lib.ModeKeys.PREDICT, - predictions=gan_model.generated_data) - elif mode == model_fn_lib.ModeKeys.EVAL: - gan_loss = self.create_loss( - features=None, mode=mode, logits=gan_model, labels=None) - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - with ops.name_scope(None, 'metrics', - [gan_loss.generator_loss, - gan_loss.discriminator_loss]): - eval_metric_ops = { - _summary_key(self._name, 'generator_loss'): - metrics_lib.mean(gan_loss.generator_loss), - _summary_key(self._name, 'discriminator_loss'): - metrics_lib.mean(gan_loss.discriminator_loss) - } - if self._get_eval_metric_ops_fn is not None: - custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model) - if not isinstance(custom_eval_metric_ops, dict): - raise TypeError('get_eval_metric_ops_fn must return a dict, ' - 'received: {}'.format(custom_eval_metric_ops)) - eval_metric_ops.update(custom_eval_metric_ops) - return model_fn_lib.EstimatorSpec( - mode=model_fn_lib.ModeKeys.EVAL, - predictions=gan_model.generated_data, - loss=scalar_loss, - eval_metric_ops=eval_metric_ops) - elif mode == model_fn_lib.ModeKeys.TRAIN: - if train_op_fn is None: - raise ValueError('train_op_fn can not be None.') - gan_loss = self.create_loss(None, mode, gan_model, None) - scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss - train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer, - self._discriminator_optimizer) - training_hooks = self._get_hooks_fn(train_ops) - return model_fn_lib.EstimatorSpec( - loss=scalar_loss, - mode=model_fn_lib.ModeKeys.TRAIN, - train_op=train_ops.global_step_inc_op, - training_hooks=training_hooks) - else: - raise ValueError('Mode not recognized: %s' % mode) - - -def _validate_logits_and_labels(logits, labels): - if labels is not None: - raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must ' - 'be `None`. Instead, found: %s' % labels) - - if not isinstance(logits, tfgan_tuples.GANModel): - raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must ' - 'be an instnace of a `GANModel`. Instead, found: %s' % - logits) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py deleted file mode 100644 index 6587f1fc60..0000000000 --- a/tensorflow/contrib/gan/python/estimator/python/head_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for TFGAN's head.py.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples -from tensorflow.contrib.gan.python.estimator.python import head - -from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import test -from tensorflow.python.training import training - - -def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument - return math_ops.reduce_sum(gan_model.discriminator_real_outputs - - gan_model.discriminator_gen_outputs) - - -def get_gan_model(): - # TODO(joelshor): Find a better way of creating a variable scope. - with variable_scope.variable_scope('generator') as gen_scope: - gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) - with variable_scope.variable_scope('discriminator') as dis_scope: - dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) - return tfgan_tuples.GANModel( - generator_inputs=None, - generated_data=array_ops.ones([3, 4]), - generator_variables=[gen_var], - generator_scope=gen_scope, - generator_fn=None, - real_data=None, - discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, - discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, - discriminator_variables=[dis_var], - discriminator_scope=dis_scope, - discriminator_fn=None) - - -class GANHeadTest(test.TestCase): - - def setUp(self): - super(GANHeadTest, self).setUp() - self.gan_head = head.gan_head( - generator_loss_fn=dummy_loss, - discriminator_loss_fn=dummy_loss, - generator_optimizer=training.GradientDescentOptimizer(1.0), - discriminator_optimizer=training.GradientDescentOptimizer(1.0), - get_eval_metric_ops_fn=self.get_metrics) - self.assertTrue(isinstance(self.gan_head, head.GANHead)) - - def get_metrics(self, gan_model): - self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel)) - return {} - - def _test_modes_helper(self, mode): - self.gan_head.create_estimator_spec( - features=None, - mode=mode, - logits=get_gan_model()) - - def test_modes_predict(self): - self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) - - def test_modes_eval(self): - self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) - - def test_modes_train(self): - self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN) - - -if __name__ == '__main__': - test.main() |