diff options
author | Dan Moldovan <mdan@google.com> | 2018-07-30 19:46:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-30 19:49:48 -0700 |
commit | a6572d3d003cf7ef5b0fffd5ad7c5fc86919465c (patch) | |
tree | ab2ab87733ff7d631688de828c336c546605fd61 /tensorflow/contrib/autograph | |
parent | e62ac84640c898d49107d172abb8f7b6f8cb8cd3 (diff) |
Add a crude method to mark to mark converted entities so that dynamic conversion will not attempt to convert them again.
PiperOrigin-RevId: 206691438
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/impl/api.py | 12 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/impl/api_test.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/impl/conversion.py | 18 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/impl/conversion_test.py | 12 |
4 files changed, 44 insertions, 12 deletions
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index ee71f4f9ac..0adff76a9f 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -258,25 +258,27 @@ def to_graph(e, # Avoid overwriting entities that have been transformed. if key not in compiled_module.__dict__: compiled_module.__dict__[key] = val - compiled_fn = getattr(compiled_module, name) + compiled = getattr(compiled_module, name) # Need this so the source_mapping attribute is available for the context # manager to access for runtime errors. # # Note that compiler.ast_to_object attaches the source map 'ag_source_map__' # symbol to the compiled module. + # TODO(mdan): Record this statically in the generated code. + # TODO(mdan): Rename this attribute to 'autograph_info__' source_map_attribute_name = 'ag_source_map' - if getattr(compiled_fn, source_map_attribute_name, None) is not None: + if getattr(compiled, source_map_attribute_name, None) is not None: raise ValueError('cannot convert %s because is has an attribute ' '"%s", which is reserved for AutoGraph.' % - (compiled_fn, source_map_attribute_name)) - setattr(compiled_fn, source_map_attribute_name, + (compiled, source_map_attribute_name)) + setattr(compiled, source_map_attribute_name, compiled_module.__dict__['ag_source_map__']) if verbose: logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src) - return compiled_fn + return compiled def to_code(e, diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 4de7df6572..754baa87b0 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -280,6 +280,20 @@ class ApiTest(test.TestCase): x = tc.test_method() self.assertEqual(1, sess.run(x)) + def test_converted_call_already_converted(self): + + def f(x): + return x == 0 + + with self.test_session() as sess: + x = api.converted_call(f, False, False, {}, constant_op.constant(0)) + self.assertTrue(sess.run(x)) + + converted_f = api.to_graph(f) + x = api.converted_call(converted_f, False, False, {}, + constant_op.constant(0)) + self.assertTrue(sess.run(x)) + def test_to_graph_basic(self): def test_fn(x, s): diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index 57ec739a80..afb10d4d8b 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -48,6 +48,7 @@ from tensorflow.contrib.autograph.pyct import inspect_utils from tensorflow.contrib.autograph.pyct import origin_info from tensorflow.contrib.autograph.pyct import parser from tensorflow.contrib.autograph.pyct import qual_names +from tensorflow.contrib.autograph.pyct import templates from tensorflow.contrib.autograph.pyct import transformer from tensorflow.python.util import tf_inspect @@ -70,6 +71,8 @@ def is_whitelisted_for_graph(o): for prefix, in config.DEFAULT_UNCOMPILED_MODULES: if m.__name__.startswith(prefix): return True + if hasattr(o, 'autograph_info__'): + return True return False @@ -120,7 +123,16 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 'supported for now.' % (o, type(o))) + # TODO(mdan): This is temporary. it should be created using a converter. + # TODO(mdan): The attribute should be added with a helper, not directly. + # The helper can ensure there are no collisions. + template = ''' + entity.autograph_info__ = {} + ''' + node.extend(templates.replace(template, entity=name)) + program_ctx.add_to_cache(o, node) + if program_ctx.recursive: while True: candidate = None @@ -268,18 +280,18 @@ def function_to_graph(f, context = converter.EntityContext(namer, entity_info, program_ctx) node = node_to_graph(node, context, rewrite_errors=rewrite_errors) - # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py + # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) if not did_rename: new_name = f.__name__ if node.name != f.__name__: raise NotImplementedError('Strange corner case. Send us offending code!') - node.name = new_name + program_ctx.update_name_map(namer) # TODO(mdan): Use this at compilation. - return (node,), new_name, namespace + return [node], new_name, namespace def node_to_graph(node, context, rewrite_errors=True): diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index bfc51365a3..1c5d4d09c4 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -61,7 +61,7 @@ class ConversionTest(test.TestCase): program_ctx = self._simple_program_ctx() nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None) - fn_node, = nodes + fn_node, _ = nodes self.assertIsInstance(fn_node, gast.FunctionDef) self.assertEqual('tf__f', name) self.assertIs(ns['b'], b) @@ -115,10 +115,12 @@ class ConversionTest(test.TestCase): self.assertTrue(TestBase in program_ctx.dependency_cache) self.assertTrue(TestSubclass in program_ctx.dependency_cache) + # The returned nodes will include: + # <import nodes>, <class node>, <assignment node> self.assertEqual('TfTestBase', - program_ctx.dependency_cache[TestBase][-1].name) + program_ctx.dependency_cache[TestBase][-2].name) self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass][-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_class_hierarchy_whitelisted(self): @@ -138,8 +140,10 @@ class ConversionTest(test.TestCase): self.assertFalse(training.Model in program_ctx.dependency_cache) self.assertEqual( 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name) + # The returned nodes will include: + # <import nodes>, <class node>, <assignment node> self.assertEqual('TfTestSubclass', - program_ctx.dependency_cache[TestSubclass][-1].name) + program_ctx.dependency_cache[TestSubclass][-2].name) def test_entity_to_graph_lambda(self): f = lambda a: a |