aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-05 12:58:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-05 13:01:32 -0800
commit3a3feb207d8e138b7a468ae5d6e0d2daf4c8a49c (patch)
tree4fae87c7a7e1b6f4d9fa4430f9055c42f161aefe
parent620c8383123519fcf4d987efb9776d861901ccfa (diff)
Basic templating code.
PiperOrigin-RevId: 180964100
-rw-r--r--tensorflow/contrib/py2tf/pyct/BUILD14
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates.py112
-rw-r--r--tensorflow/contrib/py2tf/pyct/templates_test.py77
3 files changed, 203 insertions, 0 deletions
diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD
index 0601417335..dca380ceb1 100644
--- a/tensorflow/contrib/py2tf/pyct/BUILD
+++ b/tensorflow/contrib/py2tf/pyct/BUILD
@@ -22,6 +22,7 @@ py_library(
"compiler.py",
"parser.py",
"pretty_printer.py",
+ "templates.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
@@ -78,3 +79,16 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "templates_test",
+ srcs = ["templates_test.py"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py
new file mode 100644
index 0000000000..6acc03bfce
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/templates.py
@@ -0,0 +1,112 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST conversion templates.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import parser
+
+
+class ReplaceTransformer(gast.NodeTransformer):
+ """Replace AST nodes."""
+
+ def __init__(self, replacements):
+ """Create a new ReplaceTransformer.
+
+ Args:
+ replacements: A mapping from placeholder names to (lists of) AST nodes
+ that these placeholders will be replaced by.
+ """
+ self.replacements = replacements
+
+ # TODO(mdan): Make a more detailed pass and clean up if needed.
+
+ def visit_Expr(self, node):
+ if (isinstance(node.value, gast.Name) and
+ node.value.id in self.replacements):
+ return self.visit(node.value)
+ self.generic_visit(node)
+ return node
+
+ def visit_FunctionDef(self, node):
+ node = self.generic_visit(node)
+ if node.name in self.replacements:
+ repl = self.replacements[node.name]
+ if not isinstance(repl, (gast.Name, ast.Name)):
+ raise ValueError(
+ 'A function name can only be replaced by a Name node. Found: %s',
+ repl)
+ node.name = repl.id
+ return node
+
+ def visit_Name(self, node):
+ # Note: The caller is reposnsible with making sure the replacement
+ # Name nodes have the proper ctx set up.
+ # TODO(mdan): Is it possible to always infer the proper context here?
+ if node.id in self.replacements:
+ # TODO(mdan): Sanitize the nodes by erasing scope-dependent annotations.
+ new_nodes = self.replacements[node.id]
+ if isinstance(new_nodes, gast.AST):
+ new_nodes = [new_nodes]
+ if len(new_nodes) == 1:
+ new_nodes, = new_nodes
+ return new_nodes
+ else:
+ return node
+
+
+def replace(template, **replacements):
+ """Replace placeholders in a Python template.
+
+ Args:
+ template: A function to be used as a template. Any placeholder is expected
+ to also be a function argument.
+ **replacements: A mapping from placeholder names to (lists of) AST nodes
+ that these placeholders will be replaced by.
+
+ Returns:
+ body: An AST node or list of AST nodes with the replacements made. If the
+ template was a function, a list will be returned. If the template was a
+ node, the same node will be returned. If the template was a string, an
+ AST node will be returned (a `Module` node in the case of a multi-line
+ string, an `Expr` node otherwise).
+
+ Raises:
+ ValueError: If a function is used as a template and an incorrect set of
+ replacements was passed.
+ """
+ tree = parser.parse_object(template).body[0]
+ placeholders = set(arg.id for arg in tree.args.args)
+ tree.args.args = []
+ if tree.args.vararg:
+ placeholders.add(tree.args.vararg)
+ tree.args.vararg = None
+ if set(replacements.keys()) != placeholders:
+ raise ValueError(
+ 'too many or few replacements. replacements: %s; placeholders: %s' %
+ (replacements.keys(), placeholders))
+
+ # Perform the replacement, stripping the function into which the template was
+ # wrapped.
+ return ReplaceTransformer(replacements).visit(tree).body
diff --git a/tensorflow/contrib/py2tf/pyct/templates_test.py b/tensorflow/contrib/py2tf/pyct/templates_test.py
new file mode 100644
index 0000000000..2ad8b9317b
--- /dev/null
+++ b/tensorflow/contrib/py2tf/pyct/templates_test.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for templates module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.python.platform import test
+
+
+class TemplatesTest(test.TestCase):
+
+ def test_replace_variable(self):
+ def template(a): # pylint:disable=unused-argument
+ def test_fn(a): # pylint:disable=unused-variable
+ a += 1
+ a = 2 * a + 1
+ return b # pylint:disable=undefined-variable
+
+ node = templates.replace(
+ template, a=gast.Name('b', gast.Load(), None))[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals(7, result.test_fn(2))
+
+ def test_replace_function_name(self):
+ def template(fname): # pylint:disable=unused-argument
+ def fname(a): # pylint:disable=function-redefined
+ a += 1
+ a = 2 * a + 1
+ return a
+
+ node = templates.replace(
+ template, fname=gast.Name('test_fn', gast.Load(), None))[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals(7, result.test_fn(2))
+
+ def test_code_block(self):
+ def template(block): # pylint:disable=unused-argument
+ def test_fn(a): # pylint:disable=unused-variable
+ block # pylint:disable=pointless-statement
+ return a
+
+ node = templates.replace(
+ template,
+ block=[
+ gast.Assign(
+ [
+ gast.Name('a', gast.Store(), None)
+ ],
+ gast.BinOp(
+ gast.Name('a', gast.Load(), None),
+ gast.Add(),
+ gast.Num(1))),
+ ] * 2)[0]
+ result = compiler.ast_to_object(node)
+ self.assertEquals(3, result.test_fn(1))
+
+
+if __name__ == '__main__':
+ test.main()