aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-13 03:02:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 03:05:03 -0700
commite6d00acfd8e4539291a087a6c3e0799253ba9d6f (patch)
treef8e78063d153a3a310e9e14f350d1d501acbe163
parent97d5bfed6c8a42ea6d8779309e9eb64a1e488d07 (diff)
Remove GANHead from GANEstimator.
PiperOrigin-RevId: 200362771
-rw-r--r--tensorflow/contrib/gan/BUILD50
-rw-r--r--tensorflow/contrib/gan/python/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py186
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py227
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head.py28
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py235
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py90
7 files changed, 218 insertions, 603 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index b305f37791..d38d770bc5 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -45,6 +45,7 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/losses",
@@ -59,6 +60,7 @@ py_test(
deps = [
":features",
":namedtuples",
+ ":random_tensor_pool",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/slim:learning",
@@ -70,6 +72,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
+ "//tensorflow/python:training_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/ops/distributions",
@@ -96,7 +99,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":gan_estimator",
- ":head",
"//tensorflow/python:util",
],
)
@@ -188,6 +190,7 @@ py_test(
srcs = ["python/losses/python/tuple_losses_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":namedtuples",
":tuple_losses",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -344,9 +347,11 @@ py_library(
"//tensorflow/python:image_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "@six_archive//:six",
],
)
@@ -429,40 +434,6 @@ py_test(
)
py_library(
- name = "head",
- srcs = [
- "python/estimator/python/head.py",
- "python/estimator/python/head_impl.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":namedtuples",
- ":train",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/estimator:head",
- "//tensorflow/python/estimator:model_fn",
- ],
-)
-
-py_test(
- name = "head_test",
- srcs = ["python/estimator/python/head_test.py"],
- shard_count = 1,
- srcs_version = "PY2AND3",
- deps = [
- ":head",
- ":namedtuples",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/estimator:model_fn",
- ],
-)
-
-py_library(
name = "gan_estimator",
srcs = [
"python/estimator/python/gan_estimator.py",
@@ -470,12 +441,12 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":head",
":namedtuples",
":summaries",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator",
@@ -498,16 +469,19 @@ py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
- "//tensorflow/python/estimator:head",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py
index c9f7bc61b2..04dddb4b55 100644
--- a/tensorflow/contrib/gan/python/estimator/__init__.py
+++ b/tensorflow/contrib/gan/python/estimator/__init__.py
@@ -25,16 +25,13 @@ from __future__ import print_function
# Collapse `estimator` into a single namespace.
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
-from tensorflow.contrib.gan.python.estimator.python import head
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
-from tensorflow.contrib.gan.python.estimator.python.head import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'gan_estimator',
- 'head',
-] + gan_estimator.__all__ + head.__all__
+] + gan_estimator.__all__
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index 4092b32004..7104c8aa61 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -24,11 +24,11 @@ import enum
from tensorflow.contrib.framework.python.ops import variables as variable_lib
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
-from tensorflow.contrib.gan.python.estimator.python import head as head_lib
from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import tf_inspect as inspect
@@ -158,90 +158,77 @@ class GANEstimator(estimator.Estimator):
# TODO(joelshor): Explicitly validate inputs.
def _model_fn(features, labels, mode):
- gopt = (generator_optimizer() if callable(generator_optimizer) else
- generator_optimizer)
- dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
- else discriminator_optimizer)
- gan_head = head_lib.gan_head(
- generator_loss_fn, discriminator_loss_fn, gopt, dopt,
- use_loss_summaries, get_hooks_fn=get_hooks_fn,
- get_eval_metric_ops_fn=get_eval_metric_ops_fn)
- return _gan_model_fn(
- features, labels, mode, generator_fn, discriminator_fn, gan_head,
+ """GANEstimator model function."""
+ if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT]:
+ raise ValueError('Mode not recognized: %s' % mode)
+ real_data = labels # rename inputs for clarity
+ generator_inputs = features # rename inputs for clarity
+
+ # Make GANModel, which encapsulates the GAN model architectures.
+ gan_model = _get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
add_summaries)
+ # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
+ # metrics, and optimizers (if required).
+ return _get_estimator_spec(
+ mode, gan_model, generator_loss_fn, discriminator_loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn)
+
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
-def _gan_model_fn(
- features,
- labels,
- mode,
- generator_fn,
- discriminator_fn,
- head,
- add_summaries=None,
- generator_scope_name='Generator'):
- """The `model_fn` for the GAN estimator.
-
- We make the following convention:
- features -> TFGAN's `generator_inputs`
- labels -> TFGAN's `real_data`
-
- Args:
- features: A dictionary to feed to generator. In the unconditional case,
- this might be just `noise`. In the conditional GAN case, this
- might be the generator's conditioning. The `generator_fn` determines
- what the required keys are.
- labels: Real data. Can be any structure, as long as `discriminator_fn`
- can accept it for the first argument.
- mode: Defines whether this is training, evaluation or prediction.
- See `ModeKeys`.
- generator_fn: A python lambda that takes `generator_inputs` as inputs and
- returns the outputs of the GAN generator.
- discriminator_fn: A python lambda that takes `real_data`/`generated data`
- and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
- head: A `Head` instance suitable for GANs.
- add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
- generator_scope_name: The name of the generator scope. We need this to be
- the same for GANModels produced by TFGAN's `train.gan_model` and the
- manually constructed ones for predictions.
-
- Returns:
- `ModelFnOps`
-
- Raises:
- ValueError: If `labels` isn't `None` during prediction.
- """
- real_data = labels
- generator_inputs = features
-
- if mode == model_fn_lib.ModeKeys.TRAIN:
- gan_model = _make_train_gan_model(
- generator_fn, discriminator_fn, real_data, generator_inputs,
- generator_scope_name, add_summaries)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- gan_model = _make_eval_gan_model(
- generator_fn, discriminator_fn, real_data, generator_inputs,
- generator_scope_name, add_summaries)
- else:
+def _get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
+ add_summaries, generator_scope='Generator'):
+ """Makes the GANModel tuple, which encapsulates the GAN model architecture."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
if real_data is not None:
raise ValueError('`labels` must be `None` when mode is `predict`. '
'Instead, found %s' % real_data)
gan_model = _make_prediction_gan_model(
- generator_inputs, generator_fn, generator_scope_name)
+ generator_inputs, generator_fn, generator_scope)
+ else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
+ gan_model = _make_gan_model(
+ generator_fn, discriminator_fn, real_data, generator_inputs,
+ generator_scope, add_summaries, mode)
+
+ return gan_model
- return head.create_estimator_spec(
- features=None,
- mode=mode,
- logits=gan_model,
- labels=None)
+
+def _get_estimator_spec(
+ mode, gan_model, generator_loss_fn, discriminator_loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn=None):
+ """Get the EstimatorSpec for the current mode."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=mode, predictions=gan_model.generated_data)
+ else:
+ gan_loss = tfgan_tuples.GANLoss(
+ generator_loss=generator_loss_fn(gan_model),
+ discriminator_loss=discriminator_loss_fn(gan_model))
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ estimator_spec = _get_eval_estimator_spec(
+ gan_model, gan_loss, get_eval_metric_ops_fn)
+ else: # model_fn_lib.ModeKeys.TRAIN:
+ gopt = (generator_optimizer() if callable(generator_optimizer) else
+ generator_optimizer)
+ dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
+ else discriminator_optimizer)
+ get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
+ estimator_spec = _get_train_estimator_spec(
+ gan_model, gan_loss, gopt, dopt, get_hooks_fn)
+
+ return estimator_spec
def _make_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries, mode):
- """Make a `GANModel`, and optionally pass in `mode`."""
+ """Construct a `GANModel`, and optionally pass in `mode`."""
# If network functions have an argument `mode`, pass mode to it.
if 'mode' in inspect.getargspec(generator_fn).args:
generator_fn = functools.partial(generator_fn, mode=mode)
@@ -264,22 +251,6 @@ def _make_gan_model(generator_fn, discriminator_fn, real_data,
return gan_model
-def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for training."""
- return _make_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries,
- model_fn_lib.ModeKeys.TRAIN)
-
-
-def _make_eval_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries):
- """Make a `GANModel` for evaluation."""
- return _make_gan_model(generator_fn, discriminator_fn, real_data,
- generator_inputs, generator_scope, add_summaries,
- model_fn_lib.ModeKeys.EVAL)
-
-
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
# If `generator_fn` has an argument `mode`, pass mode to it.
@@ -303,3 +274,46 @@ def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)
+
+
+def _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn=None,
+ name=None):
+ """Return an EstimatorSpec for the eval case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss,
+ gan_loss.discriminator_loss]):
+ def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+ eval_metric_ops = {
+ _summary_key(name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL,
+ predictions=gan_model.generated_data,
+ loss=scalar_loss,
+ eval_metric_ops=eval_metric_ops)
+
+
+def _get_train_estimator_spec(
+ gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
+ get_hooks_fn, train_op_fn=tfgan_train.gan_train_ops):
+ """Return an EstimatorSpec for the train case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
+ discriminator_optimizer)
+ training_hooks = get_hooks_fn(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ loss=scalar_loss,
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ train_op=train_ops.global_step_inc_op,
+ training_hooks=training_hooks)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 955482599b..9ac9c6ca9c 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -21,30 +21,30 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
from tensorflow.contrib import layers
-from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator
from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import learning_rate_decay
-from tensorflow.python.training import monitored_session
from tensorflow.python.training import training
from tensorflow.python.training import training_util
@@ -60,120 +60,109 @@ def discriminator_fn(data, unused_conditioning, mode):
return layers.fully_connected(data, 1)
-def mock_head(testcase, expected_generator_inputs, expected_real_data,
- generator_scope_name):
- """Returns a mock head that validates logits values and variable names."""
- discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults
- generator_var_names = set([
- '%s/fully_connected/weights:0' % generator_scope_name,
- '%s/fully_connected/biases:0' % generator_scope_name])
- discriminator_var_names = set([
- '%s/fully_connected/weights:0' % discriminator_scope_name,
- '%s/fully_connected/biases:0' % discriminator_scope_name])
-
- def _create_estimator_spec(features, mode, logits, labels):
- gan_model = logits # renaming for clarity
- is_predict = mode == model_fn_lib.ModeKeys.PREDICT
- testcase.assertIsNone(features)
- testcase.assertIsNone(labels)
- testcase.assertIsInstance(gan_model, namedtuples.GANModel)
-
- trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- expected_var_names = (generator_var_names if is_predict else
- generator_var_names | discriminator_var_names)
- testcase.assertItemsEqual(expected_var_names,
- [var.name for var in trainable_vars])
-
- assertions = []
- def _or_none(x):
- return None if is_predict else x
- testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs)
- # TODO(joelshor): Add check on `generated_data`.
- testcase.assertItemsEqual(
- generator_var_names,
- set([x.name for x in gan_model.generator_variables]))
- testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name)
- testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data)
- # TODO(joelshor): Add check on `discriminator_real_outputs`.
- # TODO(joelshor): Add check on `discriminator_gen_outputs`.
- if is_predict:
- testcase.assertIsNone(gan_model.discriminator_scope)
- else:
- testcase.assertEqual(discriminator_scope_name,
- gan_model.discriminator_scope.name)
-
- with ops.control_dependencies(assertions):
- if mode == model_fn_lib.ModeKeys.TRAIN:
- return model_fn_lib.EstimatorSpec(
- mode=mode, loss=array_ops.zeros([]),
- train_op=control_flow_ops.no_op(), training_hooks=[])
- elif mode == model_fn_lib.ModeKeys.EVAL:
- return model_fn_lib.EstimatorSpec(
- mode=mode, predictions=gan_model.generated_data,
- loss=array_ops.zeros([]))
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- return model_fn_lib.EstimatorSpec(
- mode=mode, predictions=gan_model.generated_data)
- else:
- testcase.fail('Invalid mode: {}'.format(mode))
-
- head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
- head.create_estimator_spec = test.mock.MagicMock(
- wraps=_create_estimator_spec)
-
- return head
-
-
-class GANModelFnTest(test.TestCase):
- """Tests that _gan_model_fn passes expected logits to mock head."""
-
- def setUp(self):
- self._model_dir = tempfile.mkdtemp()
-
- def tearDown(self):
- if self._model_dir:
- writer_cache.FileWriterCache.clear()
- shutil.rmtree(self._model_dir)
+class GetGANModelTest(test.TestCase, parameterized.TestCase):
+ """Tests that `GetGANModel` produces the correct model."""
- def _test_logits_helper(self, mode):
- """Tests that the expected logits are passed to mock head."""
+ @parameterized.named_parameters(
+ ('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_gan_model(self, mode):
with ops.Graph().as_default():
- training_util.get_or_create_global_step()
- generator_inputs = {'x': array_ops.zeros([5, 4])}
- real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
- array_ops.zeros([5, 4]))
- generator_scope_name = 'generator'
- head = mock_head(self,
- expected_generator_inputs=generator_inputs,
- expected_real_data=real_data,
- generator_scope_name=generator_scope_name)
- estimator_spec = estimator._gan_model_fn(
- features=generator_inputs,
- labels=real_data,
- mode=mode,
- generator_fn=generator_fn,
- discriminator_fn=discriminator_fn,
- generator_scope_name=generator_scope_name,
- head=head)
- with monitored_session.MonitoredTrainingSession(
- checkpoint_dir=self._model_dir) as sess:
- if mode == model_fn_lib.ModeKeys.TRAIN:
- sess.run(estimator_spec.train_op)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- sess.run(estimator_spec.loss)
- elif mode == model_fn_lib.ModeKeys.PREDICT:
- sess.run(estimator_spec.predictions)
- else:
- self.fail('Invalid mode: {}'.format(mode))
-
- def test_logits_predict(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT)
-
- def test_logits_eval(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.EVAL)
-
- def test_logits_train(self):
- self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN)
+ generator_inputs = {'x': array_ops.ones([3, 4])}
+ real_data = (array_ops.zeros([3, 4]) if
+ mode != model_fn_lib.ModeKeys.PREDICT else None)
+ gan_model = estimator._get_gan_model(
+ mode, generator_fn, discriminator_fn, real_data, generator_inputs,
+ add_summaries=False)
+
+ self.assertEqual(generator_inputs, gan_model.generator_inputs)
+ self.assertIsNotNone(gan_model.generated_data)
+ self.assertEqual(2, len(gan_model.generator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.generator_fn)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertIsNone(gan_model.real_data)
+ self.assertIsNone(gan_model.discriminator_real_outputs)
+ self.assertIsNone(gan_model.discriminator_gen_outputs)
+ self.assertIsNone(gan_model.discriminator_variables)
+ self.assertIsNone(gan_model.discriminator_scope)
+ self.assertIsNone(gan_model.discriminator_fn)
+ else:
+ self.assertIsNotNone(gan_model.real_data)
+ self.assertIsNotNone(gan_model.discriminator_real_outputs)
+ self.assertIsNotNone(gan_model.discriminator_gen_outputs)
+ self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.discriminator_scope)
+ self.assertIsNotNone(gan_model.discriminator_fn)
+
+
+def get_dummy_gan_model():
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ return tfgan_tuples.GANModel(
+ generator_inputs=None,
+ generated_data=array_ops.ones([3, 4]),
+ generator_variables=[gen_var],
+ generator_scope=gen_scope,
+ generator_fn=None,
+ real_data=array_ops.zeros([3, 4]),
+ discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
+ discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
+ discriminator_variables=[dis_var],
+ discriminator_scope=dis_scope,
+ discriminator_fn=None)
+
+
+def dummy_loss_fn(gan_model):
+ return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
+ gan_model.discriminator_gen_outputs)
+
+
+def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric': metrics_lib.mean_squared_error(
+ gan_model.real_data, gan_model.generated_data)
+ }
+
+
+class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
+ """Tests that the EstimatorSpec is constructed appropriately."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
+ cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
+
+ @parameterized.named_parameters(
+ ('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_estimator_spec(self, mode):
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ spec = estimator._get_estimator_spec(
+ mode,
+ self._gan_model,
+ generator_loss_fn=dummy_loss_fn,
+ discriminator_loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=self._generator_optimizer,
+ discriminator_optimizer=self._discriminator_optimizer)
+
+ self.assertEqual(mode, spec.mode)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNotNone(spec.training_hooks)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.eval_metric_ops)
# TODO(joelshor): Add pandas test.
@@ -195,12 +184,6 @@ class GANEstimatorIntegrationTest(test.TestCase):
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
- def get_metrics(gan_model):
- return {
- 'mse_custom_metric': metrics_lib.mean_squared_error(
- gan_model.real_data, gan_model.generated_data)
- }
-
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
diff --git a/tensorflow/contrib/gan/python/estimator/python/head.py b/tensorflow/contrib/gan/python/estimator/python/head.py
deleted file mode 100644
index 3225d6f41a..0000000000
--- a/tensorflow/contrib/gan/python/estimator/python/head.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""`tf.Learn` components for `GANEstimator`'s loss."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.gan.python.estimator.python import head_impl
-# pylint: disable=wildcard-import
-from tensorflow.contrib.gan.python.estimator.python.head_impl import *
-# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-__all__ = head_impl.__all__
-remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
deleted file mode 100644
index ff903a78cc..0000000000
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ /dev/null
@@ -1,235 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""A TFGAN-backed GAN Estimator."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import functools
-
-from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
-from tensorflow.contrib.gan.python import train as tfgan_train
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.canned import head
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import metrics as metrics_lib
-
-__all__ = [
- 'GANHead',
- 'gan_head',
-]
-
-def _summary_key(head_name, val):
- return '%s/%s' % (val, head_name) if head_name else val
-
-
-def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
- discriminator_optimizer, use_loss_summaries=True,
- get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
- get_eval_metric_ops_fn=None, name=None):
- """Creates a `GANHead`.
-
- Args:
- generator_loss_fn: A TFGAN loss function for the generator. Takes a
- `GANModel` and returns a scalar.
- discriminator_loss_fn: Same as `generator_loss_fn`, but for the
- discriminator.
- generator_optimizer: The optimizer for generator updates.
- discriminator_optimizer: Same as `generator_optimizer`, but for the
- discriminator updates.
- use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
- If `None`, uses defaults.
- get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
- list of hooks.
- get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
- dict of metric results keyed by name. The output of this function is
- passed into `tf.estimator.EstimatorSpec` during evaluation.
- name: name of the head. If provided, summary and metrics keys will be
- suffixed by `"/" + name`.
-
- Returns:
- An instance of `GANHead`.
- """
- return GANHead(generator_loss_fn=generator_loss_fn,
- discriminator_loss_fn=discriminator_loss_fn,
- generator_optimizer=generator_optimizer,
- discriminator_optimizer=discriminator_optimizer,
- use_loss_summaries=use_loss_summaries,
- get_hooks_fn=get_hooks_fn,
- get_eval_metric_ops_fn=get_eval_metric_ops_fn,
- name=name)
-
-
-class GANHead(head._Head): # pylint: disable=protected-access
- """`Head` for a GAN."""
-
- def __init__(self, generator_loss_fn, discriminator_loss_fn,
- generator_optimizer, discriminator_optimizer,
- use_loss_summaries=True,
- get_hooks_fn=None,
- get_eval_metric_ops_fn=None,
- name=None):
- """`Head` for GAN training.
-
- Args:
- generator_loss_fn: A TFGAN loss function for the generator. Takes a
- `GANModel` and returns a scalar.
- discriminator_loss_fn: Same as `generator_loss_fn`, but for the
- discriminator.
- generator_optimizer: The optimizer for generator updates.
- discriminator_optimizer: Same as `generator_optimizer`, but for the
- discriminator updates.
- use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
- If `None`, uses defaults.
- get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
- list of hooks. Defaults to `train.get_sequential_train_hooks()`
- get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
- dict of metric results keyed by name. The output of this function is
- passed into `tf.estimator.EstimatorSpec` during evaluation.
- name: name of the head. If provided, summary and metrics keys will be
- suffixed by `"/" + name`.
- """
- if get_hooks_fn is None:
- get_hooks_fn = tfgan_train.get_sequential_train_hooks()
- # TODO(joelshor): Validate inputs.
-
- if use_loss_summaries in [True, False]:
- generator_loss_fn = functools.partial(
- generator_loss_fn, add_summaries=use_loss_summaries)
- discriminator_loss_fn = functools.partial(
- discriminator_loss_fn, add_summaries=use_loss_summaries)
- self._generator_loss_fn = generator_loss_fn
- self._discriminator_loss_fn = discriminator_loss_fn
- self._generator_optimizer = generator_optimizer
- self._discriminator_optimizer = discriminator_optimizer
- self._get_hooks_fn = get_hooks_fn
- self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
- self._name = name
-
- @property
- def name(self):
- return self._name
-
- @property
- def logits_dimension(self):
- return None
-
- def create_loss(self, features, mode, logits, labels):
- """Returns a GANLoss tuple from the provided GANModel.
-
- See `Head` for more details.
-
- Args:
- features: Input `dict` of `Tensor` objects. Unused.
- mode: Estimator's `ModeKeys`.
- logits: A GANModel tuple.
- labels: Must be `None`.
-
- Returns:
- A GANLoss tuple.
-
- """
- _validate_logits_and_labels(logits, labels)
- del mode, labels, features # unused for this head.
- gan_model = logits # rename variable for clarity
- return tfgan_tuples.GANLoss(
- generator_loss=self._generator_loss_fn(gan_model),
- discriminator_loss=self._discriminator_loss_fn(gan_model))
-
- def create_estimator_spec(
- self, features, mode, logits, labels=None,
- train_op_fn=tfgan_train.gan_train_ops):
- """Returns `EstimatorSpec` that a model_fn can return.
-
- See `Head` for more details.
-
- Args:
- features: Must be `None`.
- mode: Estimator's `ModeKeys`.
- logits: A GANModel tuple.
- labels: Must be `None`.
- train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer,
- and discriminator optimizer, and returns a `GANTrainOps` tuple. For
- example, this function can come from TFGAN's `train.py` library, or can
- be custom.
-
- Returns:
- `EstimatorSpec`.
-
- Raises:
- ValueError: If `features` isn't `None`.
- ValueError: If `train_op_fn` isn't provided in train mode.
- """
- _validate_logits_and_labels(logits, labels)
- if features is not None:
- raise ValueError('`features` should be `None`. Instead, found: %s' %
- features)
- gan_model = logits # rename variable for clarity
- with ops.name_scope('GANHead'):
- if mode == model_fn_lib.ModeKeys.PREDICT:
- return model_fn_lib.EstimatorSpec(
- mode=model_fn_lib.ModeKeys.PREDICT,
- predictions=gan_model.generated_data)
- elif mode == model_fn_lib.ModeKeys.EVAL:
- gan_loss = self.create_loss(
- features=None, mode=mode, logits=gan_model, labels=None)
- scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
- with ops.name_scope(None, 'metrics',
- [gan_loss.generator_loss,
- gan_loss.discriminator_loss]):
- eval_metric_ops = {
- _summary_key(self._name, 'generator_loss'):
- metrics_lib.mean(gan_loss.generator_loss),
- _summary_key(self._name, 'discriminator_loss'):
- metrics_lib.mean(gan_loss.discriminator_loss)
- }
- if self._get_eval_metric_ops_fn is not None:
- custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model)
- if not isinstance(custom_eval_metric_ops, dict):
- raise TypeError('get_eval_metric_ops_fn must return a dict, '
- 'received: {}'.format(custom_eval_metric_ops))
- eval_metric_ops.update(custom_eval_metric_ops)
- return model_fn_lib.EstimatorSpec(
- mode=model_fn_lib.ModeKeys.EVAL,
- predictions=gan_model.generated_data,
- loss=scalar_loss,
- eval_metric_ops=eval_metric_ops)
- elif mode == model_fn_lib.ModeKeys.TRAIN:
- if train_op_fn is None:
- raise ValueError('train_op_fn can not be None.')
- gan_loss = self.create_loss(None, mode, gan_model, None)
- scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
- train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer,
- self._discriminator_optimizer)
- training_hooks = self._get_hooks_fn(train_ops)
- return model_fn_lib.EstimatorSpec(
- loss=scalar_loss,
- mode=model_fn_lib.ModeKeys.TRAIN,
- train_op=train_ops.global_step_inc_op,
- training_hooks=training_hooks)
- else:
- raise ValueError('Mode not recognized: %s' % mode)
-
-
-def _validate_logits_and_labels(logits, labels):
- if labels is not None:
- raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must '
- 'be `None`. Instead, found: %s' % labels)
-
- if not isinstance(logits, tfgan_tuples.GANModel):
- raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must '
- 'be an instnace of a `GANModel`. Instead, found: %s' %
- logits)
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
deleted file mode 100644
index 6587f1fc60..0000000000
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for TFGAN's head.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
-from tensorflow.contrib.gan.python.estimator.python import head
-
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-from tensorflow.python.training import training
-
-
-def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
- return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
- gan_model.discriminator_gen_outputs)
-
-
-def get_gan_model():
- # TODO(joelshor): Find a better way of creating a variable scope.
- with variable_scope.variable_scope('generator') as gen_scope:
- gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
- with variable_scope.variable_scope('discriminator') as dis_scope:
- dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
- return tfgan_tuples.GANModel(
- generator_inputs=None,
- generated_data=array_ops.ones([3, 4]),
- generator_variables=[gen_var],
- generator_scope=gen_scope,
- generator_fn=None,
- real_data=None,
- discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
- discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
- discriminator_variables=[dis_var],
- discriminator_scope=dis_scope,
- discriminator_fn=None)
-
-
-class GANHeadTest(test.TestCase):
-
- def setUp(self):
- super(GANHeadTest, self).setUp()
- self.gan_head = head.gan_head(
- generator_loss_fn=dummy_loss,
- discriminator_loss_fn=dummy_loss,
- generator_optimizer=training.GradientDescentOptimizer(1.0),
- discriminator_optimizer=training.GradientDescentOptimizer(1.0),
- get_eval_metric_ops_fn=self.get_metrics)
- self.assertTrue(isinstance(self.gan_head, head.GANHead))
-
- def get_metrics(self, gan_model):
- self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
- return {}
-
- def _test_modes_helper(self, mode):
- self.gan_head.create_estimator_spec(
- features=None,
- mode=mode,
- logits=get_gan_model())
-
- def test_modes_predict(self):
- self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
-
- def test_modes_eval(self):
- self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)
-
- def test_modes_train(self):
- self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN)
-
-
-if __name__ == '__main__':
- test.main()