diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-09 12:42:10 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-09 12:46:08 -0800 |
commit | 55cd506ab8220c6a1075965eb7839cac4af1db3e (patch) | |
tree | 5805cb43e86854749a959cf08c8d922a1ef6557a | |
parent | edb0bea1109b64fe1f1d45360c77bbab9e855c2e (diff) |
Extend the type info analyzer to cover variables declared using with statements.
This allows constructs of the kind:
with tfe.GradientTape() as tape:
tape.gradients(...)
PiperOrigin-RevId: 181358791
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py | 37 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py | 21 |
2 files changed, 44 insertions, 14 deletions
diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py index c1ad30815e..4a9730a1ec 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info.py @@ -99,29 +99,38 @@ class TypeInfoResolver(gast.NodeTransformer): self.scope.setval(node.id, type_holder) return node - def visit_Assign(self, node): - self.generic_visit(node) - if isinstance(node.value, gast.Call): - target = node.value.func - if anno.hasanno(target, 'live_val'): - target_obj = anno.getanno(target, 'live_val') - if tf_inspect.isclass(target_obj): + def _process_variable_assignment(self, source, targets): + if isinstance(source, gast.Call): + func = source.func + if anno.hasanno(func, 'live_val'): + func_obj = anno.getanno(func, 'live_val') + if tf_inspect.isclass(func_obj): # This is then a constructor. - anno.setanno(node.value, 'type', target_obj) - anno.setanno(node.value, 'type_fqn', anno.getanno(target, 'fqn')) + anno.setanno(source, 'type', func_obj) + anno.setanno(source, 'type_fqn', anno.getanno(func, 'fqn')) # TODO(mdan): Raise an error if constructor has side effects. # We can have a whitelist of no-side-effects constructors. # We can also step inside the constructor and further analyze. - for n in node.targets: - if isinstance(n, gast.Tuple): - for i, e in enumerate(n.elts): + for t in targets: + if isinstance(t, gast.Tuple): + for i, e in enumerate(t.elts): self.scope.setval(e.id, gast.Subscript( - node.value, gast.Index(i), ctx=gast.Store())) + source, gast.Index(i), ctx=gast.Store())) else: - self.scope.setval(n.id, node.value) + self.scope.setval(t.id, source) + def visit_With(self, node): + for wi in node.items: + if wi.optional_vars is not None: + self._process_variable_assignment(wi.context_expr, (wi.optional_vars,)) + self.generic_visit(node) + return node + + def visit_Assign(self, node): + self.generic_visit(node) + self._process_variable_assignment(node.value, node.targets) return node def visit_Call(self, node): diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py index 66abde71a8..0748be48f8 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py @@ -23,6 +23,7 @@ from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct.static_analysis import access from tensorflow.contrib.py2tf.pyct.static_analysis import live_values from tensorflow.contrib.py2tf.pyct.static_analysis import type_info +from tensorflow.python.client import session from tensorflow.python.platform import test from tensorflow.python.training import training @@ -85,6 +86,26 @@ class TypeInfoResolverTest(test.TestCase): self.assertEquals((training.__name__, 'GradientDescentOptimizer'), anno.getanno(attr_call_node, 'type_fqn')) + def test_class_members_in_with_stmt(self): + + def test_fn(x): + with session.Session() as sess: + sess.run(x) + + node = parser.parse_object(test_fn) + node = access.resolve(node) + node = live_values.resolve(node, {'session': session}, {}) + node = type_info.resolve(node, None) + + constructor_call = node.body[0].body[0].items[0].context_expr + self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) + self.assertEquals((session.__name__, 'Session'), + anno.getanno(constructor_call, 'type_fqn')) + + member_call = node.body[0].body[0].body[0].value.func + self.assertEquals((session.__name__, 'Session'), + anno.getanno(member_call, 'type_fqn')) + def test_parameter_class_members(self): def test_fn(opt): |