diff options
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/static_analysis/type_info.py')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/static_analysis/type_info.py | 48 |
1 files changed, 6 insertions, 42 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py index a229c288a8..835d5199fa 100644 --- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py @@ -43,9 +43,8 @@ from __future__ import print_function import gast -from tensorflow.contrib.autograph import utils from tensorflow.contrib.autograph.pyct import anno -from tensorflow.contrib.autograph.pyct import parser +from tensorflow.contrib.autograph.pyct import ast_util from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -166,7 +165,6 @@ class TypeInfoResolver(transformer.Base): definition = self.scope.getval(qn) anno.copyanno(definition, node, 'type') anno.copyanno(definition, node, 'type_fqn') - anno.setanno(node, 'definition', definition) # TODO(mdan): Remove this when the directives module is in. anno.copyanno(definition, node, 'element_type') @@ -198,52 +196,18 @@ class TypeInfoResolver(transformer.Base): def visit_With(self, node): for item in node.items: if item.optional_vars is not None: - self.apply_to_single_assignments((item.optional_vars,), - item.context_expr, - self._process_variable_assignment) + ast_util.apply_to_single_assignments((item.optional_vars,), + item.context_expr, + self._process_variable_assignment) self.generic_visit(node) return node def visit_Assign(self, node): self.generic_visit(node) - self.apply_to_single_assignments( - node.targets, node.value, self._process_variable_assignment) + ast_util.apply_to_single_assignments(node.targets, node.value, + self._process_variable_assignment) return node - # TODO(mdan): Remove as soon as the new directives module is ready. - def visit_Call(self, node): - if anno.hasanno(node.func, 'live_val'): - # Symbols targeted by the "set_type" marker function are assigned the data - # type that it specified. - if anno.getanno(node.func, 'live_val') is utils.set_element_type: - - if len(node.args) < 2 or len(node.args) > 3: - raise ValueError('"%s" must have either two or three parameters' - % self.context.type_annotation_func) - if len(node.args) == 2: - target_arg, type_arg = node.args - shape_arg = parser.parse_expression('None') - else: - target_arg, type_arg, shape_arg = node.args - if not anno.hasanno(target_arg, anno.Basic.QN): - raise ValueError('the first argument of "%s" must by a symbol' % - utils.set_element_type) - # TODO(mdan): This is vulnerable to symbol renaming. - element_type = type_arg - element_shape = shape_arg - - target_symbol = anno.getanno(target_arg, anno.Basic.QN) - # Find the definition of this symbol and annotate it with the given - # data type. That in turn will cause future uses of the symbol - # to receive the same type annotation. - definition = self.scope.getval(target_symbol) - anno.setanno(node, 'element_type', element_type) - anno.setanno(node, 'element_shape', element_shape) - anno.setanno(definition, 'element_type', element_type) - anno.setanno(definition, 'element_shape', element_shape) - # TODO(mdan): Should we update references between definition and here? - return self.generic_visit(node) - def resolve(node, context): return TypeInfoResolver(context).visit(node) |