aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-30 10:01:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-30 10:08:58 -0700
commit1ece2e8e96be2eb39922951619ea99208df93284 (patch)
tree828199666ccb74b3fc315066df4278ebd9903a75 /tensorflow/contrib/autograph
parent0301533e35aac90387ef3aba71add1f2da35e3ef (diff)
Random (usually non-functional) code generation for testing/fuzzing.
Doesn't generate a large space of programs currently, but will spit out combinations of BinOp, Compare, UnaryOp, If and While nodes currently. PiperOrigin-RevId: 206599818
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/pyct/testing/BUILD57
-rw-r--r--tensorflow/contrib/autograph/pyct/testing/codegen.py234
-rw-r--r--tensorflow/contrib/autograph/pyct/testing/codegen_test.py40
3 files changed, 331 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD
new file mode 100644
index 0000000000..b89affdc98
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/testing/BUILD
@@ -0,0 +1,57 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "testing",
+ srcs = [
+ "codegen.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/contrib/autograph/utils",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "codegen_test",
+ size = "large",
+ srcs = ["codegen_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":testing",
+ "//tensorflow/contrib/autograph/pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+# py_test(
+# name = "dataflow_test",
+# size = "large",
+# srcs = ["dataflow_test.py"],
+# srcs_version = "PY2AND3",
+# tags = ["no_windows"],
+# deps = [
+# ":testing",
+# "@gast_archive//:gast",
+# "//tensorflow/contrib/autograph/pyct",
+# "//tensorflow/python:client_testlib",
+# ],
+# )
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen.py b/tensorflow/contrib/autograph/pyct/testing/codegen.py
new file mode 100644
index 0000000000..279e7c09dc
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/testing/codegen.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Random code generation for testing/fuzzing."""
+# pylint: disable=invalid-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import string
+
+import gast
+import numpy as np
+
+from tensorflow.contrib.autograph.pyct import templates
+
+
+class NodeSampler(object):
+ sample_map = None
+
+ def sample(self):
+ nodes, magnitudes = zip(*self.sample_map.items())
+ return np.random.choice(
+ nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes))
+
+
+class StatementSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Assign, 10),
+ (gast.Print, 1),
+ (gast.If, 2),
+ (gast.While, 2),
+ (gast.For, 0),
+ ))
+
+
+class ExpressionSampler(NodeSampler):
+ sample_map = dict((
+ (gast.UnaryOp, 1),
+ (gast.BinOp, 8),
+ (gast.Name, 1),
+ (gast.Call, 0),
+ ))
+
+
+class CompareSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Eq, 1),
+ (gast.NotEq, 1),
+ (gast.Lt, 1),
+ (gast.LtE, 1),
+ (gast.Gt, 1),
+ (gast.GtE, 1),
+ (gast.Is, 1),
+ (gast.IsNot, 1),
+ ))
+
+
+class BinaryOpSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Add, 1),
+ (gast.Sub, 1),
+ (gast.Mult, 1),
+ (gast.Div, 1),
+ (gast.FloorDiv, 1),
+ (gast.Mod, 1),
+ (gast.Pow, 1),
+ ))
+
+
+class UnaryOpSampler(NodeSampler):
+ sample_map = dict(((gast.USub, 1), (gast.UAdd, 0)))
+
+
+class NameSampler(NodeSampler):
+ sample_map = dict((
+ ('new', 1),
+ ('existing', 1),
+ ))
+
+
+N_CONTROLFLOW_STATEMENTS = 10
+N_FUNCTIONDEF_STATEMENTS = 10
+
+
+class CodeGenerator(object):
+ """Generate random syntactically-valid Python ASTs."""
+
+ def __init__(self, max_depth=3, depth=0):
+ self.max_depth = max_depth
+ self.depth = depth
+
+ def generate_statement(self):
+ """Generate a statement node, dispatching to the correct class method."""
+ desired_node = StatementSampler().sample()
+ self.depth += 1
+
+ # Enforce some constraints on generating statements.
+ # E.g., if statements need at least 3 readable variables.
+ # If we fail to satisfy our constraints, draw another sample.
+ if desired_node in (gast.While, gast.For, gast.If):
+ if self.depth > self.max_depth:
+ return self.generate_statement()
+
+ # Go get the generator method and run it
+ method = 'generate_' + desired_node.__name__
+ visitor = getattr(self, method)
+ node = visitor()
+ self.depth -= 1
+ return node
+
+ def sample_node_list(self, low, high, generator):
+ """Generate a list of statements of random length.
+
+ Args:
+ low: Fewest number of statements to generate.
+ high: Highest number of statements to generate.
+ generator: Function to call to generate nodes.
+
+ Returns:
+ A list of statements.
+ """
+ statements = []
+ for _ in range(np.random.randint(low, high)):
+ statements.append(generator())
+ return statements
+
+ def generate_Name(self, ctx=gast.Load()):
+ variable_name = '_' + ''.join(
+ random.choice(string.ascii_lowercase) for _ in range(4))
+ return gast.Name(variable_name, ctx=ctx, annotation=None)
+
+ def generate_BinOp(self):
+ # TODO(alexbw): convert to generate_expression when we get to limit
+ # expression depth.
+ op = BinaryOpSampler().sample()()
+ return gast.BinOp(self.generate_Name(), op, self.generate_Name())
+
+ def generate_Compare(self):
+ op = CompareSampler().sample()()
+ return gast.Compare(self.generate_Name(), [op], [self.generate_Name()])
+
+ def generate_UnaryOp(self):
+ operand = self.generate_Name()
+ op = UnaryOpSampler().sample()()
+ return gast.UnaryOp(op, operand)
+
+ def generate_expression(self):
+ desired_node = ExpressionSampler().sample()
+ # Go get the generator method and run it
+ method = 'generate_' + desired_node.__name__
+ generator = getattr(self, method)
+ return generator()
+
+ def generate_Assign(self):
+ """Generate an Assign node."""
+ # Generate left-hand side
+ target_node = self.generate_Name(gast.Store())
+ # Generate right-hand side
+ value_node = self.generate_expression()
+ # Put it all together
+ node = gast.Assign(targets=[target_node], value=value_node)
+ return node
+
+ def generate_If(self):
+ """Generate an If node."""
+ test = self.generate_Compare()
+
+ # Generate true branch statements
+ body = self.sample_node_list(
+ low=1,
+ high=N_CONTROLFLOW_STATEMENTS // 2,
+ generator=self.generate_statement)
+
+ # Generate false branch statements
+ orelse = self.sample_node_list(
+ low=1,
+ high=N_CONTROLFLOW_STATEMENTS // 2,
+ generator=self.generate_statement)
+
+ node = gast.If(test, body, orelse)
+ return node
+
+ def generate_While(self):
+ """Generate a While node."""
+
+ test = self.generate_Compare()
+ body = self.sample_node_list(
+ low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement)
+ orelse = [] # not generating else statements
+
+ node = gast.While(test, body, orelse)
+ return node
+
+ def generate_Call(self):
+ raise NotImplementedError
+
+ def generate_Return(self):
+ return gast.Return(self.generate_expression())
+
+ def generate_Print(self):
+ return templates.replace('print(x)', x=self.generate_expression())[0]
+
+ def generate_FunctionDef(self):
+ """Generate a FunctionDef node."""
+
+ # Generate the arguments, register them as available
+ arg_vars = self.sample_node_list(
+ low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
+ args = gast.arguments(arg_vars, None, [], [], None, [])
+
+ # Generate the function body
+ body = self.sample_node_list(
+ low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
+ body.append(self.generate_Return())
+ fn_name = self.generate_Name().id
+ node = gast.FunctionDef(fn_name, args, body, (), None)
+ return node
+
+
+def generate_random_functiondef():
+ return CodeGenerator().generate_FunctionDef()
diff --git a/tensorflow/contrib/autograph/pyct/testing/codegen_test.py b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py
new file mode 100644
index 0000000000..255c3b2a2e
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/testing/codegen_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for type_info module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.autograph.pyct import compiler
+from tensorflow.contrib.autograph.pyct.testing import codegen
+from tensorflow.python.platform import test
+
+
+class CodeGenTest(test.TestCase):
+
+ def test_codegen_gens(self):
+ np.random.seed(0)
+ for _ in range(1000):
+ node = codegen.generate_random_functiondef()
+ fn = compiler.ast_to_object(node)
+ self.assertIsNotNone(
+ fn, 'Generated invalid AST that could not convert to source.')
+
+
+if __name__ == '__main__':
+ test.main()