aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-07-30 19:46:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 19:49:48 -0700
commita6572d3d003cf7ef5b0fffd5ad7c5fc86919465c (patch)
treeab2ab87733ff7d631688de828c336c546605fd61 /tensorflow/contrib/autograph
parente62ac84640c898d49107d172abb8f7b6f8cb8cd3 (diff)
Add a crude method to mark to mark converted entities so that dynamic conversion will not attempt to convert them again.
PiperOrigin-RevId: 206691438
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/impl/api.py12
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py14
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py18
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py12
4 files changed, 44 insertions, 12 deletions
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index ee71f4f9ac..0adff76a9f 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -258,25 +258,27 @@ def to_graph(e,
# Avoid overwriting entities that have been transformed.
if key not in compiled_module.__dict__:
compiled_module.__dict__[key] = val
- compiled_fn = getattr(compiled_module, name)
+ compiled = getattr(compiled_module, name)
# Need this so the source_mapping attribute is available for the context
# manager to access for runtime errors.
#
# Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
# symbol to the compiled module.
+ # TODO(mdan): Record this statically in the generated code.
+ # TODO(mdan): Rename this attribute to 'autograph_info__'
source_map_attribute_name = 'ag_source_map'
- if getattr(compiled_fn, source_map_attribute_name, None) is not None:
+ if getattr(compiled, source_map_attribute_name, None) is not None:
raise ValueError('cannot convert %s because is has an attribute '
'"%s", which is reserved for AutoGraph.' %
- (compiled_fn, source_map_attribute_name))
- setattr(compiled_fn, source_map_attribute_name,
+ (compiled, source_map_attribute_name))
+ setattr(compiled, source_map_attribute_name,
compiled_module.__dict__['ag_source_map__'])
if verbose:
logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
- return compiled_fn
+ return compiled
def to_code(e,
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 4de7df6572..754baa87b0 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -280,6 +280,20 @@ class ApiTest(test.TestCase):
x = tc.test_method()
self.assertEqual(1, sess.run(x))
+ def test_converted_call_already_converted(self):
+
+ def f(x):
+ return x == 0
+
+ with self.test_session() as sess:
+ x = api.converted_call(f, False, False, {}, constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
+ converted_f = api.to_graph(f)
+ x = api.converted_call(converted_f, False, False, {},
+ constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
def test_to_graph_basic(self):
def test_fn(x, s):
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index 57ec739a80..afb10d4d8b 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -48,6 +48,7 @@ from tensorflow.contrib.autograph.pyct import inspect_utils
from tensorflow.contrib.autograph.pyct import origin_info
from tensorflow.contrib.autograph.pyct import parser
from tensorflow.contrib.autograph.pyct import qual_names
+from tensorflow.contrib.autograph.pyct import templates
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.util import tf_inspect
@@ -70,6 +71,8 @@ def is_whitelisted_for_graph(o):
for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
if m.__name__.startswith(prefix):
return True
+ if hasattr(o, 'autograph_info__'):
+ return True
return False
@@ -120,7 +123,16 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types):
'Entity "%s" has unsupported type "%s". Only functions and classes are '
'supported for now.' % (o, type(o)))
+ # TODO(mdan): This is temporary. it should be created using a converter.
+ # TODO(mdan): The attribute should be added with a helper, not directly.
+ # The helper can ensure there are no collisions.
+ template = '''
+ entity.autograph_info__ = {}
+ '''
+ node.extend(templates.replace(template, entity=name))
+
program_ctx.add_to_cache(o, node)
+
if program_ctx.recursive:
while True:
candidate = None
@@ -268,18 +280,18 @@ def function_to_graph(f,
context = converter.EntityContext(namer, entity_info, program_ctx)
node = node_to_graph(node, context, rewrite_errors=rewrite_errors)
- # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py
+ # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
if not did_rename:
new_name = f.__name__
if node.name != f.__name__:
raise NotImplementedError('Strange corner case. Send us offending code!')
-
node.name = new_name
+
program_ctx.update_name_map(namer)
# TODO(mdan): Use this at compilation.
- return (node,), new_name, namespace
+ return [node], new_name, namespace
def node_to_graph(node, context, rewrite_errors=True):
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index bfc51365a3..1c5d4d09c4 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -61,7 +61,7 @@ class ConversionTest(test.TestCase):
program_ctx = self._simple_program_ctx()
nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
- fn_node, = nodes
+ fn_node, _ = nodes
self.assertIsInstance(fn_node, gast.FunctionDef)
self.assertEqual('tf__f', name)
self.assertIs(ns['b'], b)
@@ -115,10 +115,12 @@ class ConversionTest(test.TestCase):
self.assertTrue(TestBase in program_ctx.dependency_cache)
self.assertTrue(TestSubclass in program_ctx.dependency_cache)
+ # The returned nodes will include:
+ # <import nodes>, <class node>, <assignment node>
self.assertEqual('TfTestBase',
- program_ctx.dependency_cache[TestBase][-1].name)
+ program_ctx.dependency_cache[TestBase][-2].name)
self.assertEqual('TfTestSubclass',
- program_ctx.dependency_cache[TestSubclass][-1].name)
+ program_ctx.dependency_cache[TestSubclass][-2].name)
def test_entity_to_graph_class_hierarchy_whitelisted(self):
@@ -138,8 +140,10 @@ class ConversionTest(test.TestCase):
self.assertFalse(training.Model in program_ctx.dependency_cache)
self.assertEqual(
'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
+ # The returned nodes will include:
+ # <import nodes>, <class node>, <assignment node>
self.assertEqual('TfTestSubclass',
- program_ctx.dependency_cache[TestSubclass][-1].name)
+ program_ctx.dependency_cache[TestSubclass][-2].name)
def test_entity_to_graph_lambda(self):
f = lambda a: a