aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-02 12:14:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 12:19:36 -0700
commit16b44d48d485dbb62b9922e172df4cc460174046 (patch)
tree6001c3c185e51957e68f66d5c3666950f4b44ae7 /tensorflow/python/autograph
parent7a0ce3c3a24a91c0bd17a681fa0833c9044e9256 (diff)
Fix the case when an object may have multiple directives with the same annotation.
PiperOrigin-RevId: 215435613
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/core/BUILD47
-rw-r--r--tensorflow/python/autograph/core/converter.py53
-rw-r--r--tensorflow/python/autograph/core/converter_test.py124
3 files changed, 184 insertions, 40 deletions
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 843e381f31..3ab2e7b1bc 100644
--- a/tensorflow/python/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -33,6 +33,35 @@ py_library(
],
)
+py_library(
+ name = "test_lib",
+ srcs = [
+ "converter_testing.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "converter_test",
+ srcs = ["converter_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ":test_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "errors_test",
srcs = ["errors_test.py"],
@@ -67,21 +96,3 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
-
-py_library(
- name = "test_lib",
- srcs = [
- "converter_testing.py",
- ],
- srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
- deps = [
- ":core",
- "//tensorflow/python/autograph/operators",
- "//tensorflow/python/autograph/pyct",
- "//tensorflow/python/autograph/pyct/static_analysis",
- "//tensorflow/python/autograph/utils",
- "@gast_archive//:gast",
- "@six_archive//:six",
- ],
-)
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 80928ae7f4..408a573ad0 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -210,14 +210,22 @@ class Base(transformer.Base):
self._ast_depth = 0
def get_definition_directive(self, node, directive, arg, default):
- """Returns the unique directive for a symbol, or a default if none exist.
+ """Returns the unique directive argument for a symbol.
See lang/directives.py for details on directives.
+ Example:
+ # Given a directive in the code:
+ ag.foo_directive(bar, baz=1)
+
+ # One can write for an AST node Name(id='bar'):
+ get_definition_directive(node, ag.foo_directive, 'baz')
+
Args:
- node: ast.AST
- directive: Callable[..., Any]
- arg: str
+ node: ast.AST, the node representing the symbol for which the directive
+ argument is needed.
+ directive: Callable[..., Any], the directive to search.
+ arg: str, the directive argument to return.
default: Any
Raises:
@@ -227,27 +235,28 @@ class Base(transformer.Base):
if not defs:
return default
- # TODO(mdan): Simplify this.
- arg_values = []
+ arg_values_found = []
for def_ in defs:
- if (directive not in def_.directives or
- arg not in def_.directives[directive]):
- continue
- arg_value = def_.directives[directive][arg]
- for prev_value in arg_values:
- if not ast_util.matches(arg_value, prev_value):
- qn = anno.getanno(node, anno.Basic.QN)
- raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
- (qn, directive.__name__, arg,
- compiler.ast_to_source(arg_value).strip(),
- compiler.ast_to_source(prev_value).strip()))
- arg_values.append(arg_value)
-
- if not arg_values:
+ if (directive in def_.directives and arg in def_.directives[directive]):
+ arg_values_found.append(def_.directives[directive][arg])
+
+ if not arg_values_found:
return default
- arg_value, = arg_values
- return arg_value
+ if len(arg_values_found) == 1:
+ return arg_values_found[0]
+
+ # If multiple annotations reach the symbol, they must all match. If they do,
+ # return any of them.
+ first_value = arg_values_found[0]
+ for other_value in arg_values_found[1:]:
+ if not ast_util.matches(first_value, other_value):
+ qn = anno.getanno(node, anno.Basic.QN)
+ raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
+ (qn, directive.__name__, arg,
+ compiler.ast_to_source(other_value).strip(),
+ compiler.ast_to_source(first_value).strip()))
+ return first_value
def visit(self, node):
if not self._ast_depth:
diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py
new file mode 100644
index 0000000000..b73c67e337
--- /dev/null
+++ b/tensorflow/python/autograph/core/converter_test.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lists module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class TestConverter(converter.Base):
+ pass
+
+
+class ConverterBaseTest(converter_testing.TestCase):
+
+ def test_get_definition_directive_basic(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[1].value
+ defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs.directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('bar'),
+ }
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ None)
+ self.assertEqual(value.id, 'foo')
+
+ def test_get_definition_directive_default(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[1].value
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ parser.parse_expression('default'))
+ self.assertEqual(value.id, 'default')
+
+ def test_get_definition_directive_multiple_consistent(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ if a:
+ a = 2
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[2].value
+ defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs[0].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('bar'),
+ }
+ defs[1].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('baz'),
+ }
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ None)
+ self.assertEqual(value.id, 'foo')
+
+ def test_get_definition_directive_multiple_inconsistent(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ if a:
+ a = 2
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[2].value
+ defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs[0].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ }
+ defs[1].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('bar'),
+ }
+ c = TestConverter(ctx)
+ with self.assertRaises(ValueError):
+ c.get_definition_directive(symbol_a, directive_key, 'test_arg', None)
+
+
+if __name__ == '__main__':
+ test.main()