diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-05 12:58:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-05 13:01:32 -0800 |
commit | 3a3feb207d8e138b7a468ae5d6e0d2daf4c8a49c (patch) | |
tree | 4fae87c7a7e1b6f4d9fa4430f9055c42f161aefe | |
parent | 620c8383123519fcf4d987efb9776d861901ccfa (diff) |
Basic templating code.
PiperOrigin-RevId: 180964100
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/templates.py | 112 | ||||
-rw-r--r-- | tensorflow/contrib/py2tf/pyct/templates_test.py | 77 |
3 files changed, 203 insertions, 0 deletions
diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD index 0601417335..dca380ceb1 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/py2tf/pyct/BUILD @@ -22,6 +22,7 @@ py_library( "compiler.py", "parser.py", "pretty_printer.py", + "templates.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], @@ -78,3 +79,16 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "templates_test", + srcs = ["templates_test.py"], + tags = [ + "manual", + "notap", + ], + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py new file mode 100644 index 0000000000..6acc03bfce --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/templates.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AST conversion templates. + +Adapted from Tangent. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast + +import gast + +from tensorflow.contrib.py2tf.pyct import parser + + +class ReplaceTransformer(gast.NodeTransformer): + """Replace AST nodes.""" + + def __init__(self, replacements): + """Create a new ReplaceTransformer. + + Args: + replacements: A mapping from placeholder names to (lists of) AST nodes + that these placeholders will be replaced by. + """ + self.replacements = replacements + + # TODO(mdan): Make a more detailed pass and clean up if needed. + + def visit_Expr(self, node): + if (isinstance(node.value, gast.Name) and + node.value.id in self.replacements): + return self.visit(node.value) + self.generic_visit(node) + return node + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + if node.name in self.replacements: + repl = self.replacements[node.name] + if not isinstance(repl, (gast.Name, ast.Name)): + raise ValueError( + 'A function name can only be replaced by a Name node. Found: %s', + repl) + node.name = repl.id + return node + + def visit_Name(self, node): + # Note: The caller is reposnsible with making sure the replacement + # Name nodes have the proper ctx set up. + # TODO(mdan): Is it possible to always infer the proper context here? + if node.id in self.replacements: + # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations. + new_nodes = self.replacements[node.id] + if isinstance(new_nodes, gast.AST): + new_nodes = [new_nodes] + if len(new_nodes) == 1: + new_nodes, = new_nodes + return new_nodes + else: + return node + + +def replace(template, **replacements): + """Replace placeholders in a Python template. + + Args: + template: A function to be used as a template. Any placeholder is expected + to also be a function argument. + **replacements: A mapping from placeholder names to (lists of) AST nodes + that these placeholders will be replaced by. + + Returns: + body: An AST node or list of AST nodes with the replacements made. If the + template was a function, a list will be returned. If the template was a + node, the same node will be returned. If the template was a string, an + AST node will be returned (a `Module` node in the case of a multi-line + string, an `Expr` node otherwise). + + Raises: + ValueError: If a function is used as a template and an incorrect set of + replacements was passed. + """ + tree = parser.parse_object(template).body[0] + placeholders = set(arg.id for arg in tree.args.args) + tree.args.args = [] + if tree.args.vararg: + placeholders.add(tree.args.vararg) + tree.args.vararg = None + if set(replacements.keys()) != placeholders: + raise ValueError( + 'too many or few replacements. replacements: %s; placeholders: %s' % + (replacements.keys(), placeholders)) + + # Perform the replacement, stripping the function into which the template was + # wrapped. + return ReplaceTransformer(replacements).visit(tree).body diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py new file mode 100644 index 0000000000..2ad8b9317b --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/templates_test.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for templates module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.python.platform import test + + +class TemplatesTest(test.TestCase): + + def test_replace_variable(self): + def template(a): # pylint:disable=unused-argument + def test_fn(a): # pylint:disable=unused-variable + a += 1 + a = 2 * a + 1 + return b # pylint:disable=undefined-variable + + node = templates.replace( + template, a=gast.Name('b', gast.Load(), None))[0] + result = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_replace_function_name(self): + def template(fname): # pylint:disable=unused-argument + def fname(a): # pylint:disable=function-redefined + a += 1 + a = 2 * a + 1 + return a + + node = templates.replace( + template, fname=gast.Name('test_fn', gast.Load(), None))[0] + result = compiler.ast_to_object(node) + self.assertEquals(7, result.test_fn(2)) + + def test_code_block(self): + def template(block): # pylint:disable=unused-argument + def test_fn(a): # pylint:disable=unused-variable + block # pylint:disable=pointless-statement + return a + + node = templates.replace( + template, + block=[ + gast.Assign( + [ + gast.Name('a', gast.Store(), None) + ], + gast.BinOp( + gast.Name('a', gast.Load(), None), + gast.Add(), + gast.Num(1))), + ] * 2)[0] + result = compiler.ast_to_object(node) + self.assertEquals(3, result.test_fn(1)) + + +if __name__ == '__main__': + test.main() |