aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/static_analysis/type_info.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/type_info.py48
1 files changed, 6 insertions, 42 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
index a229c288a8..835d5199fa 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/type_info.py
@@ -43,9 +43,8 @@ from __future__ import print_function
import gast
-from tensorflow.contrib.autograph import utils
from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import parser
+from tensorflow.contrib.autograph.pyct import ast_util
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -166,7 +165,6 @@ class TypeInfoResolver(transformer.Base):
definition = self.scope.getval(qn)
anno.copyanno(definition, node, 'type')
anno.copyanno(definition, node, 'type_fqn')
- anno.setanno(node, 'definition', definition)
# TODO(mdan): Remove this when the directives module is in.
anno.copyanno(definition, node, 'element_type')
@@ -198,52 +196,18 @@ class TypeInfoResolver(transformer.Base):
def visit_With(self, node):
for item in node.items:
if item.optional_vars is not None:
- self.apply_to_single_assignments((item.optional_vars,),
- item.context_expr,
- self._process_variable_assignment)
+ ast_util.apply_to_single_assignments((item.optional_vars,),
+ item.context_expr,
+ self._process_variable_assignment)
self.generic_visit(node)
return node
def visit_Assign(self, node):
self.generic_visit(node)
- self.apply_to_single_assignments(
- node.targets, node.value, self._process_variable_assignment)
+ ast_util.apply_to_single_assignments(node.targets, node.value,
+ self._process_variable_assignment)
return node
- # TODO(mdan): Remove as soon as the new directives module is ready.
- def visit_Call(self, node):
- if anno.hasanno(node.func, 'live_val'):
- # Symbols targeted by the "set_type" marker function are assigned the data
- # type that it specified.
- if anno.getanno(node.func, 'live_val') is utils.set_element_type:
-
- if len(node.args) < 2 or len(node.args) > 3:
- raise ValueError('"%s" must have either two or three parameters'
- % self.context.type_annotation_func)
- if len(node.args) == 2:
- target_arg, type_arg = node.args
- shape_arg = parser.parse_expression('None')
- else:
- target_arg, type_arg, shape_arg = node.args
- if not anno.hasanno(target_arg, anno.Basic.QN):
- raise ValueError('the first argument of "%s" must by a symbol' %
- utils.set_element_type)
- # TODO(mdan): This is vulnerable to symbol renaming.
- element_type = type_arg
- element_shape = shape_arg
-
- target_symbol = anno.getanno(target_arg, anno.Basic.QN)
- # Find the definition of this symbol and annotate it with the given
- # data type. That in turn will cause future uses of the symbol
- # to receive the same type annotation.
- definition = self.scope.getval(target_symbol)
- anno.setanno(node, 'element_type', element_type)
- anno.setanno(node, 'element_shape', element_shape)
- anno.setanno(definition, 'element_type', element_type)
- anno.setanno(definition, 'element_shape', element_shape)
- # TODO(mdan): Should we update references between definition and here?
- return self.generic_visit(node)
-
def resolve(node, context):
return TypeInfoResolver(context).visit(node)