aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/estimator/python/head_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/estimator/python/head_impl.py')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index 5b5557bd8f..d1441e1eb2 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -103,9 +103,20 @@ class GANHead(head._Head): # pylint: disable=protected-access
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
+
+ if not callable(generator_loss_fn):
+ raise TypeError('generator_loss_fn must be callable.')
+ if not callable(discriminator_loss_fn):
+ raise TypeError('discriminator_loss_fn must be callable.')
+ if not use_loss_summaries in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
+ if name is not None and not isinstance(name, str):
+ raise TypeError('name must be string.')
+
if get_hooks_fn is None:
get_hooks_fn = tfgan_train.get_sequential_train_hooks()
- # TODO(joelshor): Validate inputs.
if use_loss_summaries in [True, False]:
generator_loss_fn = functools.partial(