aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-09 12:42:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-09 12:46:08 -0800
commit55cd506ab8220c6a1075965eb7839cac4af1db3e (patch)
tree5805cb43e86854749a959cf08c8d922a1ef6557a
parentedb0bea1109b64fe1f1d45360c77bbab9e855c2e (diff)
Extend the type info analyzer to cover variables declared using with statements.
This allows constructs of the kind: with tfe.GradientTape() as tape: tape.gradients(...) PiperOrigin-RevId: 181358791
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py37
-rw-r--r--tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py21
2 files changed, 44 insertions, 14 deletions
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
index c1ad30815e..4a9730a1ec 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py
@@ -99,29 +99,38 @@ class TypeInfoResolver(gast.NodeTransformer):
self.scope.setval(node.id, type_holder)
return node
- def visit_Assign(self, node):
- self.generic_visit(node)
- if isinstance(node.value, gast.Call):
- target = node.value.func
- if anno.hasanno(target, 'live_val'):
- target_obj = anno.getanno(target, 'live_val')
- if tf_inspect.isclass(target_obj):
+ def _process_variable_assignment(self, source, targets):
+ if isinstance(source, gast.Call):
+ func = source.func
+ if anno.hasanno(func, 'live_val'):
+ func_obj = anno.getanno(func, 'live_val')
+ if tf_inspect.isclass(func_obj):
# This is then a constructor.
- anno.setanno(node.value, 'type', target_obj)
- anno.setanno(node.value, 'type_fqn', anno.getanno(target, 'fqn'))
+ anno.setanno(source, 'type', func_obj)
+ anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn'))
# TODO(mdan): Raise an error if constructor has side effects.
# We can have a whitelist of no-side-effects constructors.
# We can also step inside the constructor and further analyze.
- for n in node.targets:
- if isinstance(n, gast.Tuple):
- for i, e in enumerate(n.elts):
+ for t in targets:
+ if isinstance(t, gast.Tuple):
+ for i, e in enumerate(t.elts):
self.scope.setval(e.id,
gast.Subscript(
- node.value, gast.Index(i), ctx=gast.Store()))
+ source, gast.Index(i), ctx=gast.Store()))
else:
- self.scope.setval(n.id, node.value)
+ self.scope.setval(t.id, source)
+ def visit_With(self, node):
+ for wi in node.items:
+ if wi.optional_vars is not None:
+ self._process_variable_assignment(wi.context_expr, (wi.optional_vars,))
+ self.generic_visit(node)
+ return node
+
+ def visit_Assign(self, node):
+ self.generic_visit(node)
+ self._process_variable_assignment(node.value, node.targets)
return node
def visit_Call(self, node):
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
index 66abde71a8..0748be48f8 100644
--- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
+++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py
@@ -23,6 +23,7 @@ from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct.static_analysis import access
from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.client import session
from tensorflow.python.platform import test
from tensorflow.python.training import training
@@ -85,6 +86,26 @@ class TypeInfoResolverTest(test.TestCase):
self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
anno.getanno(attr_call_node, 'type_fqn'))
+ def test_class_members_in_with_stmt(self):
+
+ def test_fn(x):
+ with session.Session() as sess:
+ sess.run(x)
+
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, {'session': session}, {})
+ node = type_info.resolve(node, None)
+
+ constructor_call = node.body[0].body[0].items[0].context_expr
+ self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
+ self.assertEquals((session.__name__, 'Session'),
+ anno.getanno(constructor_call, 'type_fqn'))
+
+ member_call = node.body[0].body[0].body[0].value.func
+ self.assertEquals((session.__name__, 'Session'),
+ anno.getanno(member_call, 'type_fqn'))
+
def test_parameter_class_members(self):
def test_fn(opt):