diff options
author | Martin Wicke <577277+martinwicke@users.noreply.github.com> | 2018-06-27 11:58:44 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-27 11:58:44 -0700 |
commit | 85cb8d48b7ad8a7587bc9a52d2198b027677bd13 (patch) | |
tree | 870aea7ad35100231349a83359582a946f37bd3f /tensorflow/contrib/gan | |
parent | 4803f7b673d4ad20e7b3caa3d893e9f5ed6978f3 (diff) | |
parent | ef98fc4fb98f7df05b636d022297e2a708a7986b (diff) |
Merge pull request #18565 from alexpantyukhin/ganhead_constructor_validate
add checking for input values in GANHead constructor
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/estimator/python/head_impl.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 5b5557bd8f..d1441e1eb2 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -103,9 +103,20 @@ class GANHead(head._Head): # pylint: disable=protected-access name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + + if not callable(generator_loss_fn): + raise TypeError('generator_loss_fn must be callable.') + if not callable(discriminator_loss_fn): + raise TypeError('discriminator_loss_fn must be callable.') + if not use_loss_summaries in [True, False, None]: + raise ValueError('use_loss_summaries must be True, False or None.') + if get_hooks_fn is not None and not callable(get_hooks_fn): + raise TypeError('get_hooks_fn must be callable.') + if name is not None and not isinstance(name, str): + raise TypeError('name must be string.') + if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() - # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( |