aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/testing/codegen.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/testing/codegen.py')
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen.py234
1 files changed, 234 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/testing/codegen.py b/tensorflow/python/autograph/pyct/testing/codegen.py
new file mode 100644
index 0000000000..78b24390c3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/testing/codegen.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Random code generation for testing/fuzzing."""
+# pylint: disable=invalid-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import string
+
+import gast
+import numpy as np
+
+from tensorflow.python.autograph.pyct import templates
+
+
+class NodeSampler(object):
+ sample_map = None
+
+ def sample(self):
+ nodes, magnitudes = zip(*self.sample_map.items())
+ return np.random.choice(
+ nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes))
+
+
+class StatementSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Assign, 10),
+ (gast.Print, 1),
+ (gast.If, 2),
+ (gast.While, 2),
+ (gast.For, 0),
+ ))
+
+
+class ExpressionSampler(NodeSampler):
+ sample_map = dict((
+ (gast.UnaryOp, 1),
+ (gast.BinOp, 8),
+ (gast.Name, 1),
+ (gast.Call, 0),
+ ))
+
+
+class CompareSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Eq, 1),
+ (gast.NotEq, 1),
+ (gast.Lt, 1),
+ (gast.LtE, 1),
+ (gast.Gt, 1),
+ (gast.GtE, 1),
+ (gast.Is, 1),
+ (gast.IsNot, 1),
+ ))
+
+
+class BinaryOpSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Add, 1),
+ (gast.Sub, 1),
+ (gast.Mult, 1),
+ (gast.Div, 1),
+ (gast.FloorDiv, 1),
+ (gast.Mod, 1),
+ (gast.Pow, 1),
+ ))
+
+
+class UnaryOpSampler(NodeSampler):
+ sample_map = dict(((gast.USub, 1), (gast.UAdd, 0)))
+
+
+class NameSampler(NodeSampler):
+ sample_map = dict((
+ ('new', 1),
+ ('existing', 1),
+ ))
+
+
+N_CONTROLFLOW_STATEMENTS = 10
+N_FUNCTIONDEF_STATEMENTS = 10
+
+
+class CodeGenerator(object):
+ """Generate random syntactically-valid Python ASTs."""
+
+ def __init__(self, max_depth=3, depth=0):
+ self.max_depth = max_depth
+ self.depth = depth
+
+ def generate_statement(self):
+ """Generate a statement node, dispatching to the correct class method."""
+ desired_node = StatementSampler().sample()
+ self.depth += 1
+
+ # Enforce some constraints on generating statements.
+ # E.g., if statements need at least 3 readable variables.
+ # If we fail to satisfy our constraints, draw another sample.
+ if desired_node in (gast.While, gast.For, gast.If):
+ if self.depth > self.max_depth:
+ return self.generate_statement()
+
+ # Go get the generator method and run it
+ method = 'generate_' + desired_node.__name__
+ visitor = getattr(self, method)
+ node = visitor()
+ self.depth -= 1
+ return node
+
+ def sample_node_list(self, low, high, generator):
+ """Generate a list of statements of random length.
+
+ Args:
+ low: Fewest number of statements to generate.
+ high: Highest number of statements to generate.
+ generator: Function to call to generate nodes.
+
+ Returns:
+ A list of statements.
+ """
+ statements = []
+ for _ in range(np.random.randint(low, high)):
+ statements.append(generator())
+ return statements
+
+ def generate_Name(self, ctx=gast.Load()):
+ variable_name = '_' + ''.join(
+ random.choice(string.ascii_lowercase) for _ in range(4))
+ return gast.Name(variable_name, ctx=ctx, annotation=None)
+
+ def generate_BinOp(self):
+ # TODO(alexbw): convert to generate_expression when we get to limit
+ # expression depth.
+ op = BinaryOpSampler().sample()()
+ return gast.BinOp(self.generate_Name(), op, self.generate_Name())
+
+ def generate_Compare(self):
+ op = CompareSampler().sample()()
+ return gast.Compare(self.generate_Name(), [op], [self.generate_Name()])
+
+ def generate_UnaryOp(self):
+ operand = self.generate_Name()
+ op = UnaryOpSampler().sample()()
+ return gast.UnaryOp(op, operand)
+
+ def generate_expression(self):
+ desired_node = ExpressionSampler().sample()
+ # Go get the generator method and run it
+ method = 'generate_' + desired_node.__name__
+ generator = getattr(self, method)
+ return generator()
+
+ def generate_Assign(self):
+ """Generate an Assign node."""
+ # Generate left-hand side
+ target_node = self.generate_Name(gast.Store())
+ # Generate right-hand side
+ value_node = self.generate_expression()
+ # Put it all together
+ node = gast.Assign(targets=[target_node], value=value_node)
+ return node
+
+ def generate_If(self):
+ """Generate an If node."""
+ test = self.generate_Compare()
+
+ # Generate true branch statements
+ body = self.sample_node_list(
+ low=1,
+ high=N_CONTROLFLOW_STATEMENTS // 2,
+ generator=self.generate_statement)
+
+ # Generate false branch statements
+ orelse = self.sample_node_list(
+ low=1,
+ high=N_CONTROLFLOW_STATEMENTS // 2,
+ generator=self.generate_statement)
+
+ node = gast.If(test, body, orelse)
+ return node
+
+ def generate_While(self):
+ """Generate a While node."""
+
+ test = self.generate_Compare()
+ body = self.sample_node_list(
+ low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement)
+ orelse = [] # not generating else statements
+
+ node = gast.While(test, body, orelse)
+ return node
+
+ def generate_Call(self):
+ raise NotImplementedError
+
+ def generate_Return(self):
+ return gast.Return(self.generate_expression())
+
+ def generate_Print(self):
+ return templates.replace('print(x)', x=self.generate_expression())[0]
+
+ def generate_FunctionDef(self):
+ """Generate a FunctionDef node."""
+
+ # Generate the arguments, register them as available
+ arg_vars = self.sample_node_list(
+ low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
+ args = gast.arguments(arg_vars, None, [], [], None, [])
+
+ # Generate the function body
+ body = self.sample_node_list(
+ low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
+ body.append(self.generate_Return())
+ fn_name = self.generate_Name().id
+ node = gast.FunctionDef(fn_name, args, body, (), None)
+ return node
+
+
+def generate_random_functiondef():
+ return CodeGenerator().generate_FunctionDef()