diff options
author | Dan Moldovan <mdan@google.com> | 2018-10-02 12:14:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 12:19:36 -0700 |
commit | 16b44d48d485dbb62b9922e172df4cc460174046 (patch) | |
tree | 6001c3c185e51957e68f66d5c3666950f4b44ae7 /tensorflow/python/autograph | |
parent | 7a0ce3c3a24a91c0bd17a681fa0833c9044e9256 (diff) |
Fix the case when an object may have multiple directives with the same annotation.
PiperOrigin-RevId: 215435613
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/core/BUILD | 47 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/converter.py | 53 | ||||
-rw-r--r-- | tensorflow/python/autograph/core/converter_test.py | 124 |
3 files changed, 184 insertions, 40 deletions
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD index 843e381f31..3ab2e7b1bc 100644 --- a/tensorflow/python/autograph/core/BUILD +++ b/tensorflow/python/autograph/core/BUILD @@ -33,6 +33,35 @@ py_library( ], ) +py_library( + name = "test_lib", + srcs = [ + "converter_testing.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":core", + "//tensorflow/python/autograph/operators", + "//tensorflow/python/autograph/pyct", + "//tensorflow/python/autograph/pyct/static_analysis", + "//tensorflow/python/autograph/utils", + "@gast_archive//:gast", + "@six_archive//:six", + ], +) + +py_test( + name = "converter_test", + srcs = ["converter_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":core", + ":test_lib", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "errors_test", srcs = ["errors_test.py"], @@ -67,21 +96,3 @@ py_test( "//tensorflow/python:client_testlib", ], ) - -py_library( - name = "test_lib", - srcs = [ - "converter_testing.py", - ], - srcs_version = "PY2AND3", - visibility = ["//tensorflow:__subpackages__"], - deps = [ - ":core", - "//tensorflow/python/autograph/operators", - "//tensorflow/python/autograph/pyct", - "//tensorflow/python/autograph/pyct/static_analysis", - "//tensorflow/python/autograph/utils", - "@gast_archive//:gast", - "@six_archive//:six", - ], -) diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py index 80928ae7f4..408a573ad0 100644 --- a/tensorflow/python/autograph/core/converter.py +++ b/tensorflow/python/autograph/core/converter.py @@ -210,14 +210,22 @@ class Base(transformer.Base): self._ast_depth = 0 def get_definition_directive(self, node, directive, arg, default): - """Returns the unique directive for a symbol, or a default if none exist. + """Returns the unique directive argument for a symbol. See lang/directives.py for details on directives. + Example: + # Given a directive in the code: + ag.foo_directive(bar, baz=1) + + # One can write for an AST node Name(id='bar'): + get_definition_directive(node, ag.foo_directive, 'baz') + Args: - node: ast.AST - directive: Callable[..., Any] - arg: str + node: ast.AST, the node representing the symbol for which the directive + argument is needed. + directive: Callable[..., Any], the directive to search. + arg: str, the directive argument to return. default: Any Raises: @@ -227,27 +235,28 @@ class Base(transformer.Base): if not defs: return default - # TODO(mdan): Simplify this. - arg_values = [] + arg_values_found = [] for def_ in defs: - if (directive not in def_.directives or - arg not in def_.directives[directive]): - continue - arg_value = def_.directives[directive][arg] - for prev_value in arg_values: - if not ast_util.matches(arg_value, prev_value): - qn = anno.getanno(node, anno.Basic.QN) - raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % - (qn, directive.__name__, arg, - compiler.ast_to_source(arg_value).strip(), - compiler.ast_to_source(prev_value).strip())) - arg_values.append(arg_value) - - if not arg_values: + if (directive in def_.directives and arg in def_.directives[directive]): + arg_values_found.append(def_.directives[directive][arg]) + + if not arg_values_found: return default - arg_value, = arg_values - return arg_value + if len(arg_values_found) == 1: + return arg_values_found[0] + + # If multiple annotations reach the symbol, they must all match. If they do, + # return any of them. + first_value = arg_values_found[0] + for other_value in arg_values_found[1:]: + if not ast_util.matches(first_value, other_value): + qn = anno.getanno(node, anno.Basic.QN) + raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % + (qn, directive.__name__, arg, + compiler.ast_to_source(other_value).strip(), + compiler.ast_to_source(first_value).strip())) + return first_value def visit(self, node): if not self._ast_depth: diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py new file mode 100644 index 0000000000..b73c67e337 --- /dev/null +++ b/tensorflow/python/autograph/core/converter_test.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for lists module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.core import converter +from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.autograph.pyct import anno +from tensorflow.python.autograph.pyct import parser +from tensorflow.python.platform import test + + +class TestConverter(converter.Base): + pass + + +class ConverterBaseTest(converter_testing.TestCase): + + def test_get_definition_directive_basic(self): + + directive_key = object + + def test_fn(): + a = 1 + return a + + ns = {} + node, ctx = self.prepare(test_fn, ns) + symbol_a = node.body[1].value + defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) + defs.directives[directive_key] = { + 'test_arg': parser.parse_expression('foo'), + 'other_arg': parser.parse_expression('bar'), + } + c = TestConverter(ctx) + value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', + None) + self.assertEqual(value.id, 'foo') + + def test_get_definition_directive_default(self): + + directive_key = object + + def test_fn(): + a = 1 + return a + + ns = {} + node, ctx = self.prepare(test_fn, ns) + symbol_a = node.body[1].value + c = TestConverter(ctx) + value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', + parser.parse_expression('default')) + self.assertEqual(value.id, 'default') + + def test_get_definition_directive_multiple_consistent(self): + + directive_key = object + + def test_fn(): + a = 1 + if a: + a = 2 + return a + + ns = {} + node, ctx = self.prepare(test_fn, ns) + symbol_a = node.body[2].value + defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) + defs[0].directives[directive_key] = { + 'test_arg': parser.parse_expression('foo'), + 'other_arg': parser.parse_expression('bar'), + } + defs[1].directives[directive_key] = { + 'test_arg': parser.parse_expression('foo'), + 'other_arg': parser.parse_expression('baz'), + } + c = TestConverter(ctx) + value = c.get_definition_directive(symbol_a, directive_key, 'test_arg', + None) + self.assertEqual(value.id, 'foo') + + def test_get_definition_directive_multiple_inconsistent(self): + + directive_key = object + + def test_fn(): + a = 1 + if a: + a = 2 + return a + + ns = {} + node, ctx = self.prepare(test_fn, ns) + symbol_a = node.body[2].value + defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS) + defs[0].directives[directive_key] = { + 'test_arg': parser.parse_expression('foo'), + } + defs[1].directives[directive_key] = { + 'test_arg': parser.parse_expression('bar'), + } + c = TestConverter(ctx) + with self.assertRaises(ValueError): + c.get_definition_directive(symbol_a, directive_key, 'test_arg', None) + + +if __name__ == '__main__': + test.main() |