diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-01 11:48:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-01 11:52:02 -0700 |
commit | 09c8964acdeeb11634c43bd5ac0c68d7588f2c01 (patch) | |
tree | f4235185ea58bd7d13e3f82c8a1702c055cd8ddb /tensorflow/contrib/autograph | |
parent | 0d4d93a47c998c5e6aeef2d0db1ffbc331679208 (diff) |
Fix for unspecified arguments in AutoGraph directives.
PiperOrigin-RevId: 206965028
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/converters/directives.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/converters/directives_test.py | 19 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/core/converter.py | 2 |
3 files changed, 40 insertions, 3 deletions
diff --git a/tensorflow/contrib/autograph/converters/directives.py b/tensorflow/contrib/autograph/converters/directives.py index ccdf79d47b..77f625bac7 100644 --- a/tensorflow/contrib/autograph/converters/directives.py +++ b/tensorflow/contrib/autograph/converters/directives.py @@ -42,10 +42,30 @@ def _map_args(call_node, function): Returns: Dict[Text, ast.AST], mapping each of the function's argument names to the respective AST node. + Raises: + ValueError: if the default arguments are not correctly set """ args = call_node.args kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} - return tf_inspect.getcallargs(function, *args, **kwds) + call_args = tf_inspect.getcallargs(function, *args, **kwds) + + # Keyword arguments not specified in kwds will be mapped to their defaults, + # which are Python values. Since we don't currently have a way to transform + # those into AST references, we simply remove them. By convention, directives + # use UNSPECIFIED as default value for for optional arguments. No other + # defaults should be present. + unexpected_defaults = [] + for k in call_args: + if (k not in kwds + and call_args[k] not in args + and call_args[k] is not directives.UNSPECIFIED): + unexpected_defaults.append(k) + if unexpected_defaults: + raise ValueError('Unexpected keyword argument values, %s, for function %s' + % (zip(unexpected_defaults, + [call_args[k] for k in unexpected_defaults]), + function)) + return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} class DirectivesTransformer(converter.Base): diff --git a/tensorflow/contrib/autograph/converters/directives_test.py b/tensorflow/contrib/autograph/converters/directives_test.py index a573ba5850..a2d083b891 100644 --- a/tensorflow/contrib/autograph/converters/directives_test.py +++ b/tensorflow/contrib/autograph/converters/directives_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib.autograph.core import converter_testing from tensorflow.contrib.autograph.core.converter import AgAnno from tensorflow.contrib.autograph.lang import directives from tensorflow.contrib.autograph.pyct import anno +from tensorflow.contrib.autograph.pyct import parser from tensorflow.python.platform import test @@ -71,7 +72,23 @@ class DirectivesTest(converter_testing.TestCase): d = d[directives.set_loop_options] self.assertEqual(d['parallel_iterations'].n, 10) self.assertEqual(d['back_prop'].id, 'a') - self.assertEqual(d['swap_memory'], directives.UNSPECIFIED) + self.assertNotIn('swap_memory', d) + + def test_invalid_default(self): + + def invalid_directive(valid_arg, invalid_default=object()): + del valid_arg + del invalid_default + return + + def call_invalid_directive(): + invalid_directive(1) + + node, _ = parser.parse_entity(call_invalid_directive) + # Find the call to the invalid directive + node = node.body[0].body[0].value + with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'): + directives_converter._map_args(node, invalid_directive) if __name__ == '__main__': diff --git a/tensorflow/contrib/autograph/core/converter.py b/tensorflow/contrib/autograph/core/converter.py index a93e4a8064..83a80c1f52 100644 --- a/tensorflow/contrib/autograph/core/converter.py +++ b/tensorflow/contrib/autograph/core/converter.py @@ -233,7 +233,7 @@ class Base(transformer.Base): arg_values = [] for def_ in defs: if (directive not in def_.directives or - arg not in arg not in def_.directives[directive]): + arg not in def_.directives[directive]): continue arg_value = def_.directives[directive][arg] for prev_value in arg_values: |