diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-24 06:35:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-24 06:39:01 -0800 |
commit | 0220d128c78f4061595a13d40037aebc865239cb (patch) | |
tree | b1912ebf7d693458009b2dc1a5dba157526ad014 | |
parent | 917136b3bb7d83a1674bb24d3c0b0892ad77e056 (diff) |
Use the new inspect_utils API to to get the function's namespace.
PiperOrigin-RevId: 186884307
-rw-r--r-- | tensorflow/contrib/py2tf/impl/conversion.py | 22 |
1 files changed, 7 insertions, 15 deletions
diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py index 4bf698f207..044de33568 100644 --- a/tensorflow/contrib/py2tf/impl/conversion.py +++ b/tensorflow/contrib/py2tf/impl/conversion.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import gast -import six from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.converters import asserts @@ -36,6 +35,7 @@ from tensorflow.contrib.py2tf.converters import side_effect_guards from tensorflow.contrib.py2tf.impl import config from tensorflow.contrib.py2tf.impl import naming from tensorflow.contrib.py2tf.pyct import context +from tensorflow.contrib.py2tf.pyct import inspect_utils from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import qual_names from tensorflow.contrib.py2tf.pyct.static_analysis import activity @@ -155,7 +155,7 @@ def class_to_graph(c, conversion_map): if not members: raise ValueError('Cannot convert %s: it has no member methods.') - class_globals = None + class_namespace = None for _, m in members: node, _ = function_to_graph( m, @@ -164,10 +164,10 @@ def class_to_graph(c, conversion_map): arg_types={'self': (c.__name__, c)}, owner_type=c) # TODO(mdan): Do not assume all members have the same view of globals. - if class_globals is None: - class_globals = six.get_function_globals(m) + if class_namespace is None: + class_namespace = inspect_utils.getnamespace(m) converted_members[m] = node - namer = conversion_map.new_namer(class_globals) + namer = conversion_map.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) node = gast.ClassDef( class_name, @@ -202,19 +202,11 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] - namespace = six.get_function_globals(f) - - # This is needed for non-global functions. - closure = six.get_function_closure(f) - if closure: - for e in closure: - if callable(e.cell_contents): - fn = e.cell_contents - namespace[fn.__name__] = fn + namespace = inspect_utils.getnamespace(f) _add_self_references(namespace, conversion_map.api_module) - namer = conversion_map.new_namer(namespace) + ctx = context.EntityContext( namer=namer, source_code=source, |