aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-24 06:35:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-24 06:39:01 -0800
commit0220d128c78f4061595a13d40037aebc865239cb (patch)
treeb1912ebf7d693458009b2dc1a5dba157526ad014
parent917136b3bb7d83a1674bb24d3c0b0892ad77e056 (diff)
Use the new inspect_utils API to to get the function's namespace.
PiperOrigin-RevId: 186884307
-rw-r--r--tensorflow/contrib/py2tf/impl/conversion.py22
1 files changed, 7 insertions, 15 deletions
diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py
index 4bf698f207..044de33568 100644
--- a/tensorflow/contrib/py2tf/impl/conversion.py
+++ b/tensorflow/contrib/py2tf/impl/conversion.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import gast
-import six
from tensorflow.contrib.py2tf import utils
from tensorflow.contrib.py2tf.converters import asserts
@@ -36,6 +35,7 @@ from tensorflow.contrib.py2tf.converters import side_effect_guards
from tensorflow.contrib.py2tf.impl import config
from tensorflow.contrib.py2tf.impl import naming
from tensorflow.contrib.py2tf.pyct import context
+from tensorflow.contrib.py2tf.pyct import inspect_utils
from tensorflow.contrib.py2tf.pyct import parser
from tensorflow.contrib.py2tf.pyct import qual_names
from tensorflow.contrib.py2tf.pyct.static_analysis import activity
@@ -155,7 +155,7 @@ def class_to_graph(c, conversion_map):
if not members:
raise ValueError('Cannot convert %s: it has no member methods.')
- class_globals = None
+ class_namespace = None
for _, m in members:
node, _ = function_to_graph(
m,
@@ -164,10 +164,10 @@ def class_to_graph(c, conversion_map):
arg_types={'self': (c.__name__, c)},
owner_type=c)
# TODO(mdan): Do not assume all members have the same view of globals.
- if class_globals is None:
- class_globals = six.get_function_globals(m)
+ if class_namespace is None:
+ class_namespace = inspect_utils.getnamespace(m)
converted_members[m] = node
- namer = conversion_map.new_namer(class_globals)
+ namer = conversion_map.new_namer(class_namespace)
class_name = namer.compiled_class_name(c.__name__, c)
node = gast.ClassDef(
class_name,
@@ -202,19 +202,11 @@ def function_to_graph(f, conversion_map, arg_values, arg_types,
"""Specialization of `entity_to_graph` for callable functions."""
node, source = parser.parse_entity(f)
node = node.body[0]
- namespace = six.get_function_globals(f)
-
- # This is needed for non-global functions.
- closure = six.get_function_closure(f)
- if closure:
- for e in closure:
- if callable(e.cell_contents):
- fn = e.cell_contents
- namespace[fn.__name__] = fn
+ namespace = inspect_utils.getnamespace(f)
_add_self_references(namespace, conversion_map.api_module)
-
namer = conversion_map.new_namer(namespace)
+
ctx = context.EntityContext(
namer=namer,
source_code=source,