aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-01 11:48:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 11:52:02 -0700
commit09c8964acdeeb11634c43bd5ac0c68d7588f2c01 (patch)
treef4235185ea58bd7d13e3f82c8a1702c055cd8ddb /tensorflow/contrib/autograph
parent0d4d93a47c998c5e6aeef2d0db1ffbc331679208 (diff)
Fix for unspecified arguments in AutoGraph directives.
PiperOrigin-RevId: 206965028
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/directives.py22
-rw-r--r--tensorflow/contrib/autograph/converters/directives_test.py19
-rw-r--r--tensorflow/contrib/autograph/core/converter.py2
3 files changed, 40 insertions, 3 deletions
diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py
index ccdf79d47b..77f625bac7 100644
--- a/tensorflow/contrib/autograph/converters/directives.py
+++ b/tensorflow/contrib/autograph/converters/directives.py
@@ -42,10 +42,30 @@ def _map_args(call_node, function):
Returns:
Dict[Text, ast.AST], mapping each of the function's argument names to
the respective AST node.
+ Raises:
+ ValueError: if the default arguments are not correctly set
"""
args = call_node.args
kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
- return tf_inspect.getcallargs(function, *args, **kwds)
+ call_args = tf_inspect.getcallargs(function, *args, **kwds)
+
+ # Keyword arguments not specified in kwds will be mapped to their defaults,
+ # which are Python values. Since we don't currently have a way to transform
+ # those into AST references, we simply remove them. By convention, directives
+ # use UNSPECIFIED as default value for for optional arguments. No other
+ # defaults should be present.
+ unexpected_defaults = []
+ for k in call_args:
+ if (k not in kwds
+ and call_args[k] not in args
+ and call_args[k] is not directives.UNSPECIFIED):
+ unexpected_defaults.append(k)
+ if unexpected_defaults:
+ raise ValueError('Unexpected keyword argument values, %s, for function %s'
+ % (zip(unexpected_defaults,
+ [call_args[k] for k in unexpected_defaults]),
+ function))
+ return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
class DirectivesTransformer(converter.Base):
diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py
index a573ba5850..a2d083b891 100644
--- a/tensorflow/contrib/autograph/converters/directives_test.py
+++ b/tensorflow/contrib/autograph/converters/directives_test.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.contrib.autograph.core.converter import AgAnno
from tensorflow.contrib.autograph.lang import directives
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
from tensorflow.python.platform import test
@@ -71,7 +72,23 @@ class DirectivesTest(converter_testing.TestCase):
d = d[directives.set_loop_options]
self.assertEqual(d['parallel_iterations'].n, 10)
self.assertEqual(d['back_prop'].id, 'a')
- self.assertEqual(d['swap_memory'], directives.UNSPECIFIED)
+ self.assertNotIn('swap_memory', d)
+
+ def test_invalid_default(self):
+
+ def invalid_directive(valid_arg, invalid_default=object()):
+ del valid_arg
+ del invalid_default
+ return
+
+ def call_invalid_directive():
+ invalid_directive(1)
+
+ node, _ = parser.parse_entity(call_invalid_directive)
+ # Find the call to the invalid directive
+ node = node.body[0].body[0].value
+ with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
+ directives_converter._map_args(node, invalid_directive)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py
index a93e4a8064..83a80c1f52 100644
--- a/tensorflow/contrib/autograph/core/converter.py
+++ b/tensorflow/contrib/autograph/core/converter.py
@@ -233,7 +233,7 @@ class Base(transformer.Base):
arg_values = []
for def_ in defs:
if (directive not in def_.directives or
- arg not in arg not in def_.directives[directive]):
+ arg not in def_.directives[directive]):
continue
arg_value = def_.directives[directive][arg]
for prev_value in arg_values: