diff options
author | 2017-09-12 09:28:51 -0700 | |
---|---|---|
committer | 2017-09-12 09:36:30 -0700 | |
commit | bc6b60f1bc79c2753cea087cf0eba1d76c5702df (patch) | |
tree | f4c67cef7b4fc1351c0a78144bdf49c73f76a0e5 | |
parent | 7a8c63da365106048dc96affddb39e2fdc33da89 (diff) |
Fix tuple_losses bug caused by Python bug.
PiperOrigin-RevId: 168386341
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py | 15 |
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 8805633dee..fca8063891 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -73,7 +73,21 @@ def _args_to_gan_model(loss_fn): default_args_dict = dict(zip(args_with_defaults, defaults)) def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring - gan_model_dict = gan_model._asdict() + def _asdict(namedtuple): + """Returns a namedtuple as a dictionary. + + This is required because `_asdict()` in Python 3.x.x is broken in classes + that inherit from `collections.namedtuple`. See + https://bugs.python.org/issue24931 for more details. + + Args: + namedtuple: An object that inherits from `collections.namedtuple`. + + Returns: + A dictionary version of the tuple. + """ + return {k: getattr(namedtuple, k) for k in namedtuple._fields} + gan_model_dict = _asdict(gan_model) # Make sure non-tuple required args are supplied. args_from_tuple = set(argspec.args).intersection(set(gan_model._fields)) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index f65b20d0b5..215b15ef69 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -79,6 +79,21 @@ class ArgsToGanModelTest(test.TestCase): # If `arg3` were not set properly, this value would be different. self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) + def test_works_with_child_classes(self): + """`args_to_gan_model` should work with classes derived from namedtuple.""" + tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) + + class InheritedType(tuple_type): + pass + def args_loss(arg1, arg2, arg3=3): + return arg1 + 2 * arg2 + 3 * arg3 + + loss_fn = tfgan_losses._args_to_gan_model(args_loss) + loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4) + + # If `arg3` were not set properly, this value would be different. + self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) + class ConsistentLossesTest(test.TestCase): |