aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 09:28:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 09:36:30 -0700
commitbc6b60f1bc79c2753cea087cf0eba1d76c5702df (patch)
treef4c67cef7b4fc1351c0a78144bdf49c73f76a0e5
parent7a8c63da365106048dc96affddb39e2fdc33da89 (diff)
Fix tuple_losses bug caused by Python bug.
PiperOrigin-RevId: 168386341
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py16
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py15
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
index 8805633dee..fca8063891 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
@@ -73,7 +73,21 @@ def _args_to_gan_model(loss_fn):
default_args_dict = dict(zip(args_with_defaults, defaults))
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
- gan_model_dict = gan_model._asdict()
+ def _asdict(namedtuple):
+ """Returns a namedtuple as a dictionary.
+
+ This is required because `_asdict()` in Python 3.x.x is broken in classes
+ that inherit from `collections.namedtuple`. See
+ https://bugs.python.org/issue24931 for more details.
+
+ Args:
+ namedtuple: An object that inherits from `collections.namedtuple`.
+
+ Returns:
+ A dictionary version of the tuple.
+ """
+ return {k: getattr(namedtuple, k) for k in namedtuple._fields}
+ gan_model_dict = _asdict(gan_model)
# Make sure non-tuple required args are supplied.
args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index f65b20d0b5..215b15ef69 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -79,6 +79,21 @@ class ArgsToGanModelTest(test.TestCase):
# If `arg3` were not set properly, this value would be different.
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
+ def test_works_with_child_classes(self):
+ """`args_to_gan_model` should work with classes derived from namedtuple."""
+ tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])
+
+ class InheritedType(tuple_type):
+ pass
+ def args_loss(arg1, arg2, arg3=3):
+ return arg1 + 2 * arg2 + 3 * arg3
+
+ loss_fn = tfgan_losses._args_to_gan_model(args_loss)
+ loss = loss_fn(InheritedType(arg1=-1, arg2=2), arg3=4)
+
+ # If `arg3` were not set properly, this value would be different.
+ self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
+
class ConsistentLossesTest(test.TestCase):