aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-11 16:20:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 16:32:19 -0700
commit668c079f4e6020131978b7a812c3b92eea9c47b9 (patch)
tree269836fd98f37b3a099e6b4cceeb3256416705fa /tensorflow/python/autograph
parentefd9e0d073a6632f7632f7fe43ae4364cc2c834b (diff)
Move AutoGraph to core. This CL moves the entirety of the code base, keeping the frontend autograph module in contrib for backward compatibility. Certain files, like notebooks and the readme file may be referenced from the outside, so a copy of those is kept as well. In addition, the notebooks subdirectory of examples is also kept in contrib because the extension the build file relies on is not available in the PIP package.
PiperOrigin-RevId: 212543067
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/BUILD31
-rw-r--r--tensorflow/python/autograph/CONTRIBUTING.md104
-rw-r--r--tensorflow/python/autograph/LIMITATIONS.md50
-rw-r--r--tensorflow/python/autograph/README.md143
-rw-r--r--tensorflow/python/autograph/STYLE_GUIDE.md85
-rw-r--r--tensorflow/python/autograph/__init__.py68
-rw-r--r--tensorflow/python/autograph/converters/BUILD249
-rw-r--r--tensorflow/python/autograph/converters/__init__.py32
-rw-r--r--tensorflow/python/autograph/converters/asserts.py49
-rw-r--r--tensorflow/python/autograph/converters/asserts_test.py42
-rw-r--r--tensorflow/python/autograph/converters/break_statements.py146
-rw-r--r--tensorflow/python/autograph/converters/break_statements_test.py137
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions.py65
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions_test.py74
-rw-r--r--tensorflow/python/autograph/converters/call_trees.py330
-rw-r--r--tensorflow/python/autograph/converters/call_trees_test.py138
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions.py129
-rw-r--r--tensorflow/python/autograph/converters/conditional_expressions_test.py53
-rw-r--r--tensorflow/python/autograph/converters/continue_statements.py139
-rw-r--r--tensorflow/python/autograph/converters/continue_statements_test.py94
-rw-r--r--tensorflow/python/autograph/converters/control_flow.py339
-rw-r--r--tensorflow/python/autograph/converters/control_flow_test.py247
-rw-r--r--tensorflow/python/autograph/converters/decorators.py105
-rw-r--r--tensorflow/python/autograph/converters/decorators_test.py152
-rw-r--r--tensorflow/python/autograph/converters/directives.py128
-rw-r--r--tensorflow/python/autograph/converters/directives_test.py95
-rw-r--r--tensorflow/python/autograph/converters/error_handlers.py53
-rw-r--r--tensorflow/python/autograph/converters/error_handlers_test.py59
-rw-r--r--tensorflow/python/autograph/converters/list_comprehensions.py82
-rw-r--r--tensorflow/python/autograph/converters/list_comprehensions_test.py61
-rw-r--r--tensorflow/python/autograph/converters/lists.py239
-rw-r--r--tensorflow/python/autograph/converters/lists_test.py132
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions.py132
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions_test.py61
-rw-r--r--tensorflow/python/autograph/converters/name_scopes.py74
-rw-r--r--tensorflow/python/autograph/converters/name_scopes_test.py101
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py317
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py167
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards.py183
-rw-r--r--tensorflow/python/autograph/converters/side_effect_guards_test.py163
-rw-r--r--tensorflow/python/autograph/converters/slices.py85
-rw-r--r--tensorflow/python/autograph/converters/slices_test.py76
-rw-r--r--tensorflow/python/autograph/core/BUILD75
-rw-r--r--tensorflow/python/autograph/core/config.py49
-rw-r--r--tensorflow/python/autograph/core/converter.py330
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py166
-rw-r--r--tensorflow/python/autograph/core/errors.py258
-rw-r--r--tensorflow/python/autograph/core/errors_test.py105
-rw-r--r--tensorflow/python/autograph/core/naming.py130
-rw-r--r--tensorflow/python/autograph/core/naming_test.py77
-rw-r--r--tensorflow/python/autograph/docs/pyfunc_dtypes.md33
-rw-r--r--tensorflow/python/autograph/impl/BUILD62
-rw-r--r--tensorflow/python/autograph/impl/api.py328
-rw-r--r--tensorflow/python/autograph/impl/api_test.py329
-rw-r--r--tensorflow/python/autograph/impl/conversion.py351
-rw-r--r--tensorflow/python/autograph/impl/conversion_test.py172
-rw-r--r--tensorflow/python/autograph/lang/BUILD40
-rw-r--r--tensorflow/python/autograph/lang/directives.py68
-rw-r--r--tensorflow/python/autograph/lang/special_functions.py96
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py70
-rw-r--r--tensorflow/python/autograph/operators/BUILD84
-rw-r--r--tensorflow/python/autograph/operators/__init__.py55
-rw-r--r--tensorflow/python/autograph/operators/control_flow.py227
-rw-r--r--tensorflow/python/autograph/operators/control_flow_test.py100
-rw-r--r--tensorflow/python/autograph/operators/data_structures.py338
-rw-r--r--tensorflow/python/autograph/operators/data_structures_test.py158
-rw-r--r--tensorflow/python/autograph/operators/dispatch_context.py41
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py225
-rw-r--r--tensorflow/python/autograph/operators/py_builtins_test.py131
-rw-r--r--tensorflow/python/autograph/operators/slices.py142
-rw-r--r--tensorflow/python/autograph/operators/slices_test.py66
-rw-r--r--tensorflow/python/autograph/pyct/BUILD163
-rw-r--r--tensorflow/python/autograph/pyct/__init__.py19
-rw-r--r--tensorflow/python/autograph/pyct/anno.py157
-rw-r--r--tensorflow/python/autograph/pyct/anno_test.py84
-rw-r--r--tensorflow/python/autograph/pyct/ast_util.py313
-rw-r--r--tensorflow/python/autograph/pyct/ast_util_test.py196
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py815
-rw-r--r--tensorflow/python/autograph/pyct/cfg_test.py969
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/BUILD41
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/__init__.py0
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/anf.py424
-rw-r--r--tensorflow/python/autograph/pyct/common_transformers/anf_test.py443
-rw-r--r--tensorflow/python/autograph/pyct/compiler.py141
-rw-r--r--tensorflow/python/autograph/pyct/compiler_test.py108
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py161
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py277
-rw-r--r--tensorflow/python/autograph/pyct/origin_info.py186
-rw-r--r--tensorflow/python/autograph/pyct/origin_info_test.py104
-rw-r--r--tensorflow/python/autograph/pyct/parser.py59
-rw-r--r--tensorflow/python/autograph/pyct/parser_test.py52
-rw-r--r--tensorflow/python/autograph/pyct/pretty_printer.py113
-rw-r--r--tensorflow/python/autograph/pyct/pretty_printer_test.py52
-rw-r--r--tensorflow/python/autograph/pyct/qual_names.py257
-rw-r--r--tensorflow/python/autograph/pyct/qual_names_test.py255
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/BUILD94
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/__init__.py33
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py398
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py508
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/annos.py55
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py137
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values_test.py132
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness.py200
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/liveness_test.py149
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py301
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py263
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info.py213
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/type_info_test.py207
-rw-r--r--tensorflow/python/autograph/pyct/templates.py277
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py213
-rw-r--r--tensorflow/python/autograph/pyct/testing/BUILD48
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen.py234
-rw-r--r--tensorflow/python/autograph/pyct/testing/codegen_test.py40
-rw-r--r--tensorflow/python/autograph/pyct/transformer.py487
-rw-r--r--tensorflow/python/autograph/pyct/transformer_test.py369
-rw-r--r--tensorflow/python/autograph/utils/BUILD114
-rw-r--r--tensorflow/python/autograph/utils/__init__.py29
-rw-r--r--tensorflow/python/autograph/utils/context_managers.py49
-rw-r--r--tensorflow/python/autograph/utils/context_managers_test.py47
-rw-r--r--tensorflow/python/autograph/utils/misc.py50
-rw-r--r--tensorflow/python/autograph/utils/misc_test.py54
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch.py66
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch_test.py75
-rw-r--r--tensorflow/python/autograph/utils/py_func.py131
-rw-r--r--tensorflow/python/autograph/utils/py_func_test.py103
-rw-r--r--tensorflow/python/autograph/utils/tensor_list.py68
-rw-r--r--tensorflow/python/autograph/utils/tensor_list_test.py117
-rw-r--r--tensorflow/python/autograph/utils/tensors.py41
-rw-r--r--tensorflow/python/autograph/utils/tensors_test.py57
-rw-r--r--tensorflow/python/autograph/utils/testing.py35
-rw-r--r--tensorflow/python/autograph/utils/type_check.py33
-rw-r--r--tensorflow/python/autograph/utils/type_check_test.py43
132 files changed, 20374 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/BUILD b/tensorflow/python/autograph/BUILD
new file mode 100644
index 0000000000..3289b447e7
--- /dev/null
+++ b/tensorflow/python/autograph/BUILD
@@ -0,0 +1,31 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "autograph",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/autograph/impl",
+ "//tensorflow/python/autograph/lang",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
+ ],
+)
diff --git a/tensorflow/python/autograph/CONTRIBUTING.md b/tensorflow/python/autograph/CONTRIBUTING.md
new file mode 100644
index 0000000000..1ded5ba5f6
--- /dev/null
+++ b/tensorflow/python/autograph/CONTRIBUTING.md
@@ -0,0 +1,104 @@
+# How to contribute
+
+We'd love to have your patches and contributions! Here are some guidelines. In general, we follow the [TensorFlow contributing guidelines](../../CONTRIBUTING.md), but have some [AutoGraph-specific style guidelines](STYLE_GUIDE.md). More details below.
+
+### Note to active contributors
+
+In preparation for TF 2.0, we moved the code base of AutoGraph from
+`tensorflow/contrib/autograph` to `tensorflow/python/autograph`. The move
+does not impact functionality, and AutoGraph will remain accessible under
+`tensorflow.contrib.autograph` until `tensorflow.contrib` is retired.
+
+When
+
+## TensorFlow Code of Conduct
+Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
+
+## Contributor License Agreement
+
+Contributions to this project must be accompanied by a Contributor License
+Agreement. You (or your employer) retain the copyright to your contribution;
+this simply gives us permission to use and redistribute your contributions as
+part of the project. Head over to <https://cla.developers.google.com/> to see
+your current agreements on file or to sign a new one.
+
+You generally only need to submit a CLA once, so if you've already submitted one
+(even if it was for a different project), you probably don't need to do it
+again.
+
+## Code reviews
+
+All submissions, including submissions by project members, require review. We
+use GitHub pull requests for this purpose. Consult [GitHub
+Help](https://help.github.com/articles/about-pull-requests/) for more
+information on using pull requests.
+
+After a pull request is approved, we merge it. Note our merging process differs
+from GitHub in that we pull and submit the change into an internal version
+control system. This system automatically pushes a git commit to the GitHub
+repository (with credit to the original author) and closes the pull request.
+
+## Style
+
+See the [AutoGraph style guide](STYLE_GUIDE.md).
+
+## Unit tests
+
+Please include unit tests when contributing new features ([example here](converters/continue_statements_test.py)), as they help to a) prove that your code works correctly, and b) guard against future breaking
+changes to lower the maintenance cost.
+It's also helpful to check that any
+changes you propose do not break existing unit tests. You can run tests using the command,
+
+```shell
+bazel test --config=opt --copt=-O3 --copt=-march=native \
+ //tensorflow/contrib/autograph/...
+```
+
+from the root of the `tensorflow` repository. For more details see the [main TensorFlow Contributing File](../../CONTRIBUTING.md)
+
+## Developer info
+
+### Module structure
+
+The graph below describes the dependencies between AutoGraph modules (not to be mistaken with the directory structure for these modules, which is flat):
+
+```dot
+digraph d_modules {
+ autograph [style=filled];
+ converters;
+ core;
+ impl;
+ lang;
+ operators;
+
+ autograph -> impl
+ autograph -> lang
+
+ impl -> converters
+ impl -> core
+ impl -> operators
+
+ lang -> operators
+
+ converters -> core
+ converters -> lang
+}
+```
+
+`autograph` is the sole user-visible module.
+
+A short description of the modules:
+
+ * `autograph`: the main module imported by the user and by the generated code; only contains declarations
+ * `impl`: high level code and the implementation of the api frontend
+ * `core`: base classes for the AutoGraph source code transformation logic; see in particular `converter.py`
+ * `lang`: special user-visible functions that serve as extensions to the Python language
+ * `converters`: collection of source code transformation modules specialized for particular AutoGraph features
+ * `operators`: collection of operators that AutoGraph overloads; these correspond to Python operators as well as Python syntactic structures, like control flow
+
+There are two additional modules, `pyct` and `utils`. These are independent of AutoGraph:
+
+ * `pyct`: a general purpose Python source code transformation library
+ * `utils`: the kitchen sync; deprecated
+
+Note: we have a long term plan to factor out an implementation of `impl` and `converters` that is independent of autograph, into a general purpose Python operator overloading library.
diff --git a/tensorflow/python/autograph/LIMITATIONS.md b/tensorflow/python/autograph/LIMITATIONS.md
new file mode 100644
index 0000000000..d8b1cb7616
--- /dev/null
+++ b/tensorflow/python/autograph/LIMITATIONS.md
@@ -0,0 +1,50 @@
+# Capabilities and Limitations
+
+TF AutoGraph converts Eager Python code into TensorFlow graph-mode code. For example, users write code with `if` and `while` and AutoGraph automatically converts it into the equivalent `tf.cond`, and `tf.while_loop`.
+
+Python is a large language, so hoping to convert arbitrary Python code directly to TF graphs is overly ambitious. However, the Python code written to metaprogram TF graphs is in practice a restricted subset. We aim to support as much of this subset as possible. The table below lays out what we currently handle, what we hope to support, and what we have no plans to support.
+
+# Python Language Support Status
+
+Note: as more complex features in TensorFlow are made more accessible using AutoGraph, we expect to come across use cases that haven't been tried before, some of which might reveal rare bugs. If we do find any such bugs, we may add additional restrictions for the affected configurations, until those bugs are resolved.
+
+ Construct | Supported now? | Plan to support? | Notes
+ :--------- | :--------------: | :----------------: | :-----
+If statement | Yes | | Converts to `tf.cond`. If variables are created in one branch that don’t exist in another, which is inexpressible in TF, we throw a clear error.
+For statement | Yes | | We will specialize `for` loops with unknown and known lengths, as well as for loops over TF datasets. Converts to `tf.while_loop`, with an additional `maximum_iterations` hint, if that is known. Creating variables inside the loop that are used later outside the loop is not supported, as the loop may have no iterations.
+While statement | Yes | | Converts to `tf.while_loop`. Creating variables inside the loop is not supported, as the loop may have no iterations.
+Continue and break | Yes | | Converts to boolean flags and extra predicates in loop tests.
+Composition of control flow | Yes | | Arbitrary composition of `if`, `while`, `for`, `break`, and `continue`, along with other supported language elements, is supported and tested.
+Iterators | Some | Yes | Not all iterators supported, but we plan to support everything that can be desugared, such as `enumerate` and `zip`.
+Multiple return values | Yes | | We desugar them into variables, boolean flags and conditionals so that the function has a single return value at the end, and provide a clear error if we are unable to do so.
+Print expression | Yes | | Wrapped in `PyFunc`, and given proper control dependencies. Optional support for using tf.Log when py_func is undesirable exists.
+Static function calls | Yes | | Non-recursive function calls
+Nested call trees | Yes | | For example, `f` calls `g` which calls `h`, all of which need conversion.
+Recursive function calls | No | Maybe | Based on available support in TF. Currently `function.Defun` is the best candidate, but it is not reentrant.
+Python built-ins | Some | Yes | `print`, `len`, `range`, `xrange`, `int`, `float` are supported, and we plan to support or clearly error on all [Python built-ins](https://docs.python.org/3/library/functions.html).
+List operations | Yes | | We convert list creation, append, pop and indexing to their TF TensorArray equivalents. However, we do need some extra type hints to fully convert correctly. We hope to remove this limitation.
+Function variables | Yes | | e.g. `f_new = f_orig; f_new()`
+Lambda functions | No | Yes | Planned feature.
+Classes | Yes | | Classes can be converted all at once, or method-by-method. Some limitations exist around static and class methods.
+Subclasses | Yes | | Subclassing library objects like tf.keras.Model is also supported.
+Dynamic types | Some | | `o = C1() if foo else C2(); o.bar()`. Some scenarios where types are data-dependent may not be supported. We will raise a meaningful error in that case.
+Dynamic code / exec | No | |
+Reflection | No | |
+Try / Except | No | No | No current sane TF equivalent.
+Global variables | Restricted | | In general, we only support read-only access to arguments or variables defined outside the converted code. A few exceptions include TensorFlow library code.
+Functions with side effects | Some | | Side effects are allowed, under certain circumstances.
+Collections | Some | Yes | We currently support lists. There are currently no TF equivalents of dictionaries or tuples.
+List Comprehensions | Yes | | We desugar `ListComp` into the appropriate combination of `For` and `If` statements. Other comprehensions are currently very low priority.
+Custom context managers | No | Yes | Currently low priority. Left unconverted currently.
+Generators | No | Maybe | Could be achievable using queues; very low priority.
+Assertions | Yes | | As `tf.Assert`
+Deletion | Yes | Maybe | Currently unconverted. If new semanti cs are required for `del`, we are able to add it in.
+Inline imports | No | Yes | For example, `import numpy as np; np.eye(3)`. Currently low priority.
+Async | No | No |
+
+## Extra capabilities
+
+ - We liberally add name scopes to generated functions
+ - Operations get decent default names everywhere (planned)
+ - Statements that have no output values are given correct control dependencies. For example, `for i in range(n): print(i)` will have control dependencies to ensure the `print` statements are executed serially.
+
diff --git a/tensorflow/python/autograph/README.md b/tensorflow/python/autograph/README.md
new file mode 100644
index 0000000000..cc54da4daa
--- /dev/null
+++ b/tensorflow/python/autograph/README.md
@@ -0,0 +1,143 @@
+# AutoGraph
+
+IMPORTANT: AutoGraph is beta software, and under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback! We'd also love contributions ([please see our contributing guidelines](CONTRIBUTING.md) and our [style guide](STYLE_GUIDE.md)).
+
+AutoGraph is a Python to TensorFlow compiler.
+
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops. [Please see this file for which parts of the Python language we currently support](LIMITATIONS.md).
+
+For example, this Python function:
+
+```
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+would be converted to this:
+
+```
+def graph_mode_f(x):
+ with tf.name_scope('f'):
+
+ def if_true():
+ with tf.name_scope('if_true'):
+ x_1, = x,
+ x_1 = tf.negative(x_1)
+ return x_1,
+
+ def if_false():
+ with tf.name_scope('if_false'):
+ x_1, = x,
+ return x_1,
+ x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
+ return x
+```
+
+so you can use it like an op:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1.0)
+
+ converted_f = autograph.to_graph(f)
+ y = converted_f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+# Getting started
+
+Use AutoGraph in one of the following ways, described below:
+
+ 1. Annotations (simpler)
+ 2. Functional API (more flexible)
+
+To get started, install the latest nightly TensorFlow build:
+
+```shell
+pip install -U tf-nightly
+```
+
+Then import the `autograph` module from `tf.contrib`:
+
+```
+from tensorflow.contrib import autograph as ag
+```
+
+### Related links
+
+Articles:
+
+ * [TensorFlow blog post](https://medium.com/tensorflow/autograph-converts-python-into-tensorflow-graphs-b2a871f87ec7)
+
+Interactive notebooks:
+
+ * [Quick guide](https://colab.research.google.com/github/tensorflow/models/blob/master/samples/core/guide/autograph.ipynb)
+ * [RNN trained using Keras and Estimators](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/rnn_keras_estimator.ipynb)
+ * [Demo from the TF Dev Summit 2018](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb)
+ * [Basic control flow speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_collatz_speed_test.ipynb)
+ * [MNIST training speed test](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/ag_vs_eager_mnist_speed_test.ipynb)
+ * [Basic algorithm samples](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/algorithms.ipynb)
+ * [Introductory workshop support notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/workshop.ipynb)
+
+## Using with annotations
+
+Annotating a function or class with `@convert` converts it in place:
+
+```
+@ag.convert()
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+... so that it always outputs TensorFlow code:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1)
+
+ y = f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+## Using the functional API
+
+The functional API allows you to convert an existing function, class or object after it was defined:
+
+```
+converted_f = ag.to_graph(f)
+
+print(converted_f(tf.constant(-1)))
+# Output: Tensor
+
+print(f(-1))
+# Output: 1
+```
+
+You can use the functional API to inspect the generated code as well:
+
+```
+print(ag.to_code(f))
+# Output: <Python and TensorFlow code>
+```
+
+## Filing bugs and feature requests
+
+### Reporting a bug
+
+ - If AutoGraph-generated code is compiling and running, but producing an incorrect result, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is compiling, but not running, send us a minimal reproduction case that includes the original Eager code, the inputs and if possible, the outputs or the error message.
+ - If AutoGraph-generated code is not compiling, send us two minimal pieces of code. First, the Eager code that you would like to write, and second, the Graph code that you would like AutoGraph to have generated for you.
+
+### Requesting a feature
+
+If you’d like AutoGraph to convert a feature of Python or TF that we currently don’t handle, please let us know by filing a bug. We’ll make it as easy as possible to interact with us through there.
diff --git a/tensorflow/python/autograph/STYLE_GUIDE.md b/tensorflow/python/autograph/STYLE_GUIDE.md
new file mode 100644
index 0000000000..7e6b0cc27d
--- /dev/null
+++ b/tensorflow/python/autograph/STYLE_GUIDE.md
@@ -0,0 +1,85 @@
+# AutoGraph Style Guide
+
+This page contains style decisions that developers should follow when
+contributing code to AutoGraph.
+
+## TensorFlow Style
+
+Follow the [TensorFlow style
+guide](https://www.tensorflow.org/community/style_guide), the [documentation
+guide](https://www.tensorflow.org/community/documentation) and the
+[Google Python style guide](https://google.github.io/styleguide/pyguide.html).
+
+Naming conventions:
+
+1. The name is TensorFlow, not Tensorflow.
+2. The name is AutoGraph, not Autograph.
+
+## AutoGraph Style
+
+Below are AutoGraph-specific conventions. In the event of conflict,
+it supercedes all previous conventions.
+
+1. __Types in docstrings.__ Use [PEP 484][https://www.python.org/dev/peps/pep-0484/]
+ notation to describe the type for args, return values and attributes.
+
+ Example:
+
+ ```
+ Args:
+ foo: Dict[str, List[int]], a dictionary of sorts
+ ```
+
+2. __Citations in Docstrings.__ Write a `#### References` subsection at the
+ bottom of any docstring with citations. Use ICLR’s bibliography style to
+ write references; for example, order entries by the first author's last
+ name. Add a link to the paper if the publication is open source (ideally,
+ arXiv).
+
+ Write in-paragraph citations in general, e.g., [(Tran and Blei, 2018)][1].
+ Write in-text citations when the citation is a noun, e.g., [Tran and Blei
+ (2018)][1]. Write citations with more than two authors using et al., e.g.,
+ [(Tran et al., 2018)][1]. Separate multiple citations with semicolon, e.g.,
+ ([Tran and Blei, 2018][1]; [Gelman and Rubin, 1992][2]).
+
+ Examples:
+
+ ```none
+ #### References
+
+ # technical report
+ [1]: Tony Finch. Incremental calculation of weighted mean and variance.
+ _Technical Report_, 2009.
+ http://people.ds.cam.ac.uk/fanf2/hermes/doc/antiforgery/stats.pdf
+
+ # journal
+ [2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation
+ Using Multiple Sequences. _Statistical Science_, 7(4):457-472, 1992.
+
+ # arXiv preprint
+ # use "et al." for papers with too many authors to maintain
+ [3]: Aaron van den Oord et al. Parallel WaveNet: Fast High-Fidelity Speech
+ Synthesis. _arXiv preprint arXiv:1711.10433_, 2017.
+ https://arxiv.org/abs/1711.10433
+
+ # conference
+ [4]: Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, and Roger Grosse.
+ Flipout: Efficient Pseudo-Independent Weight Perturbations on
+ Mini-Batches. In _International Conference on Learning
+ Representations_, 2018.
+ https://arxiv.org/abs/1803.04386
+ ```
+
+3. Avoid LaTeX in docstrings.
+
+ * It is not rendered in many (if not most) editors and can be hard to read
+ for both LaTeX experts and non-experts.
+
+4. Write docstring and comment math using ASCII friendly notation; python using
+ operators. E.g., `x**2` better than `x^2`, `x[i, j]` better than `x_{i,j}`,
+ `sum{ f(x[i]) : i=1...n }` better than `\sum_{i=1}^n f(x_i)` `int{sin(x) dx:
+ x in [0, 2 pi]}` better than `\int_0^{2\pi} sin(x) dx`.
+
+ * The more we stick to python style, the more someone can
+ copy/paste/execute.
+ * Python style is usually easier to read as ASCII.
diff --git a/tensorflow/python/autograph/__init__.py b/tensorflow/python/autograph/__init__.py
new file mode 100644
index 0000000000..c3448e6e58
--- /dev/null
+++ b/tensorflow/python/autograph/__init__.py
@@ -0,0 +1,68 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Autograph compiles Python code into equivalent TensorFlow code.
+
+Equivalent here means that they have the same effect when executed.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Bring only the relevant symbols to the top level.
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core.errors import GraphConstructionError
+from tensorflow.python.autograph.core.errors import TfRuntimeError
+from tensorflow.python.autograph.core.errors import improved_errors
+from tensorflow.python.autograph.impl.api import RunMode
+from tensorflow.python.autograph.impl.api import convert
+from tensorflow.python.autograph.impl.api import converted_call
+from tensorflow.python.autograph.impl.api import do_not_convert
+from tensorflow.python.autograph.impl.api import to_code
+from tensorflow.python.autograph.impl.api import to_graph
+from tensorflow.python.autograph.lang.directives import set_element_type
+from tensorflow.python.autograph.lang.directives import set_loop_options
+from tensorflow.python.autograph.lang.special_functions import stack
+from tensorflow.python.autograph.lang.special_functions import tensor_list
+from tensorflow.python.autograph.pyct.transformer import AutographParseError
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ # Main API
+ 'RunMode',
+ 'convert',
+ 'converted_call',
+ 'do_not_convert',
+ 'to_code',
+ 'to_graph',
+ # Overloaded operators
+ 'operators',
+ # Errors
+ 'improved_errors',
+ 'GraphConstructionError',
+ 'TfRuntimeError',
+ # Python language "extensions"
+ 'set_element_type',
+ 'set_loop_options',
+ 'stack',
+ 'tensor_list',
+ # Exceptions
+ 'AutographParseError',
+ # Utilities: to be removed
+ 'utils',
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
new file mode 100644
index 0000000000..7b029de8ed
--- /dev/null
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -0,0 +1,249 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "converters",
+ srcs = [
+ "asserts.py",
+ "break_statements.py",
+ "builtin_functions.py",
+ "call_trees.py",
+ "conditional_expressions.py",
+ "continue_statements.py",
+ "control_flow.py",
+ "decorators.py",
+ "directives.py",
+ "error_handlers.py",
+ "list_comprehensions.py",
+ "lists.py",
+ "logical_expressions.py",
+ "name_scopes.py",
+ "return_statements.py",
+ "side_effect_guards.py",
+ "slices.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/autograph/core",
+ "//tensorflow/python/autograph/lang",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "asserts_test",
+ srcs = ["asserts_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "break_statements_test",
+ srcs = ["break_statements_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "builtin_functions_test",
+ srcs = ["builtin_functions_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "call_trees_test",
+ size = "large",
+ srcs = ["call_trees_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/impl",
+ ],
+)
+
+py_test(
+ name = "conditional_expressions_test",
+ srcs = ["conditional_expressions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "continue_statements_test",
+ srcs = ["continue_statements_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "control_flow_test",
+ srcs = ["control_flow_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "decorators_test",
+ srcs = ["decorators_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "directives_test",
+ srcs = ["directives_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/lang",
+ ],
+)
+
+py_test(
+ name = "name_scopes_test",
+ srcs = ["name_scopes_test.py"],
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "list_comprehensions_test",
+ srcs = ["list_comprehensions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "lists_test",
+ srcs = ["lists_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "logical_expressions_test",
+ srcs = ["logical_expressions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "side_effect_guards_test",
+ srcs = ["side_effect_guards_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ ],
+)
+
+py_test(
+ name = "return_statements_test",
+ srcs = ["return_statements_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "error_handlers_test",
+ srcs = ["error_handlers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "slices_test",
+ srcs = ["slices_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":converters",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/core:test_lib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
diff --git a/tensorflow/python/autograph/converters/__init__.py b/tensorflow/python/autograph/converters/__init__.py
new file mode 100644
index 0000000000..6325ac78dc
--- /dev/null
+++ b/tensorflow/python/autograph/converters/__init__.py
@@ -0,0 +1,32 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Code converters used by Autograph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Naming conventions:
+# * each converter should specialize on a single idiom; be consistent with
+# the Python reference for naming
+# * all converters inherit core.converter.Base
+# * module names describe the idiom that the converter covers, plural
+# * the converter class is named consistent with the module, singular and
+# includes the word Transformer
+#
+# Example:
+#
+# lists.py
+# class ListTransformer(converter.Base)
diff --git a/tensorflow/python/autograph/converters/asserts.py b/tensorflow/python/autograph/converters/asserts.py
new file mode 100644
index 0000000000..56a97534c4
--- /dev/null
+++ b/tensorflow/python/autograph/converters/asserts.py
@@ -0,0 +1,49 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts assert statements to their corresponding TF calls."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
+
+
+class AssertTransformer(converter.Base):
+ """Transforms Assert nodes to Call so they can be handled as functions."""
+
+ def visit_Assert(self, node):
+ self.generic_visit(node)
+
+ # Note: The lone tf.Assert call will be wrapped with control_dependencies
+ # by side_effect_guards.
+ template = """
+ tf.Assert(test, (msg,))
+ """
+
+ if node.msg is None:
+ return templates.replace(
+ template, test=node.test, msg=gast.Str('Assertion error'))
+ elif isinstance(node.msg, gast.Str):
+ return templates.replace(template, test=node.test, msg=node.msg)
+ else:
+ raise NotImplementedError('can only convert string messages for now.')
+
+
+def transform(node, ctx):
+ return AssertTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
new file mode 100644
index 0000000000..01282f9e62
--- /dev/null
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -0,0 +1,42 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for asserts module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.platform import test
+
+
+class AssertsTest(converter_testing.TestCase):
+
+ def test_transform(self):
+
+ def test_fn(a):
+ assert a > 0
+
+ node, ctx = self.prepare(test_fn, {})
+ node = asserts.transform(node, ctx)
+
+ self.assertTrue(isinstance(node.body[0].value, gast.Call))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py
new file mode 100644
index 0000000000..bd6b0b248c
--- /dev/null
+++ b/tensorflow/python/autograph/converters/break_statements.py
@@ -0,0 +1,146 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Lowers break statements to conditionals."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+class _Break(object):
+
+ def __init__(self):
+ self.used = False
+ self.control_var_name = None
+
+ def __repr__(self):
+ return 'used: %s, var: %s' % (self.used, self.control_var_name)
+
+
+class BreakTransformer(converter.Base):
+ """Canonicalizes break statements into additional conditionals."""
+
+ def visit_Break(self, node):
+ self.state[_Break].used = True
+ var_name = self.state[_Break].control_var_name
+ # TODO(mdan): This will fail when expanded inside a top-level else block.
+ template = """
+ var_name = tf.constant(True)
+ continue
+ """
+ return templates.replace(template, var_name=var_name)
+
+ def _guard_if_present(self, block, var_name):
+ """Prevents the block from executing if var_name is set."""
+ if not block:
+ return block
+
+ template = """
+ if not var_name:
+ block
+ """
+ node = templates.replace(
+ template,
+ var_name=var_name,
+ block=block)
+ return node
+
+ def _process_body(self, nodes, break_var):
+ self.state[_Break].enter()
+ self.state[_Break].control_var_name = break_var
+ nodes = self.visit_block(nodes)
+ break_used = self.state[_Break].used
+ self.state[_Break].exit()
+ return nodes, break_used
+
+ def visit_While(self, node):
+ scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
+
+ node.test = self.visit(node.test)
+ node.body, break_used = self._process_body(node.body, break_var)
+ # A break in the else clause applies to the containing scope.
+ node.orelse = self.visit_block(node.orelse)
+
+ if break_used:
+ # Python's else clause only triggers if the loop exited cleanly (e.g.
+ # break did not trigger).
+ guarded_orelse = self._guard_if_present(node.orelse, break_var)
+
+ template = """
+ var_name = tf.constant(False)
+ while test and not var_name:
+ body
+ else:
+ orelse
+ """
+ node = templates.replace(
+ template,
+ var_name=break_var,
+ test=node.test,
+ body=node.body,
+ orelse=guarded_orelse)
+
+ return node
+
+ def visit_For(self, node):
+ scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
+
+ node.target = self.visit(node.target)
+ node.iter = self.visit(node.iter)
+ node.body, break_used = self._process_body(node.body, break_var)
+ # A break in the else clause applies to the containing scope.
+ node.orelse = self.visit_block(node.orelse)
+
+ if break_used:
+ # Python's else clause only triggers if the loop exited cleanly (e.g.
+ # break did not trigger).
+ guarded_orelse = self._guard_if_present(node.orelse, break_var)
+ extra_test = templates.replace_as_expression(
+ 'not var_name', var_name=break_var)
+
+ # The extra test is hidden in the AST, which will confuse the static
+ # analysis. To mitigate that, we insert a no-op statement that ensures
+ # the control variable is marked as used.
+ # TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
+ template = """
+ var_name = tf.constant(False)
+ for target in iter_:
+ (var_name,)
+ body
+ else:
+ orelse
+ """
+ node = templates.replace(
+ template,
+ var_name=break_var,
+ iter_=node.iter,
+ target=node.target,
+ body=node.body,
+ orelse=guarded_orelse)
+
+ anno.setanno(node[1], 'extra_test', extra_test)
+
+ return node
+
+
+def transform(node, ctx):
+ return BreakTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
new file mode 100644
index 0000000000..39406a969d
--- /dev/null
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -0,0 +1,137 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for break_statements module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.eager import context as tfe_ctx
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class BreakCanonicalizationTest(converter_testing.TestCase):
+
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ with self.converted(test_fn, break_statements, {},
+ constant_op.constant) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_while_loop(self):
+
+ def test_fn(x):
+ v = []
+ while x > 0:
+ x -= 1
+ if x % 2 == 0:
+ break
+ v.append(x)
+ return v
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 4)
+
+ def test_for_loop(self):
+
+ def test_fn(a):
+ v = []
+ for x in a:
+ x -= 1
+ if x % 2 == 0:
+ break
+ v.append(x)
+ return v
+
+ with self.converted(test_fn, break_statements, {},
+ constant_op.constant) as result:
+ # The break is incompletely canonicalized. The loop will not interrupt,
+ # but the section following the break will be skipped.
+ self.assertEqual([3], result.test_fn([5, 4]))
+
+ def test_nested(self):
+
+ def test_fn(x):
+ v = []
+ u = []
+ w = []
+ while x > 0:
+ x -= 1
+ if x % 2 == 0:
+ if x % 3 != 0:
+ u.append(x)
+ else:
+ w.append(x)
+ break
+ v.append(x)
+ return v, u, w
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 11)
+
+ def test_nested_loops(self):
+
+ def test_fn(x):
+ v = []
+ u = []
+ while x > 0:
+ x -= 1
+ y = x
+ while y > 0:
+ y -= 1
+ if y % 2 == 0:
+ break
+ u.append(y)
+ if x == 0:
+ break
+ v.append(x)
+ return v, u
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 5)
+
+ def test_loop_orelse(self):
+
+ def test_fn(x):
+ v = []
+ u = []
+ while x > 0:
+ x -= 1
+ y = x
+ while y > 1:
+ break
+ else:
+ u.append(y)
+ break
+ v.append(x)
+ return v, u
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
new file mode 100644
index 0000000000..b8b268d8ce
--- /dev/null
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -0,0 +1,65 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles builtins and other special functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+
+
+class BuiltinFunctionTransformer(converter.Base):
+ """Handles builtin functions.
+
+ This transformer only covers functions that are translated into a
+ TF equivalent, like `len`.
+ """
+
+ def _convert_builtin(self, f, args, as_expression):
+ template = """
+ ag__.func(args)
+ """
+ if as_expression:
+ return templates.replace_as_expression(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
+ else:
+ return templates.replace(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
+
+ def visit_Call(self, node):
+ node = self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ live_val = anno.getanno(node.func, 'live_val')
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
+ return node
+
+ def visit_Print(self, node):
+ node = self.generic_visit(node)
+ args = node.values
+ # Following is the case when calling print(a, b)
+ if len(args) == 1 and isinstance(args[0], gast.Tuple):
+ args = args[0].elts
+ return self._convert_builtin(print, args, as_expression=False)
+
+
+def transform(node, ctx):
+ return BuiltinFunctionTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
new file mode 100644
index 0000000000..c87c304cdb
--- /dev/null
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -0,0 +1,74 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for builtin_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class BuiltinFunctionsTest(converter_testing.TestCase):
+
+ def test_len(self):
+
+ def test_fn(a):
+ return len(a)
+
+ with self.converted(test_fn, builtin_functions, {'len': len}) as result:
+ with self.cached_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ ops = result.test_fn(p)
+ self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
+
+ def test_print(self):
+
+ if six.PY2:
+ return
+
+ def test_fn(a):
+ return print(a)
+
+ with self.converted(test_fn, builtin_functions, {'print': print}) as result:
+ with self.cached_session() as sess:
+ with self.assertPrints('a\n'):
+ sess.run(result.test_fn('a'))
+
+ def test_print_multiple_values(self):
+
+ if six.PY2:
+ return
+
+ def test_fn(a, b, c):
+ return print(a, b, c)
+
+ with self.converted(test_fn, builtin_functions, {'print': print}) as result:
+ with self.cached_session() as sess:
+ with self.assertPrints('a 1 [2, 3]\n'):
+ sess.run(
+ result.test_fn(
+ constant_op.constant('a'), constant_op.constant(1), [2, 3]))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
new file mode 100644
index 0000000000..6a606c450d
--- /dev/null
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -0,0 +1,330 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles function calls, by generating compiled function names and calls.
+
+Note: this transformer does not rename the top level object being converted;
+that is the caller's responsibility.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.util import tf_inspect
+
+
+class FunctionInfo(namedtuple('FunctionInfo', ('dtype',))):
+ pass
+
+
+# TODO(mdan): Move this to config.py.
+KNOWN_NUMPY_FUNCTIONS = {
+ ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'),
+}
+
+
+# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer.
+
+
+class FunctionNamer(object):
+ """Describes the interface for CallTreeTransformer's namer."""
+
+ def compiled_function_name(self,
+ original_fqn,
+ live_entity=None,
+ owner_type=None):
+ """Generate the name corresponding to the compiled version of a function.
+
+ Args:
+ original_fqn: string or tuple(string)
+ live_entity: Callable, the actual target function, if known.
+ owner_type: Optional object. If present, it indicates that the function is
+ a member of the given type.
+ Returns:
+ string, bool
+ """
+ raise NotImplementedError()
+
+ def compiled_class_name(self, original_fqn, live_entity=None):
+ """Generate the name corresponding to the compiled version of a class.
+
+ Args:
+ original_fqn: string or tuple(string)
+ live_entity: The actual target class, if known.
+ Returns:
+ string
+ """
+ raise NotImplementedError()
+
+
+# TODO(mdan): Rename to CallsTransformer.
+
+
+class CallTreeTransformer(converter.Base):
+ """Transforms the call tree by renaming transformed symbols."""
+
+ def _resolve_name(self, node):
+ """Used to resolve decorator info."""
+ if isinstance(node, gast.Call):
+ return self._resolve_name(node.func)
+ if isinstance(node, gast.Name):
+ return self.ctx.namespace.get(node.id)
+ if isinstance(node, gast.Attribute):
+ parent = self._resolve_name(node.value)
+ if parent is not None:
+ return getattr(parent, node.attr)
+ return None
+ raise ValueError(node)
+
+ def _try_resolve_target(self, node):
+ """Works for methods of objects of known type."""
+ if anno.hasanno(node, 'live_val'):
+ return anno.getanno(node, 'live_val')
+ if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
+ owner_type = anno.getanno(node, 'type')
+ if hasattr(owner_type, node.attr):
+ return getattr(owner_type, node.attr)
+ else:
+ raise ValueError('Type "%s" has not attribute "%s". Is it dynamic?' %
+ (owner_type, node.attr))
+ return None
+
+ def _function_is_compilable(self, target_entity):
+ """Determines whether an entity can be compiled at all."""
+ # TODO(mdan): This is just a placeholder. Implement.
+ return not inspect_utils.isbuiltin(target_entity)
+
+ def _should_compile(self, node, fqn):
+ """Determines whether an entity should be compiled in the context."""
+ # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
+ module_name = fqn[0]
+ for mod in self.ctx.program.uncompiled_modules:
+ if module_name.startswith(mod[0] + '.'):
+ return False
+
+ for i in range(1, len(fqn)):
+ if fqn[:i] in self.ctx.program.uncompiled_modules:
+ return False
+
+ # Check for local decorations
+ if anno.hasanno(node, 'graph_ready'):
+ return False
+
+ # The decorators themselves are not to be converted.
+ # If present, the decorators should appear as static functions.
+ target_entity = self._try_resolve_target(node.func)
+ if target_entity is not None:
+ # This attribute is set by the decorator itself.
+ # TODO(mdan): This may not play nicely with other wrapping decorators.
+ if hasattr(target_entity, '__pyct_is_compile_decorator'):
+ return False
+
+ if target_entity in self.ctx.program.autograph_decorators:
+ return False
+
+ # Inspect the target function decorators. If any include a @convert
+ # or @graph_ready annotation, then they must be called as they are.
+ # TODO(mdan): This may be quite heavy.
+ # To parse and re-analyze each function for every call site could be quite
+ # wasteful. Maybe we could cache the parsed AST?
+ try:
+ target_node, _ = parser.parse_entity(target_entity)
+ target_node = target_node.body[0]
+ except TypeError:
+ # Functions whose source we cannot access are compilable (e.g. wrapped
+ # to py_func).
+ return True
+
+ for dec in target_node.decorator_list:
+ decorator_fn = self._resolve_name(dec)
+ if (decorator_fn is not None and
+ decorator_fn in self.ctx.program.autograph_decorators):
+ return False
+
+ return True
+
+ def _rename_compilable_function(self, node):
+ assert anno.hasanno(node.func, 'live_val')
+ assert anno.hasanno(node.func, 'fqn')
+ target_entity = anno.getanno(node.func, 'live_val')
+ target_fqn = anno.getanno(node.func, 'fqn')
+
+ if not self._should_compile(node, target_fqn):
+ return node
+
+ if anno.hasanno(node, 'is_constructor'):
+ new_name = self.ctx.namer.compiled_class_name(
+ target_fqn, live_entity=target_entity)
+ do_rename = True
+ else:
+ if anno.hasanno(node.func, 'parent_type'):
+ owner_type = anno.getanno(node.func, 'parent_type')
+ else:
+ # Fallback - not reliable.
+ owner_type = inspect_utils.getmethodclass(target_entity)
+ new_name, do_rename = self.ctx.namer.compiled_function_name(
+ target_fqn, live_entity=target_entity, owner_type=owner_type)
+
+ if do_rename:
+ if target_entity is not None:
+ if tf_inspect.ismethod(target_entity):
+ # The renaming process will transform it into a regular function.
+ # TODO(mdan): Is this complete? How does it work with nested members?
+ node.args = [node.func.value] + node.args
+ node.func = templates.replace('func_name', func_name=new_name)[0]
+ return node
+
+ def _wrap_to_py_func_no_return(self, node):
+ # TODO(mdan): Properly handle varargs, etc.
+ template = """
+ ag__.utils.wrap_py_func(func, None, (args,), kwargs, True)
+ """
+ return templates.replace(
+ template,
+ func=node.func,
+ args=node.args,
+ kwargs=ast_util.keywords_to_dict(node.keywords))
+
+ def _wrap_to_py_func_single_return(self, node, dtype):
+ # TODO(mdan): Properly handle varargs, etc.
+ template = """
+ ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
+ """
+ return templates.replace_as_expression(
+ template,
+ func=node.func,
+ dtype=parser.parse_expression(dtype),
+ args=node.args,
+ kwargs=ast_util.keywords_to_dict(node.keywords))
+
+ def _insert_dynamic_conversion(self, node):
+ """Inlines a dynamic conversion for a dynamic function."""
+ # TODO(mdan): Pass information on the statically compiled functions.
+ # Having access to the statically compiled functions can help avoid
+ # unnecessary compilation.
+ # For example, this would lead to function `a` being compiled twice:
+ #
+ # def a():
+ # v = b
+ # b()
+ # def b():
+ # a()
+ #
+ # This is really a problem with recursive calls, which currently can
+ # only be gated by a static condition, and should be rare.
+ # TODO(mdan): It probably makes sense to use dynamic conversion every time.
+ # Before we could convert all the time though, we'd need a reasonable
+ # caching mechanism.
+ template = """
+ ag__.converted_call(func, True, False, False, {}, args)
+ """
+ call_expr = templates.replace(template, func=node.func, args=node.args)
+ new_call = call_expr[0].value
+ # TODO(mdan): Improve the template mechanism to better support this.
+ new_call.keywords = node.keywords
+ return new_call
+
+ def visit_Expr(self, node):
+ if isinstance(node.value, gast.Call):
+ if anno.hasanno(node.value.func, 'live_val'):
+ target_entity = anno.getanno(node.value.func, 'live_val')
+ if not self._function_is_compilable(target_entity):
+ if anno.hasanno(node.value.func, 'fqn'):
+ target_fqn = anno.getanno(node.value.func, 'fqn')
+ if not self._should_compile(node.value, target_fqn):
+ return node
+ node = self._wrap_to_py_func_no_return(node.value)
+ return node
+ # Only the case of py_func with no return value is special.
+ # Everything else is processed by visit_Call.
+ self.visit(node.value)
+ else:
+ self.generic_visit(node)
+ return node
+
+ def visit_Call(self, node):
+ # If the function call is wrapped by one of the marker decorators,
+ # consider it graph ready.
+ if anno.hasanno(node.func, 'live_val'):
+ target_entity = anno.getanno(node.func, 'live_val')
+ if target_entity in self.ctx.program.autograph_decorators:
+ if len(node.args) < 1:
+ raise ValueError(
+ 'Found call to decorator function "%s", but it had no arguments. '
+ 'A decorator needs at least one positional argument.' %
+ target_entity)
+ anno.setanno(node.args[0], 'graph_ready', True)
+
+ self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ target_entity = anno.getanno(node.func, 'live_val')
+ if anno.hasanno(node.func, 'fqn'):
+ target_fqn = anno.getanno(node.func, 'fqn')
+ else:
+ target_fqn = None
+ if self._function_is_compilable(target_entity):
+ node = self._rename_compilable_function(node)
+ elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
+ # TODO(mdan): Should we replace these with equivalent TF ops instead?
+ node = self._wrap_to_py_func_single_return(
+ node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
+ else:
+ raise NotImplementedError(
+ 'py_func with return values (unknown function)')
+ else:
+ if anno.hasanno(node.func, anno.Basic.QN):
+ # Special-case a few builtins that otherwise go undetected. This
+ # normally doesn't pose a problem, but the dict built-in doesn't
+ # work with inspect.getargspec which is required for dynamic functions.
+ # Note: expecting this is resilient to aliasing (e.g.
+ # dict = an_evil_dict), because in those cases the regular mechanisms
+ # process a simple user function.
+ qn = anno.getanno(node.func, anno.Basic.QN)
+ # Add items to this list as needed.
+ if str(qn) in ('dict',):
+ return node
+
+ if ast_util.matches(node, 'super(_)'):
+ # super() calls are preserved. The class conversion mechanism will
+ # ensure that they return the correct value.
+ return node
+
+ if self.ctx.program.recursive:
+ node = self._insert_dynamic_conversion(node)
+ return node
+
+
+def transform(node, ctx):
+ """Transform function call to the compiled counterparts.
+
+ Args:
+ node: AST
+ ctx: EntityContext
+ Returns:
+ A tuple (node, new_names):
+ node: The transformed AST
+ new_names: set(string), containing any newly-generated names
+ """
+ return CallTreeTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
new file mode 100644
index 0000000000..0e50f42c6a
--- /dev/null
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -0,0 +1,138 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for call_trees module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class CallTreesTest(converter_testing.TestCase):
+
+ def test_basic(self):
+
+ def test_fn_1(_):
+ raise ValueError('This should not be called in the compiled version.')
+
+ def other_test_fn_1(a):
+ return a + 1
+
+ def test_fn_2(a):
+ return test_fn_1(a) + 1
+
+ ns = {'test_fn_1': test_fn_1}
+ node, ctx = self.prepare(test_fn_2, ns)
+ node = call_trees.transform(node, ctx)
+
+ with self.compiled(node, ns) as result:
+ new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',))
+ setattr(result, new_name, other_test_fn_1)
+ self.assertEquals(result.test_fn_2(1), 3)
+
+ def test_dynamic_function(self):
+
+ def test_fn_1():
+ raise ValueError('This should be masked by the mock in self.compiled.')
+
+ def test_fn_2(f):
+ return f() + 3
+
+ with self.converted(test_fn_2, call_trees, {}) as result:
+ # 10 = 7 (from the mock) + 3 (from test_fn_2)
+ self.assertEquals(10, result.test_fn_2(test_fn_1))
+
+ def test_basic_method(self):
+
+ class TestClass(object):
+
+ def test_fn_1(self, a):
+ return a + 1
+
+ def test_fn_2(self, a):
+ return self.test_fn_1(a) + 1
+
+ ns = {'TestClass': TestClass}
+ node, ctx = self.prepare(
+ TestClass.test_fn_2,
+ ns,
+ namer=converter_testing.FakeNoRenameNamer(),
+ arg_types={'self': (TestClass.__name__, TestClass)})
+ node = call_trees.transform(node, ctx)
+
+ with self.compiled(node, ns) as result:
+ tc = TestClass()
+ self.assertEquals(3, result.test_fn_2(tc, 1))
+
+ def test_py_func_no_retval(self):
+
+ def test_fn(a):
+ setattr(a, 'foo', 'bar')
+
+ with self.converted(test_fn, call_trees, {'setattr': setattr}) as result:
+ with self.cached_session() as sess:
+
+ class Dummy(object):
+ pass
+
+ a = Dummy()
+ result.test_fn(a)
+ py_func_op, = sess.graph.get_operations()
+ self.assertFalse(hasattr(a, 'foo'))
+ sess.run(py_func_op)
+ self.assertEquals('bar', a.foo)
+
+ def test_py_func_known_function(self):
+
+ def test_fn():
+ return np.random.binomial(2, 0.5)
+
+ with self.converted(test_fn, call_trees, {'np': np},
+ dtypes.int64) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
+ self.assertIn(sess.run(result.test_fn()), (0, 1, 2))
+
+ def test_uncompiled_modules(self):
+
+ def test_fn(a):
+ a = math_ops.multiply(a, constant_op.constant(2))
+ a = math_ops.add(a, constant_op.constant(1))
+ return a
+
+ ns = {'math_ops': math_ops, 'constant_op': constant_op}
+ node, ctx = self.prepare(
+ test_fn,
+ ns,
+ arg_types=set(((math_ops.__name__,), (constant_op.__name__,))))
+ node = call_trees.transform(node, ctx)
+
+ with self.compiled(node, ns) as result:
+ with self.cached_session() as sess:
+ result_tensor = result.test_fn(constant_op.constant(1))
+ self.assertEquals(sess.run(result_tensor), 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/conditional_expressions.py b/tensorflow/python/autograph/converters/conditional_expressions.py
new file mode 100644
index 0000000000..40728f555d
--- /dev/null
+++ b/tensorflow/python/autograph/converters/conditional_expressions.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts the ternary conditional operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+class _FunctionDefs(object):
+
+ def __init__(self):
+ self.nodes = []
+
+
+class _Statement(object):
+
+ def __init__(self):
+ self.scope = None
+
+
+class ConditionalExpressionTransformer(converter.Base):
+ """Converts contitional expressions to functional form."""
+
+ def _postprocess_statement(self, node):
+ """Inserts any separate functions that node may use."""
+ replacements = []
+ for def_node in self.state[_FunctionDefs].nodes:
+ replacements.extend(def_node)
+ replacements.append(node)
+ node = replacements
+ # The corresponding enter is called by self.visit_block (see _process_block)
+ self.state[_FunctionDefs].exit()
+ return node, None
+
+ def _create_branch(self, expr, name_stem):
+ scope = self.state[_Statement].scope
+ name = self.ctx.namer.new_symbol(name_stem, scope.referenced)
+ template = """
+ def name():
+ return expr,
+ """
+ node = templates.replace(template, name=name, expr=expr)
+ self.state[_FunctionDefs].nodes.append(node)
+ return name
+
+ def visit_IfExp(self, node):
+ if anno.hasanno(node.test, anno.Basic.QN):
+ name_root = anno.getanno(node.test, anno.Basic.QN).ssf()
+ else:
+ name_root = 'ifexp'
+
+ true_fn_name = self._create_branch(node.body, '%s_true' % name_root)
+ false_fn_name = self._create_branch(node.orelse, '%s_false' % name_root)
+
+ return templates.replace_as_expression(
+ 'ag__.utils.run_cond(test, true_fn_name, false_fn_name)',
+ test=node.test,
+ true_fn_name=true_fn_name,
+ false_fn_name=false_fn_name)
+
+ def _process_block(self, scope, block):
+ self.state[_Statement].enter()
+ self.state[_Statement].scope = scope
+ block = self.visit_block(
+ block,
+ before_visit=self.state[_FunctionDefs].enter,
+ after_visit=self._postprocess_statement)
+ self.state[_Statement].exit()
+ return block
+
+ def visit_FunctionDef(self, node):
+ node.args = self.generic_visit(node.args)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ node.body = self._process_block(
+ anno.getanno(node, anno.Static.SCOPE), node.body)
+ return node
+
+ def visit_For(self, node):
+ node.target = self.visit(node.target)
+ node.body = self._process_block(
+ anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
+ node.orelse = self._process_block(
+ anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._process_block(
+ anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
+ node.orelse = self._process_block(
+ anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
+ return node
+
+ def visit_If(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._process_block(
+ anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
+ node.orelse = self._process_block(
+ anno.getanno(node, NodeAnno.ORELSE_SCOPE), node.orelse)
+ return node
+
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self._process_block(
+ anno.getanno(node, NodeAnno.BODY_SCOPE), node.body)
+ return node
+
+
+def transform(node, ctx):
+ node = ConditionalExpressionTransformer(ctx).visit(node)
+ return node
diff --git a/tensorflow/python/autograph/converters/conditional_expressions_test.py b/tensorflow/python/autograph/converters/conditional_expressions_test.py
new file mode 100644
index 0000000000..dd1f8d485c
--- /dev/null
+++ b/tensorflow/python/autograph/converters/conditional_expressions_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for conditional_expressions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.platform import test
+
+
+class ConditionalExpressionsTest(converter_testing.TestCase):
+
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ ns = {}
+ with self.converted(test_fn, conditional_expressions, ns) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_basic(self):
+
+ def test_fn(x):
+ return 1 if x else 0
+
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 3)
+
+ def test_nested_orelse(self):
+
+ def test_fn(x):
+ y = x * x if x > 0 else x if x else 1
+ return y
+
+ self.assertTransformedEquivalent(test_fn, -2)
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 2)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/continue_statements.py b/tensorflow/python/autograph/converters/continue_statements.py
new file mode 100644
index 0000000000..584cdc1efd
--- /dev/null
+++ b/tensorflow/python/autograph/converters/continue_statements.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Canonicalizes continue statements by de-sugaring into a control boolean."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+# Tags for local state.
+CONTROL_VAR_NAME = 'control_var_name'
+CONTINUE_USED = 'continue_used'
+GUARD_CREATED = 'guard_created'
+CREATE_GUARD_NEXT = 'create_guard_next'
+
+
+class ContinueCanonicalizationTransformer(converter.Base):
+ """Canonicalizes continue statements into additional conditionals."""
+
+ def visit_Continue(self, node):
+ self.set_local(CONTINUE_USED, True)
+ template = """
+ var_name = tf.constant(True)
+ """
+ return templates.replace(
+ template, var_name=self.get_local(CONTROL_VAR_NAME))
+
+ def _postprocess_statement(self, node):
+ # Example of how the state machine below works:
+ #
+ # 1| stmt # State: CONTINUE_USED = False
+ # | # Action: none
+ # 2| if cond:
+ # 3| continue # State: CONTINUE_USED = True,
+ # | # GUARD_CREATED = False,
+ # | # CREATE_GUARD_NEXT = False
+ # | # Action: set CREATE_GUARD_NEXT = True
+ # 4| stmt # State: CONTINUE_USED = True,
+ # | # GUARD_CREATED = False,
+ # | # CREATE_GUARD_NEXT = True
+ # | # Action: create `if not continue_used`,
+ # | # set GUARD_CREATED = True
+ # 5| stmt # State: CONTINUE_USED = True, GUARD_CREATED = True
+ # | # Action: none (will be wrapped under previously
+ # | # created if node)
+
+ if self.get_local(CONTINUE_USED, False):
+ if self.get_local(GUARD_CREATED, False):
+ return node, None
+
+ elif not self.get_local(CREATE_GUARD_NEXT, False):
+ self.set_local(CREATE_GUARD_NEXT, True)
+ return node, None
+
+ else:
+ self.set_local(GUARD_CREATED, True)
+ template = """
+ if not var_name:
+ original_node
+ """
+ cond, = templates.replace(
+ template,
+ var_name=self.get_local(CONTROL_VAR_NAME),
+ original_node=node)
+ return cond, cond.body
+ return node, None
+
+ def _visit_loop_body(self, node, nodes):
+ self.enter_local_scope()
+ scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ continue_var = self.ctx.namer.new_symbol('continue_', scope.referenced)
+ self.set_local(CONTROL_VAR_NAME, continue_var)
+
+ nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
+
+ if self.get_local(CONTINUE_USED, False):
+ template = """
+ var_name = tf.constant(False)
+ """
+ control_var_init = templates.replace(template, var_name=continue_var)
+ nodes = control_var_init + nodes
+
+ self.exit_local_scope()
+ return nodes
+
+ def _visit_non_loop_body(self, nodes):
+ self.enter_local_scope(inherit=(CONTROL_VAR_NAME,))
+ nodes = self.visit_block(nodes, after_visit=self._postprocess_statement)
+ continue_used = self.get_local(CONTINUE_USED, False)
+ self.exit_local_scope(keep=(CONTINUE_USED,))
+ return nodes, continue_used
+
+ def visit_While(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_loop_body(node, node.body)
+ # A continue in the else clause applies to the containing scope.
+ node.orelse, _ = self._visit_non_loop_body(node.orelse)
+ return node
+
+ def visit_For(self, node):
+ node.target = self.generic_visit(node.target)
+ node.iter = self.generic_visit(node.iter)
+ node.body = self._visit_loop_body(node, node.body)
+ # A continue in the else clause applies to the containing scope.
+ node.orelse, _ = self._visit_non_loop_body(node.orelse)
+ return node
+
+ def visit_If(self, node):
+ node.test = self.generic_visit(node.test)
+ node.body, continue_used_body = self._visit_non_loop_body(node.body)
+ node.orelse, continue_used_orelse = self._visit_non_loop_body(node.orelse)
+ self.set_local(CONTINUE_USED, continue_used_body or continue_used_orelse)
+ return node
+
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body, _ = self._visit_non_loop_body(node.body)
+ return node
+
+
+def transform(node, ctx):
+ return ContinueCanonicalizationTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/continue_statements_test.py b/tensorflow/python/autograph/converters/continue_statements_test.py
new file mode 100644
index 0000000000..d6aaa50443
--- /dev/null
+++ b/tensorflow/python/autograph/converters/continue_statements_test.py
@@ -0,0 +1,94 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for continue_statements module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.eager import context as tfe_ctx
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class ContinueCanonicalizationTest(converter_testing.TestCase):
+
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ with self.converted(test_fn, continue_statements, {},
+ constant_op.constant) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_basic(self):
+
+ def test_fn(x):
+ v = []
+ while x > 0:
+ x -= 1
+ if x % 2 == 0:
+ continue
+ v.append(x)
+ return v
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
+
+ def test_for_loop(self):
+
+ def test_fn(a):
+ v = []
+ for x in a:
+ x -= 1
+ if x % 2 == 0:
+ continue
+ v.append(x)
+ return v
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [1])
+ self.assertTransformedEquivalent(test_fn, [2])
+ self.assertTransformedEquivalent(test_fn, [1, 2, 3])
+
+ def test_nested(self):
+
+ def test_fn(x):
+ v = []
+ u = []
+ w = []
+ while x > 0:
+ x -= 1
+ if x % 2 == 0:
+ if x % 3 != 0:
+ u.append(x)
+ else:
+ w.append(x)
+ continue
+ v.append(x)
+ return v, u, w
+
+ with tfe_ctx.eager_mode():
+ self.assertTransformedEquivalent(test_fn, 0)
+ self.assertTransformedEquivalent(test_fn, 1)
+ self.assertTransformedEquivalent(test_fn, 3)
+ self.assertTransformedEquivalent(test_fn, 4)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
new file mode 100644
index 0000000000..416a60d2ee
--- /dev/null
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -0,0 +1,339 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles control flow statements: while, for, if."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis import annos
+
+
+class SymbolNamer(object):
+ """Describes the interface for ControlFlowTransformer's namer."""
+
+ def new_symbol(self, name_root, reserved_locals):
+ """Generate a new unique symbol.
+
+ Args:
+ name_root: String, used as stem in the new name.
+ reserved_locals: Set(string), additional local symbols that are reserved
+ and which should not be used.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+
+class ControlFlowTransformer(converter.Base):
+ """Transforms control flow structures like loops an conditionals."""
+
+ def _create_cond_branch(self, body_name, aliased_orig_names,
+ aliased_new_names, body, returns):
+ if aliased_orig_names:
+ template = """
+ def body_name():
+ aliased_new_names, = aliased_orig_names,
+ body
+ return (returns,)
+ """
+ return templates.replace(
+ template,
+ body_name=body_name,
+ body=body,
+ aliased_orig_names=aliased_orig_names,
+ aliased_new_names=aliased_new_names,
+ returns=returns)
+ else:
+ template = """
+ def body_name():
+ body
+ return (returns,)
+ """
+ return templates.replace(
+ template, body_name=body_name, body=body, returns=returns)
+
+ def _create_cond_expr(self, results, test, body_name, orelse_name):
+ if results is not None:
+ template = """
+ results = ag__.utils.run_cond(test, body_name, orelse_name)
+ """
+ return templates.replace(
+ template,
+ test=test,
+ results=results,
+ body_name=body_name,
+ orelse_name=orelse_name)
+ else:
+ template = """
+ ag__.utils.run_cond(test, body_name, orelse_name)
+ """
+ return templates.replace(
+ template, test=test, body_name=body_name, orelse_name=orelse_name)
+
+ def _fmt_symbol_list(self, symbol_set):
+ if not symbol_set:
+ return 'no variables'
+ return ', '.join(map(str, symbol_set))
+
+ def _validate_no_live_vars_created(self, node):
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+ live_vars_created_in_body = live_vars_out & body_scope.created
+ if live_vars_created_in_body:
+ raise ValueError(
+ 'The following variables are created inside the loop and used later:'
+ '\n%s\n'
+ 'Variables must be declared outside loops because loops may not'
+ ' necessarily execute.' % self._fmt_symbol_list(
+ live_vars_created_in_body))
+
+ def visit_If(self, node):
+ node = self.generic_visit(node)
+
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
+ defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
+ live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+
+ modified_in_cond = body_scope.modified | orelse_scope.modified
+ returned_from_cond = set()
+ for s in modified_in_cond:
+ if s in live_out:
+ returned_from_cond.add(s)
+ elif s.is_composite():
+ # Special treatment for compound objects: if any of their owner entities
+ # are live, then they are outputs as well.
+ if any(owner in live_out for owner in s.owner_set):
+ returned_from_cond.add(s)
+
+ need_alias_in_body = body_scope.modified & defined_in
+ need_alias_in_orelse = orelse_scope.modified & defined_in
+
+ created_in_body = body_scope.modified & returned_from_cond - defined_in
+ created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
+
+ if created_in_body != created_in_orelse:
+ raise ValueError(
+ 'if statement may not initialize all variables: the true branch'
+ ' creates %s, while the false branch creates %s. Make sure all'
+ ' these variables are initialized either in both'
+ ' branches or before the if statement.' %
+ (self._fmt_symbol_list(created_in_body),
+ self._fmt_symbol_list(created_in_orelse)))
+
+ # Alias the closure variables inside the conditional functions, to allow
+ # the functions access to the respective variables.
+ # We will alias variables independently for body and orelse scope,
+ # because different branches might write different variables.
+ aliased_body_orig_names = tuple(need_alias_in_body)
+ aliased_orelse_orig_names = tuple(need_alias_in_orelse)
+ aliased_body_new_names = tuple(
+ self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
+ for s in aliased_body_orig_names)
+ aliased_orelse_new_names = tuple(
+ self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
+ for s in aliased_orelse_orig_names)
+
+ alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
+ alias_orelse_map = dict(
+ zip(aliased_orelse_orig_names, aliased_orelse_new_names))
+
+ node_body = ast_util.rename_symbols(node.body, alias_body_map)
+ node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
+
+ returned_from_cond = tuple(returned_from_cond)
+ if returned_from_cond:
+ if len(returned_from_cond) == 1:
+ # TODO(mdan): Move this quirk into the operator implementation.
+ cond_results = returned_from_cond[0]
+ else:
+ cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
+
+ returned_from_body = tuple(
+ alias_body_map[s] if s in need_alias_in_body else s
+ for s in returned_from_cond)
+ returned_from_orelse = tuple(
+ alias_orelse_map[s] if s in need_alias_in_orelse else s
+ for s in returned_from_cond)
+
+ else:
+ # When the cond would return no value, we leave the cond called without
+ # results. That in turn should trigger the side effect guards. The
+ # branch functions will return a dummy value that ensures cond
+ # actually has some return value as well.
+ cond_results = None
+ # TODO(mdan): This doesn't belong here; it's specific to the operator.
+ returned_from_body = templates.replace_as_expression('tf.constant(1)')
+ returned_from_orelse = templates.replace_as_expression('tf.constant(1)')
+
+ body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
+ orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
+
+ body_def = self._create_cond_branch(
+ body_name,
+ aliased_orig_names=aliased_body_orig_names,
+ aliased_new_names=aliased_body_new_names,
+ body=node_body,
+ returns=returned_from_body)
+ orelse_def = self._create_cond_branch(
+ orelse_name,
+ aliased_orig_names=aliased_orelse_orig_names,
+ aliased_new_names=aliased_orelse_new_names,
+ body=node_orelse,
+ returns=returned_from_orelse)
+ cond_expr = self._create_cond_expr(cond_results, node.test, body_name,
+ orelse_name)
+
+ return body_def + orelse_def + cond_expr
+
+ def visit_While(self, node):
+ self.generic_visit(node)
+
+ self._validate_no_live_vars_created(node)
+
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ body_closure = body_scope.modified - body_scope.created
+ all_referenced = body_scope.referenced
+
+ cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
+ cond_closure = set()
+ for s in cond_scope.used:
+ for root in s.support_set:
+ if root not in body_scope.created:
+ cond_closure.add(root)
+
+ state = list(body_closure)
+ if not state:
+ # TODO(mdan): Implement this properly.
+ # To complete this statement, we need to check whether any variable
+ # created inside the body scope is used before being modified outside the
+ # scope. This should be done during activity analysis, and in general
+ # should cover the case where variables may not be initialized.
+ raise ValueError('cannot convert while loop: no outputs')
+
+ state_ssf = [
+ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ ]
+ ssf_map = {
+ name: ssf
+ for name, ssf in zip(state, state_ssf)
+ if str(name) != ssf
+ }
+
+ if len(state) == 1:
+ state = state[0]
+ state_ssf = state_ssf[0]
+ state_ast_tuple = state
+ else:
+ state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+ node_body = ast_util.rename_symbols(node.body, ssf_map)
+ test = ast_util.rename_symbols(node.test, ssf_map)
+
+ # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
+ template = """
+ def test_name(state_ssf):
+ return test
+ def body_name(state_ssf):
+ body
+ return state_ssf,
+ state_ast_tuple = ag__.while_stmt(
+ test_name, body_name, (state,), (extra_deps,))
+ """
+ node = templates.replace(
+ template,
+ state=state,
+ state_ssf=state_ssf,
+ state_ast_tuple=state_ast_tuple,
+ test_name=self.ctx.namer.new_symbol('loop_test', body_scope.referenced),
+ test=test,
+ body_name=self.ctx.namer.new_symbol('loop_body', body_scope.referenced),
+ body=node_body,
+ extra_deps=tuple(s.ast() for s in cond_closure),
+ )
+
+ return node
+
+ def visit_For(self, node):
+ self.generic_visit(node)
+
+ self._validate_no_live_vars_created(node)
+
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ body_closure = body_scope.modified - body_scope.created
+ all_referenced = body_scope.referenced
+
+ state = list(body_closure)
+
+ state_ssf = [
+ self.ctx.namer.new_symbol(s.ssf(), all_referenced) for s in state
+ ]
+ ssf_map = {
+ name: ssf
+ for name, ssf in zip(state, state_ssf)
+ if str(name) != ssf
+ }
+
+ if len(state) == 1:
+ state = state[0]
+ state_ssf = state_ssf[0]
+ state_ast_tuple = state
+ else:
+ state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+ node_body = ast_util.rename_symbols(node.body, ssf_map)
+ if anno.hasanno(node, 'extra_test'):
+ extra_test = anno.getanno(node, 'extra_test')
+ extra_test = ast_util.rename_symbols(extra_test, ssf_map)
+ else:
+ extra_test = parser.parse_expression('True')
+
+ template = """
+ def extra_test_name(state_ssf):
+ return extra_test_expr
+ def body_name(loop_vars, state_ssf):
+ # Workaround for PEP-3113
+ iterate = loop_vars
+ body
+ return state_ssf,
+ state_ast_tuple = ag__.for_stmt(
+ iter_, extra_test_name, body_name, (state,))
+ """
+ node = templates.replace(
+ template,
+ state=state,
+ state_ssf=state_ssf,
+ state_ast_tuple=state_ast_tuple,
+ iter_=node.iter,
+ iterate=node.target,
+ extra_test_name=self.ctx.namer.new_symbol('extra_test', all_referenced),
+ extra_test_expr=extra_test,
+ body_name=self.ctx.namer.new_symbol('loop_body', all_referenced),
+ body=node_body)
+
+ return node
+
+
+def transform(node, ctx):
+ node = ControlFlowTransformer(ctx).visit(node)
+ return node
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
new file mode 100644
index 0000000000..cfa0ea920c
--- /dev/null
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -0,0 +1,247 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class ControlFlowTest(converter_testing.TestCase):
+
+ def assertTransformedResult(self, test_fn, inputs, expected):
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+ with self.converted(test_fn, control_flow, {},
+ constant_op.constant) as result:
+ with self.cached_session() as sess:
+ self.assertEqual(sess.run(result.test_fn(*inputs)), expected)
+
+ def test_while_basic(self):
+
+ def test_fn(n):
+ i = 0
+ s = 0
+ while i < n:
+ s += i
+ i += 1
+ return s, i, n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
+
+ def test_while_nested(self):
+
+ def test_fn(n):
+ i = 0
+ j = 0
+ s = 0
+ while i < n:
+ while j < i:
+ j += 3
+ u = i + j # 'u' is not defined within the inner loop
+ s += u
+ i += 1
+ j = 0
+ return s, i, j, n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(5),
+ (25, 5, 0, 5))
+
+ def test_while_single_output(self):
+
+ def test_fn(n):
+ while n > 0:
+ n -= 1
+ return n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
+
+ def test_while_variable_defined_in_body(self):
+ def bad_while_loop(n):
+ while n > 0:
+ n -= 1
+ s = n
+ return s
+
+ node, ctx = self.prepare(bad_while_loop, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
+ def test_if_basic(self):
+
+ def test_fn(n):
+ a = 0
+ b = 0
+ if n > 0:
+ a = -n
+ else:
+ b = 2 * n
+ return a, b
+
+ self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
+ self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
+
+ def test_if_complex_outputs(self):
+
+ class TestClass(object):
+
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+
+ def test_fn(n, obj):
+ obj.a = 0
+ obj.b = 0
+ if n > 0:
+ obj.a = -n
+ else:
+ obj.b = 2 * n
+ return obj
+
+ with self.converted(test_fn, control_flow, {}) as result:
+ with self.cached_session() as sess:
+ res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
+ self.assertEqual(sess.run((res_obj.a, res_obj.b)), (-1, 0))
+ res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
+ self.assertEqual(sess.run((res_obj.a, res_obj.b)), (0, -2))
+
+ def test_if_single_output(self):
+
+ def test_fn(n):
+ if n > 0:
+ n = -n
+ return n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
+
+ def test_if_semi(self):
+
+ def test_fn(n):
+ if n > 0:
+ n = 3
+ return n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
+ self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
+
+ def test_if_local_var(self):
+
+ def test_fn(n):
+ if n > 0:
+ b = 4
+ n = b + 1
+ return n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
+ self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
+
+ def test_if_no_outputs(self):
+
+ def test_fn(n):
+ if n > 0:
+ b = 4 # pylint:disable=unused-variable
+ return n
+
+ # Without side effect guards, the if statement will stage a cond,
+ # but that will be pruned at execution.
+ self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
+ self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
+
+ def test_if_imbalanced_outputs(self):
+
+ def test_fn(n):
+ if n > 0:
+ b = 4
+ return b
+
+ node, ctx = self.prepare(test_fn, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
+ def test_simple_for(self):
+
+ def test_fn(l):
+ s1 = 0
+ s2 = 0
+ for e in l:
+ s1 += e
+ s2 += e * e
+ return s1, s2
+
+ self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10))
+ empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
+ self.assertTransformedResult(test_fn, empty_vector, (0, 0))
+
+ def test_for_single_output(self):
+
+ def test_fn(l):
+ s = 0
+ for e in l:
+ s += e
+ return s
+
+ self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4)
+ empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
+ self.assertTransformedResult(test_fn, empty_vector, 0)
+
+ def test_for_iterated_expression(self):
+
+ eval_count = [0]
+
+ def count_evals(x):
+ eval_count[0] += 1
+ return x
+
+ def test_fn(n):
+ s = 0
+ for e in count_evals(range(n)):
+ s += e
+ return s
+
+ ns = {'count_evals': count_evals}
+ node, ctx = self.prepare(test_fn, ns)
+ node = control_flow.transform(node, ctx)
+
+ with self.compiled(node, ns) as result:
+ self.assertEqual(result.test_fn(5), 10)
+ self.assertEqual(eval_count[0], 1)
+
+ def test_for_variable_defined_in_body(self):
+ def bad_for_loop(n):
+ for i in range(n):
+ s = i
+ return s
+
+ node, ctx = self.prepare(bad_for_loop, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
+ def test_for_tuple_unpacking(self):
+ def test_fn(x_list):
+ z = tf.constant(0) # pylint:disable=undefined-variable
+ for i, x in enumerate(x_list):
+ z = z + x + i
+ return z
+
+ self.assertTransformedResult(test_fn, [3, 3], 7)
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/decorators.py b/tensorflow/python/autograph/converters/decorators.py
new file mode 100644
index 0000000000..724f0fe5ed
--- /dev/null
+++ b/tensorflow/python/autograph/converters/decorators.py
@@ -0,0 +1,105 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles decorators.
+
+Note: this module only deals with functions whose decorators are still recorded
+in the AST. This does not always happen. See the unit test for an example.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.util import tf_inspect
+
+
+class DecoratorsTransformer(converter.Base):
+ """Converts or removes decorators."""
+
+ def visit_FunctionDef(self, node):
+ self.generic_visit(node)
+ kept_decorators = []
+ for dec in node.decorator_list:
+ if isinstance(dec, gast.Call):
+ dec_func = dec.func
+ else:
+ dec_func = dec
+
+ # Special cases.
+ # TODO(mdan): Is there any way we can treat these more generically?
+ # We may want to forego using decorators altogether if we can't
+ # properly support them.
+ if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',):
+ # Assumption: decorators are only visible in the AST when converting
+ # a function inline (via another decorator).
+ # In that case, the converted function is no longer part of the
+ # original object that it was declared into.
+ # This is currently verified by tests.
+ continue
+
+ original_dec = anno.getanno(dec_func, anno.Basic.QN)
+ dec_value = anno.getanno(dec_func, 'live_val')
+
+ if dec_value in self.ctx.program.autograph_decorators:
+ # AutoGraph decorators do not need to be preserved.
+ continue
+
+ # When using foo.bar.baz, we only really need to grab foo and import
+ # that.
+ dec_support_node = dec_func
+ while isinstance(dec_support_node, gast.Attribute):
+ dec_support_node = dec_support_node.value
+
+ if not anno.hasanno(dec_support_node, 'live_val'):
+ raise ValueError(
+ 'could not resolve symbol "%s" when looking up decorator "%s"' %
+ (anno.getanno(dec_support_node, anno.Basic.QN), original_dec))
+
+ dec_support = anno.getanno(dec_support_node, 'live_val')
+ # The tuple contains:
+ # * the AST that represents the decorator
+ # * the entity supporting the decorator (i.e., what we need to import)
+ # * the name of the module that needs to be imported for this decorator
+ # to properly resolve.
+ # Examples:
+ # for foo.bar, the tuple is (<ast>, <module foo>, 'foo')
+ # for baz, the tuple is (<ast>, <module baz.__module__>, 'baz')
+ kept_decorators.append((dec, dec_support,
+ anno.getanno(dec_support_node, anno.Basic.QN)))
+
+ for _, dec_support, name in kept_decorators:
+ if tf_inspect.ismodule(dec_support):
+ self.ctx.program.additional_imports.add(
+ 'import %s as %s' % (dec_support.__name__, name))
+ else:
+ if dec_support.__module__ == '__main__':
+ raise ValueError(
+ 'decorator "%s" was not allowed because it is declared '
+ 'in the module "%s". To fix this, declare it in a separate '
+ 'module that we can import it from.' % (dec_support,
+ dec_support.__module__))
+ self.ctx.program.additional_imports.add(
+ 'from %s import %s' % (dec_support.__module__, name))
+
+ node.decorator_list = [dec for dec, _, _ in kept_decorators]
+ return node
+
+
+def transform(node, ctx):
+ return DecoratorsTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/decorators_test.py b/tensorflow/python/autograph/converters/decorators_test.py
new file mode 100644
index 0000000000..fb31c8d583
--- /dev/null
+++ b/tensorflow/python/autograph/converters/decorators_test.py
@@ -0,0 +1,152 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for decorators module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import wraps
+import imp
+
+from tensorflow.python import autograph
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.platform import test
+
+
+# The Python parser only briefly captures decorators into the AST.
+# The interpreter desugars them on load, and the decorated function loses any
+# trace of the decorator (which is normally what you would expect, since
+# they are meant to be transparent).
+# However, decorators are still visible when you analyze the function
+# from inside a decorator, before it was applied - as is the case
+# with our conversion decorators.
+
+
+def simple_decorator(f):
+ return lambda a: f(a) + 1
+
+
+def self_transform_decorator(transform):
+
+ def decorator(f):
+ @wraps(f)
+ def wrapper(*args):
+ # This removing wrapper is defined in the test below. This setup is so
+ # intricate in order to simulate how we use the transformer in practice.
+ transformed_f = transform(f, (self_transform_decorator,))
+ return transformed_f(*args) + 1
+ return wrapper
+ return decorator
+
+
+class DecoratorsTest(converter_testing.TestCase):
+
+ def _transform(self, f, autograph_decorators):
+ namespace = {
+ 'self_transform_decorator': self_transform_decorator,
+ 'simple_decorator': simple_decorator,
+ 'converter_testing': converter_testing,
+ }
+ node, ctx = self.prepare(
+ f,
+ namespace,
+ recursive=False,
+ autograph_decorators=autograph_decorators)
+ node = decorators.transform(node, ctx)
+ import_line = '\n'.join(ctx.program.additional_imports)
+ result, _ = compiler.ast_to_object(node, source_prefix=import_line)
+ return getattr(result, f.__name__)
+
+ def test_noop(self):
+
+ def test_fn(a):
+ return a
+
+ with self.converted(test_fn, decorators, {}) as result:
+ self.assertEqual(1, result.test_fn(1))
+
+ def test_function(self):
+
+ @self_transform_decorator(self._transform)
+ def test_fn(a):
+ return a
+
+ # 2 = 1 (a) + 1 (decorator applied exactly once)
+ self.assertEqual(2, test_fn(1))
+
+ def test_method(self):
+
+ class TestClass(object):
+
+ @self_transform_decorator(self._transform)
+ def test_fn(self, a):
+ return a
+
+ # 2 = 1 (a) + 1 (decorator applied exactly once)
+ self.assertEqual(2, TestClass().test_fn(1))
+
+ def test_multiple_decorators(self):
+
+ class TestClass(object):
+
+ # Note that reversing the order of this two doesn't work.
+ @classmethod
+ @self_transform_decorator(self._transform)
+ def test_fn(cls, a):
+ return a
+
+ # 2 = 1 (a) + 1 (decorator applied exactly once)
+ self.assertEqual(2, TestClass.test_fn(1))
+
+ def test_nested_decorators_local(self):
+
+ @self_transform_decorator(self._transform)
+ def test_fn(a):
+ @simple_decorator
+ def inner_fn(b):
+ return b + 11
+ return inner_fn(a)
+
+ # Expected to fail because simple_decorator could not be imported.
+ with self.assertRaises(transformer.AutographParseError):
+ test_fn(1)
+
+ def test_nested_decorators_imported(self):
+
+ @self_transform_decorator(self._transform)
+ def test_fn(a):
+
+ @converter_testing.imported_decorator
+ def inner_fn(b):
+ return b + 11
+
+ return inner_fn(a)
+
+ # Work around TensorFlow's symbol suppression mechanism that causes core to
+ # be invisible in the generated code.
+ core_mod = imp.new_module('core')
+ core_mod.converter_testing = converter_testing
+ autograph.core = core_mod
+
+ # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn)
+ self.assertEqual(14, test_fn(1))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/directives.py b/tensorflow/python/autograph/converters/directives.py
new file mode 100644
index 0000000000..fc646348ef
--- /dev/null
+++ b/tensorflow/python/autograph/converters/directives.py
@@ -0,0 +1,128 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles directives.
+
+This converter removes the directive functions from the code and moves the
+information they specify into AST annotations. It is a specialized form of
+static analysis, one that is specific to AutoGraph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.util import tf_inspect
+
+ENCLOSING_LOOP = 'enclosing_loop'
+
+
+def _map_args(call_node, function):
+ """Maps AST call nodes to the actual function's arguments.
+
+ Args:
+ call_node: ast.Call
+ function: Callable[..., Any], the actual function matching call_node
+ Returns:
+ Dict[Text, ast.AST], mapping each of the function's argument names to
+ the respective AST node.
+ Raises:
+ ValueError: if the default arguments are not correctly set
+ """
+ args = call_node.args
+ kwds = {kwd.arg: kwd.value for kwd in call_node.keywords}
+ call_args = tf_inspect.getcallargs(function, *args, **kwds)
+
+ # Keyword arguments not specified in kwds will be mapped to their defaults,
+ # which are Python values. Since we don't currently have a way to transform
+ # those into AST references, we simply remove them. By convention, directives
+ # use UNSPECIFIED as default value for for optional arguments. No other
+ # defaults should be present.
+ unexpected_defaults = []
+ for k in call_args:
+ if (k not in kwds
+ and call_args[k] not in args
+ and call_args[k] is not directives.UNSPECIFIED):
+ unexpected_defaults.append(k)
+ if unexpected_defaults:
+ raise ValueError('Unexpected keyword argument values, %s, for function %s'
+ % (zip(unexpected_defaults,
+ [call_args[k] for k in unexpected_defaults]),
+ function))
+ return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED}
+
+
+class DirectivesTransformer(converter.Base):
+ """Parses compiler directives and converts them into AST annotations."""
+
+ def _process_symbol_directive(self, call_node, directive):
+ if len(call_node.args) < 1:
+ raise ValueError('"%s" requires a positional first argument'
+ ' as the target' % directive.__name__)
+ target = call_node.args[0]
+ defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS)
+ for def_ in defs:
+ def_.directives[directive] = _map_args(call_node, directive)
+ return call_node
+
+ def _process_statement_directive(self, call_node, directive):
+ if self.local_scope_level < 1:
+ raise ValueError(
+ '"%s" must be used inside a statement' % directive.__name__)
+ target = self.get_local(ENCLOSING_LOOP)
+ node_anno = anno.getanno(target, converter.AgAnno.DIRECTIVES, {})
+ node_anno[directive] = _map_args(call_node, directive)
+ anno.setanno(target, converter.AgAnno.DIRECTIVES, node_anno)
+ return call_node
+
+ def visit_Expr(self, node):
+ if isinstance(node.value, gast.Call):
+ call_node = node.value
+ if anno.hasanno(call_node.func, 'live_val'):
+ live_val = anno.getanno(call_node.func, 'live_val')
+
+ if live_val is directives.set_element_type:
+ call_node = self._process_symbol_directive(call_node, live_val)
+ elif live_val is directives.set_loop_options:
+ call_node = self._process_statement_directive(call_node, live_val)
+ else:
+ return self.generic_visit(node)
+
+ return None # Directive calls are not output in the generated code.
+ return self.generic_visit(node)
+
+ # TODO(mdan): This will be insufficient for other control flow.
+ # That means that if we ever have a directive that affects things other than
+ # loops, we'll need support for parallel scopes, or have multiple converters.
+ def _track_and_visit_loop(self, node):
+ self.enter_local_scope()
+ self.set_local(ENCLOSING_LOOP, node)
+ node = self.generic_visit(node)
+ self.exit_local_scope()
+ return node
+
+ def visit_While(self, node):
+ return self._track_and_visit_loop(node)
+
+ def visit_For(self, node):
+ return self._track_and_visit_loop(node)
+
+
+def transform(node, ctx):
+ return DirectivesTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/directives_test.py b/tensorflow/python/autograph/converters/directives_test.py
new file mode 100644
index 0000000000..570fb8e379
--- /dev/null
+++ b/tensorflow/python/autograph/converters/directives_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for directives module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import directives as directives_converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core.converter import AgAnno
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class DirectivesTest(converter_testing.TestCase):
+
+ def test_local_target(self):
+
+ def test_fn():
+ l = []
+ string_var = 0
+ directives.set_element_type(l, 'a', string_var)
+
+ node, ctx = self.prepare(test_fn, {'directives': directives})
+ node = directives_converter.transform(node, ctx)
+
+ def_, = anno.getanno(node.body[0].targets[0],
+ anno.Static.DEFINITIONS)
+ d = def_.directives[directives.set_element_type]
+ self.assertEqual(d['dtype'].s, 'a')
+ self.assertEqual(d['shape'].id, 'string_var')
+
+ def test_argument_target(self):
+
+ def test_fn(a):
+ directives.set_element_type(a, 1, shape=2)
+
+ node, ctx = self.prepare(test_fn, {'directives': directives})
+ node = directives_converter.transform(node, ctx)
+
+ def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
+ d = def_.directives[directives.set_element_type]
+ self.assertEqual(d['dtype'].n, 1)
+ self.assertEqual(d['shape'].n, 2)
+
+ def test_loop_target(self):
+
+ def test_fn():
+ a = True
+ while True:
+ directives.set_loop_options(parallel_iterations=10, back_prop=a)
+
+ node, ctx = self.prepare(test_fn, {'directives': directives})
+ node = directives_converter.transform(node, ctx)
+
+ d = anno.getanno(node.body[1], AgAnno.DIRECTIVES)
+ d = d[directives.set_loop_options]
+ self.assertEqual(d['parallel_iterations'].n, 10)
+ self.assertEqual(d['back_prop'].id, 'a')
+ self.assertNotIn('swap_memory', d)
+
+ def test_invalid_default(self):
+
+ def invalid_directive(valid_arg, invalid_default=object()):
+ del valid_arg
+ del invalid_default
+ return
+
+ def call_invalid_directive():
+ invalid_directive(1)
+
+ node, _ = parser.parse_entity(call_invalid_directive)
+ # Find the call to the invalid directive
+ node = node.body[0].body[0].value
+ with self.assertRaisesRegexp(ValueError, 'Unexpected keyword.*'):
+ directives_converter._map_args(node, invalid_directive)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/error_handlers.py b/tensorflow/python/autograph/converters/error_handlers.py
new file mode 100644
index 0000000000..de46c0c830
--- /dev/null
+++ b/tensorflow/python/autograph/converters/error_handlers.py
@@ -0,0 +1,53 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wraps function bodies with a try/except to rewrite error tracebacks.
+
+Only adds try/except wrappers to functions that have the anno.Basic.ORIGIN
+annotation because these are the functions originally written by the user.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import templates
+
+
+class ErrorRewritingTransformer(converter.Base):
+ """Possibly wraps the body of a function in a try/except.
+
+ Only wraps functions that were originally defined by the user, detected by
+ checking for the anno.Basic.ORIGIN annotation.
+ """
+
+ def visit_FunctionDef(self, node):
+ node = self.generic_visit(node)
+
+ if (anno.hasanno(node, anno.Basic.ORIGIN) and
+ len(self.enclosing_entities) <= 1):
+ template = """
+ try:
+ body
+ except:
+ ag__.rewrite_graph_construction_error(ag_source_map__)
+ """
+ node.body = templates.replace(template, body=node.body)
+ return node
+
+
+def transform(node, ctx):
+ return ErrorRewritingTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/error_handlers_test.py b/tensorflow/python/autograph/converters/error_handlers_test.py
new file mode 100644
index 0000000000..676ff9e02b
--- /dev/null
+++ b/tensorflow/python/autograph/converters/error_handlers_test.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for error_handlers module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.platform import test
+
+
+class ErrorHandlersTest(converter_testing.TestCase):
+
+ def test_basic(self):
+
+ def test_fn():
+ raise ValueError()
+
+ node, ctx = self.prepare(test_fn, {})
+ anno.setanno(
+ node, anno.Basic.ORIGIN,
+ origin_info.OriginInfo(None, 'test_function_name', 'test_code',
+ 'test_comment'))
+ node = error_handlers.transform(node, ctx)
+ with self.compiled(node, {}) as result:
+ with self.assertRaises(errors.GraphConstructionError):
+ # Here we just assert that the handler works. Its correctness is
+ # verified by errors_test.py.
+ result.test_fn()
+
+ def test_no_origin_annotation(self):
+
+ def test_fn():
+ raise ValueError()
+
+ with self.converted(test_fn, error_handlers, {}) as result:
+ with self.assertRaises(ValueError):
+ result.test_fn()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/list_comprehensions.py b/tensorflow/python/autograph/converters/list_comprehensions.py
new file mode 100644
index 0000000000..5be6cb9a98
--- /dev/null
+++ b/tensorflow/python/autograph/converters/list_comprehensions.py
@@ -0,0 +1,82 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Lowers list comprehensions into for and if statements.
+
+Example:
+
+ result = [x * x for x in xs]
+
+becomes
+
+ result = []
+ for x in xs:
+ elt = x * x
+ result.append(elt)
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
+
+
+# TODO(mdan): This should covert directly to operator calls.
+
+
+class ListCompTransformer(converter.Base):
+ """Lowers list comprehensions into standard control flow."""
+
+ def visit_Assign(self, node):
+ if not isinstance(node.value, gast.ListComp):
+ return self.generic_visit(node)
+ if len(node.targets) > 1:
+ raise NotImplementedError('multiple assignments')
+
+ target, = node.targets
+ list_comp_node = node.value
+
+ template = """
+ target = []
+ """
+ initialization = templates.replace(template, target=target)
+
+ template = """
+ target.append(elt)
+ """
+ body = templates.replace(template, target=target, elt=list_comp_node.elt)
+
+ for gen in reversed(list_comp_node.generators):
+ for gen_if in reversed(gen.ifs):
+ template = """
+ if test:
+ body
+ """
+ body = templates.replace(template, test=gen_if, body=body)
+ template = """
+ for target in iter_:
+ body
+ """
+ body = templates.replace(
+ template, iter_=gen.iter, target=gen.target, body=body)
+
+ return initialization + body
+
+
+def transform(node, ctx):
+ return ListCompTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/list_comprehensions_test.py b/tensorflow/python/autograph/converters/list_comprehensions_test.py
new file mode 100644
index 0000000000..1e66139af6
--- /dev/null
+++ b/tensorflow/python/autograph/converters/list_comprehensions_test.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for list_comprehensions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import list_comprehensions
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.platform import test
+
+
+class ListCompTest(converter_testing.TestCase):
+
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ with self.converted(test_fn, list_comprehensions, {}) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_basic(self):
+
+ def test_fn(l):
+ s = [e * e for e in l]
+ return s
+
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [1, 2, 3])
+
+ def test_multiple_generators(self):
+
+ def test_fn(l):
+ s = [e * e for sublist in l for e in sublist]
+ return s
+
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [[1], [2], [3]])
+
+ def test_cond(self):
+
+ def test_fn(l):
+ s = [e * e for e in l if e > 1]
+ return s
+
+ self.assertTransformedEquivalent(test_fn, [])
+ self.assertTransformedEquivalent(test_fn, [1, 2, 3])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/lists.py b/tensorflow/python/autograph/converters/lists.py
new file mode 100644
index 0000000000..8180801753
--- /dev/null
+++ b/tensorflow/python/autograph/converters/lists.py
@@ -0,0 +1,239 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for list operations.
+
+This includes converting Python lists to TensorArray/TensorList.
+"""
+
+# TODO(mdan): Elaborate the logic here.
+# TODO(mdan): Does it even make sense to attempt to try to use TAs?
+# The current rule (always convert to TensorArray) is naive and insufficient.
+# In general, a better mechanism could look like:
+# * convert to TensorList by default
+# * leave as Python list if the user explicitly forbids it
+# * convert to TensorArray only when complete write once behavior can be
+# guaranteed (e.g. list comprehensions)
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+# Tags for local state.
+POP_USES = 'pop_uses'
+
+
+class ListTransformer(converter.Base):
+ """Converts lists and related operations to their TF counterpart."""
+
+ def visit_List(self, node):
+ node = self.generic_visit(node)
+ template = """
+ ag__.new_list(elements)
+ """
+ return templates.replace_as_expression(template, elements=node)
+
+ def _replace_append_call(self, node):
+ assert len(node.args) == 1
+ assert isinstance(node.func, gast.Attribute)
+ template = """
+ target = ag__.list_append(target, element)
+ """
+ return templates.replace(
+ template,
+ target=node.func.value,
+ element=node.args[0])
+
+ def _replace_pop_call(self, node):
+ # Expressions that use pop() are converted to a statement + expression.
+ #
+ # For example:
+ #
+ # print(target.pop())
+ #
+ # ... is converted to:
+ #
+ # target, target_pop = ag__.list_pop(target)
+ # print(target_pop)
+ #
+ # Here, we just generate the variable name and swap it in,
+ # and _generate_pop_operation will handle the rest.
+ #
+ # Multiple uses of pop() are allowed:
+ #
+ # print(tartget.pop(), target.pop())
+ # print(tartget.pop().pop())
+ #
+ assert isinstance(node.func, gast.Attribute)
+ scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
+ target_node = node.func.value
+
+ # Attempt to use a related name if one exists. Otherwise use something
+ # generic.
+ if anno.hasanno(target_node, anno.Basic.QN):
+ target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
+ else:
+ target_name = 'list_'
+ pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)
+
+ pop_uses = self.get_local(POP_USES, [])
+ pop_uses.append((node, pop_var_name))
+ self.set_local(POP_USES, pop_uses)
+
+ return templates.replace_as_expression('var_name', var_name=pop_var_name)
+
+ def _replace_stack_call(self, node):
+ assert len(node.args) == 1
+ dtype = self.get_definition_directive(
+ node.args[0],
+ directives.set_element_type,
+ 'dtype',
+ default=templates.replace_as_expression('None'))
+ template = """
+ ag__.list_stack(
+ target,
+ opts=ag__.ListStackOpts(
+ element_dtype=dtype,
+ original_call=orig_call))
+ """
+ return templates.replace_as_expression(
+ template,
+ dtype=dtype,
+ target=node.args[0],
+ orig_call=node.func)
+
+ def visit_Call(self, node):
+ node = self.generic_visit(node)
+
+ # TODO(mdan): This is insufficient if target is a function argument.
+ # In the case of function arguments, we need to add the list to the
+ # function's return value, because it is being modified.
+ # TODO(mdan): Checking just the name is brittle, can it be improved?
+ if isinstance(node.func, gast.Attribute):
+ func_name = node.func.attr
+ if func_name == 'append' and (len(node.args) == 1):
+ node = self._replace_append_call(node)
+ elif func_name == 'pop' and (len(node.args) <= 1):
+ node = self._replace_pop_call(node)
+ elif (func_name == 'stack' and (len(node.args) == 1) and
+ (not node.keywords or node.keywords[0].arg == 'strict')):
+ # This avoids false positives with keyword args.
+ # TODO(mdan): handle kwargs properly.
+ node = self._replace_stack_call(node)
+
+ return node
+
+ def _generate_pop_operation(self, original_call_node, pop_var_name):
+ assert isinstance(original_call_node.func, gast.Attribute)
+
+ if original_call_node.args:
+ pop_element = original_call_node.args[0]
+ else:
+ pop_element = parser.parse_expression('None')
+
+ # The call will be something like "target.pop()", and the dtype is hooked to
+ # target, hence the func.value.
+ # TODO(mdan): For lists of lists, this won't work.
+ # The reason why it won't work is because it's unclear how to annotate
+ # the list as a "list of lists with a certain element type" when using
+ # operations like `l.pop().pop()`.
+ dtype = self.get_definition_directive(
+ original_call_node.func.value,
+ directives.set_element_type,
+ 'dtype',
+ default=templates.replace_as_expression('None'))
+ shape = self.get_definition_directive(
+ original_call_node.func.value,
+ directives.set_element_type,
+ 'shape',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ target, pop_var_name = ag__.list_pop(
+ target, element,
+ opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
+ """
+ return templates.replace(
+ template,
+ target=original_call_node.func.value,
+ pop_var_name=pop_var_name,
+ element=pop_element,
+ dtype=dtype,
+ shape=shape)
+
+ def _postprocess_statement(self, node):
+ """Inserts any separate pop() calls that node may use."""
+ pop_uses = self.get_local(POP_USES, None)
+ if pop_uses:
+ replacements = []
+ for original_call_node, pop_var_name in pop_uses:
+ replacements.extend(
+ self._generate_pop_operation(original_call_node, pop_var_name))
+ replacements.append(node)
+ node = replacements
+ self.exit_local_scope()
+ return node, None
+
+ # TODO(mdan): Should we have a generic visit_block instead?
+ # Right now it feels that a visit_block would add too much magic that's
+ # hard to follow.
+
+ def _visit_and_process_block(self, block):
+ return self.visit_block(
+ block,
+ before_visit=self.enter_local_scope,
+ after_visit=self._postprocess_statement)
+
+ def visit_FunctionDef(self, node):
+ node.args = self.generic_visit(node.args)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ node.body = self._visit_and_process_block(node.body)
+ return node
+
+ def visit_For(self, node):
+ node.target = self.visit(node.target)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_If(self, node):
+ node.test = self.visit(node.test)
+ node.body = self._visit_and_process_block(node.body)
+ node.orelse = self._visit_and_process_block(node.orelse)
+ return node
+
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self._visit_and_process_block(node.body)
+ return node
+
+
+def transform(node, ctx):
+ return ListTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/lists_test.py b/tensorflow/python/autograph/converters/lists_test.py
new file mode 100644
index 0000000000..f6da845fcc
--- /dev/null
+++ b/tensorflow/python/autograph/converters/lists_test.py
@@ -0,0 +1,132 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lists module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.lang import special_functions
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+tf = None # Will be replaced by a mock.
+
+
+class ListTest(converter_testing.TestCase):
+
+ def test_empty_list(self):
+
+ def test_fn():
+ return []
+
+ with self.converted(test_fn, lists, {}) as result:
+ tl = result.test_fn()
+ # Empty tensor lists cannot be evaluated or stacked.
+ self.assertTrue(isinstance(tl, ops.Tensor))
+ self.assertEqual(tl.dtype, dtypes.variant)
+
+ def test_initialized_list(self):
+
+ def test_fn():
+ return [1, 2, 3]
+
+ with self.converted(test_fn, lists, {}) as result:
+ self.assertAllEqual(result.test_fn(), [1, 2, 3])
+
+ def test_list_append(self):
+
+ def test_fn():
+ l = special_functions.tensor_list([1])
+ l.append(2)
+ l.append(3)
+ return l
+
+ ns = {'special_functions': special_functions}
+ with self.converted(test_fn, lists, ns) as result:
+ with self.cached_session() as sess:
+ tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2, 3])
+
+ def test_list_pop(self):
+
+ def test_fn():
+ l = special_functions.tensor_list([1, 2, 3])
+ s = l.pop()
+ return s, l
+
+ ns = {'special_functions': special_functions}
+ node, ctx = self.prepare(test_fn, ns)
+ def_, = anno.getanno(node.body[0].targets[0],
+ anno.Static.ORIG_DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32'),
+ 'shape': parser.parse_expression('()'),
+ }
+ node = lists.transform(node, ctx)
+
+ with self.compiled(node, ns, dtypes.int32) as result:
+ with self.cached_session() as sess:
+ ts, tl = result.test_fn()
+ r = list_ops.tensor_list_stack(tl, dtypes.int32)
+ self.assertAllEqual(sess.run(r), [1, 2])
+ self.assertAllEqual(sess.run(ts), 3)
+
+ def test_double_list_pop(self):
+
+ def test_fn(l):
+ s = l.pop().pop()
+ return s
+
+ with self.converted(test_fn, lists, {}) as result:
+ test_input = [1, 2, [1, 2, 3]]
+ # TODO(mdan): Pass a list of lists of tensor when we fully support that.
+ # For now, we just pass a regular Python list of lists just to verify that
+ # the two pop calls are sequenced properly.
+ self.assertAllEqual(result.test_fn(test_input), 3)
+
+ def test_list_stack(self):
+
+ def test_fn():
+ l = [1, 2, 3]
+ return tf.stack(l)
+
+ node, ctx = self.prepare(test_fn, {})
+ def_, = anno.getanno(node.body[0].targets[0],
+ anno.Static.ORIG_DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32')
+ }
+ node = lists.transform(node, ctx)
+
+ with self.compiled(node, {}, array_ops.stack, dtypes.int32) as result:
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(result.test_fn()), [1, 2, 3])
+
+ # TODO(mdan): Add a test with tf.stack with axis kwarg.
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
new file mode 100644
index 0000000000..ac42ee2c33
--- /dev/null
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -0,0 +1,132 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for logical expressions.
+
+e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+
+
+# TODO(mdan): Properly extrack boolean ops according to lazy eval rules.
+# Note that this isn't completely safe either, because tensors may have control
+# dependencies.
+# Note that for loops that should be done after the loop was converted to
+# tf.while_loop so that the expanded conditionals are properly scoped.
+
+# Used to signal that an operand is safe for non-lazy evaluation.
+SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
+
+
+class LogicalExpressionTransformer(converter.Base):
+ """Converts logical expressions to corresponding TF calls."""
+
+ def __init__(self, ctx):
+ super(LogicalExpressionTransformer, self).__init__(ctx)
+ # TODO(mdan): Look into replacing with bitwise operators instead.
+ # TODO(mdan): Skip replacing if the function is trivial.
+ self.op_mapping = {
+ gast.And: 'tf.logical_and',
+ gast.Eq: 'tf.equal',
+ gast.Gt: 'tf.greater',
+ gast.GtE: 'tf.greater_equal',
+ gast.Lt: 'tf.less',
+ gast.LtE: 'tf.less_equal',
+ gast.Not: 'tf.logical_not',
+ gast.NotEq: 'tf.not_equal',
+ gast.Or: 'tf.logical_or',
+ gast.USub: 'tf.negative',
+ gast.Is: 'ag__.utils.dynamic_is',
+ gast.IsNot: 'ag__.utils.dynamic_is_not'
+ }
+
+ def _expect_simple_symbol(self, operand):
+ if isinstance(operand, gast.Name):
+ return
+ if anno.hasanno(operand, SAFE_BOOLEAN_OPERAND):
+ return
+ raise NotImplementedError(
+ 'only simple local variables are supported in logical and compound '
+ 'comparison expressions; for example, we support "a or b" but not '
+ '"a.x or b"; for a workaround, assign the expression to a local '
+ 'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+
+ def _matching_func(self, operator):
+ op_type = type(operator)
+ mapped_op = self.op_mapping.get(op_type)
+ if not mapped_op:
+ raise NotImplementedError('operator %s is not yet supported' % op_type)
+ return mapped_op
+
+ def _as_function(self, func_name, args):
+ template = """
+ func_name(args)
+ """
+ replacement = templates.replace_as_expression(
+ template, func_name=parser.parse_expression(func_name), args=args)
+ anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
+ return replacement
+
+ def visit_Compare(self, node):
+ node = self.generic_visit(node)
+ ops_and_comps = list(zip(node.ops, node.comparators))
+ left = node.left
+ op_tree = None
+
+ # Repeated comparisons are converted to conjunctions:
+ # a < b < c -> a < b and b < c
+ while ops_and_comps:
+ op, right = ops_and_comps.pop(0)
+ binary_comparison = self._as_function(
+ self._matching_func(op), (left, right))
+ if isinstance(left, gast.Name) and isinstance(right, gast.Name):
+ anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
+ if op_tree:
+ self._expect_simple_symbol(right)
+ op_tree = self._as_function('tf.logical_and',
+ (binary_comparison, op_tree))
+ else:
+ op_tree = binary_comparison
+ left = right
+ assert op_tree is not None
+ return op_tree
+
+ def visit_UnaryOp(self, node):
+ node = self.generic_visit(node)
+ return self._as_function(self._matching_func(node.op), node.operand)
+
+ def visit_BoolOp(self, node):
+ node = self.generic_visit(node)
+ node_values = node.values
+ right = node.values.pop()
+ self._expect_simple_symbol(right)
+ while node_values:
+ left = node_values.pop()
+ self._expect_simple_symbol(left)
+ right = self._as_function(self._matching_func(node.op), (left, right))
+ return right
+
+
+def transform(node, ctx):
+ return LogicalExpressionTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py
new file mode 100644
index 0000000000..5fb3fb992f
--- /dev/null
+++ b/tensorflow/python/autograph/converters/logical_expressions_test.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for logical_expressions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GradientsFunctionTest(converter_testing.TestCase):
+
+ def test_equals(self):
+
+ def test_fn(a, b):
+ return a == b
+
+ with self.converted(test_fn, logical_expressions, {},
+ math_ops.equal) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(1, 1)))
+ self.assertFalse(sess.run(result.test_fn(1, 2)))
+
+ def test_bool_ops(self):
+
+ def test_fn(a, b, c):
+ return (a or b) and (a or b or c)
+
+ with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or,
+ math_ops.logical_and) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False, True)))
+
+ def test_ag_utils_lookup(self):
+ def test_fn(a, b):
+ return a is b or a is not b
+
+ with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
+ ) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False)))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/name_scopes.py
new file mode 100644
index 0000000000..a9c55ccff0
--- /dev/null
+++ b/tensorflow/python/autograph/converters/name_scopes.py
@@ -0,0 +1,74 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wraps a function body with a `name_scope` of the function name."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import templates
+
+
+class FunctionNameScopeTransformer(converter.Base):
+ """Wrap a function body with a `name_scope` of the function name."""
+
+ def _name_for_current_scope(self):
+ innermost = self.enclosing_entities[-1]
+ if len(self.enclosing_entities) > 1:
+ parent = self.enclosing_entities[-2]
+ if isinstance(parent, gast.ClassDef):
+ # Methods also take the name of their class.
+ name = '%s/%s' % (parent.name, innermost.name)
+ else:
+ name = innermost.name
+ else:
+ name = innermost.name
+
+ # Sanitize the name.
+ # See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope
+ # TensorFlow doesn't like leading underscores at the top level.
+ while name[0] == '_':
+ name = name[1:]
+ return name
+
+ def visit_FunctionDef(self, node):
+ node = self.generic_visit(node)
+
+ unscoped_body = []
+ scoped_body = node.body
+ if scoped_body:
+ first = scoped_body[0]
+ if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str):
+ # Skip any docstring.
+ unscoped_body = scoped_body[:1]
+ scoped_body = scoped_body[1:]
+
+ template = """
+ with tf.name_scope(scope_name):
+ body
+ """
+ scoped_body = templates.replace(
+ template,
+ scope_name=gast.Str(self._name_for_current_scope()),
+ body=scoped_body)
+ node.body = unscoped_body + scoped_body
+ return node
+
+
+def transform(node, ctx):
+ return FunctionNameScopeTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/name_scopes_test.py
new file mode 100644
index 0000000000..73933c1c4f
--- /dev/null
+++ b/tensorflow/python/autograph/converters/name_scopes_test.py
@@ -0,0 +1,101 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for for_canonicalization module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class FunctionNameScopeTransformer(converter_testing.TestCase):
+
+ def test_basic(self):
+
+ def test_fn(l):
+ """This should stay here."""
+ a = 1
+ l += a
+ return l
+
+ with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ result_op = result.test_fn(constant_op.constant(1))
+ self.assertIn('test_fn/', result_op.op.name)
+ self.assertEqual('This should stay here.', result.test_fn.__doc__)
+
+ def test_long_docstring(self):
+
+ def test_fn(l):
+ """Multi-line docstring.
+
+ Args:
+ l: A thing.
+ Returns:
+ l
+ """
+ return l + 1
+
+ with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ result_op = result.test_fn(constant_op.constant(1))
+ self.assertIn('test_fn/', result_op.op.name)
+ self.assertIn('Multi-line docstring.', result.test_fn.__doc__)
+ self.assertIn('Returns:', result.test_fn.__doc__)
+
+ def test_nested_functions(self):
+
+ def test_fn(l):
+
+ def inner_fn(i):
+ return i + 1
+
+ l += 1
+ return l, inner_fn(l)
+
+ with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ first, second = result.test_fn(constant_op.constant(1))
+ self.assertIn('test_fn/', first.op.name)
+ self.assertNotIn('inner_fn', first.op.name)
+ self.assertIn('test_fn/inner_fn/', second.op.name)
+
+ def test_method(self):
+
+ class TestClass(object):
+
+ def test_fn(self, l):
+
+ def inner_fn(i):
+ return i + 1
+
+ l += 1
+ return l, inner_fn(l)
+
+ ns = {'TestClass': TestClass}
+ node, ctx = self.prepare(TestClass, ns, owner_type=TestClass)
+ node = name_scopes.transform(node, ctx)
+
+ with self.compiled(node, {}, ops.name_scope) as result:
+ first, second = result.TestClass().test_fn(constant_op.constant(1))
+ self.assertIn('TestClass/test_fn/', first.op.name)
+ self.assertNotIn('inner_fn', first.op.name)
+ self.assertIn('TestClass/test_fn/inner_fn/', second.op.name)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
new file mode 100644
index 0000000000..62da045d6a
--- /dev/null
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -0,0 +1,317 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Canonicalizes functions with multiple returns to use just one."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+# TODO(mdan): Move this logic into transformer_base.
+class BodyVisitor(converter.Base):
+ """Walks breadth- or depth-first the list-of-nodes bodies of AST nodes."""
+
+ def __init__(self, ctx, depth_first=False):
+ super(BodyVisitor, self).__init__(ctx)
+ self.depth_first = depth_first
+ self.changes_made = False
+
+ def visit_nodelist(self, nodelist):
+ for node in nodelist:
+ if isinstance(node, list):
+ node = self.visit_nodelist(node)
+ else:
+ node = self.generic_visit(node)
+ return nodelist
+
+ def visit_If(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_For(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_While(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_Try(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ node.orelse = self.visit_nodelist(node.orelse)
+ node.finalbody = self.visit_nodelist(node.finalbody)
+ for i in range(len(node.handlers)):
+ node.handlers[i].body = self.visit_nodelist(node.handlers[i].body)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_With(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+ def visit_FunctionDef(self, node):
+ if self.depth_first:
+ node = self.generic_visit(node)
+ node.body = self.visit_nodelist(node.body)
+ self.generic_visit(node)
+ if not self.depth_first:
+ node = self.generic_visit(node)
+ return node
+
+
+class FoldElse(BodyVisitor):
+
+ def visit_nodelist(self, nodelist):
+ for i in range(len(nodelist)):
+ node = nodelist[i]
+ if isinstance(node, gast.If):
+ true_branch_returns = isinstance(node.body[-1], gast.Return)
+ false_branch_returns = len(node.orelse) and isinstance(
+ node.orelse[-1], gast.Return)
+ # If the last node in the if body is a return,
+ # then every line after this if statement effectively
+ # belongs in the else.
+ if true_branch_returns and not false_branch_returns:
+ for j in range(i + 1, len(nodelist)):
+ nodelist[i].orelse.append(ast_util.copy_clean(nodelist[j]))
+ if nodelist[i + 1:]:
+ self.changes_made = True
+ return nodelist[:i + 1]
+ elif not true_branch_returns and false_branch_returns:
+ for j in range(i + 1, len(nodelist)):
+ nodelist[i].body.append(ast_util.copy_clean(nodelist[j]))
+ if nodelist[i + 1:]:
+ self.changes_made = True
+ return nodelist[:i + 1]
+ elif true_branch_returns and false_branch_returns:
+ if nodelist[i + 1:]:
+ raise ValueError(
+ 'Unreachable code after conditional where both branches return.'
+ )
+ return nodelist
+ elif isinstance(node, gast.Return) and nodelist[i + 1:]:
+ raise ValueError(
+ 'Cannot have statements after a return in the same basic block')
+ return nodelist
+
+
+def contains_return(node):
+ for n in gast.walk(node):
+ if isinstance(n, gast.Return):
+ return True
+ return False
+
+
+class LiftReturn(converter.Base):
+ """Move return statements out of If and With blocks."""
+
+ def __init__(self, ctx):
+ super(LiftReturn, self).__init__(ctx)
+ self.changes_made = False
+ self.common_return_name = None
+
+ def visit_If(self, node):
+ # Depth-first traversal of if statements
+ node = self.generic_visit(node)
+
+ # We check if both branches return, and if so, lift the return out of the
+ # conditional. We don't enforce that the true and false branches either
+ # both return or both do not, because FoldElse might move a return
+ # into a branch after this transform completes. FoldElse and LiftReturn
+ # are alternately run until the code reaches a fixed point.
+ true_branch_returns = isinstance(node.body[-1], gast.Return)
+ false_branch_returns = len(node.orelse) and isinstance(
+ node.orelse[-1], gast.Return)
+ if true_branch_returns and false_branch_returns:
+ node.body[-1] = templates.replace(
+ 'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
+ node.orelse[-1] = templates.replace(
+ 'a = b', a=self.common_return_name, b=node.orelse[-1].value)[0]
+ return_node = templates.replace('return a', a=self.common_return_name)[0]
+ self.changes_made = True
+ return [node, return_node]
+ else:
+ return node
+
+ def visit_With(self, node):
+ # Depth-first traversal of syntax
+ node = self.generic_visit(node)
+
+ # If the with statement returns, lift the return
+ if isinstance(node.body[-1], gast.Return):
+ node.body[-1] = templates.replace(
+ 'a = b', a=self.common_return_name, b=node.body[-1].value)[0]
+ return_node = templates.replace('return a', a=self.common_return_name)[0]
+ node = self.generic_visit(node)
+ self.changes_made = True
+ return [node, return_node]
+ else:
+ return node
+
+ def visit_FunctionDef(self, node):
+ # Ensure we're doing depth-first traversal
+ last_return_name = self.common_return_name
+ body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+ referenced_names = body_scope.referenced
+ self.common_return_name = self.ctx.namer.new_symbol('return_',
+ referenced_names)
+ node = self.generic_visit(node)
+ self.common_return_name = last_return_name
+ return node
+
+
+class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
+ """Throws an error if code returns inside loops or try/except."""
+
+ # First, throw an error if we detect a return statement in a loop.
+ # TODO(alexbw): we need to learn to handle returns inside a loop,
+ # but don't currently have the TF constructs to do so (need something
+ # that looks vaguely like a goto).
+
+ def __init__(self):
+ self.cant_return = False
+ super(DetectReturnInUnsupportedControlFlow, self).__init__()
+
+ def visit_While(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_For(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_Try(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_Return(self, node):
+ if self.cant_return:
+ raise ValueError(
+ '`return` statements are not supported in loops. '
+ 'Try assigning to a variable in the while loop, and returning '
+ 'outside of the loop')
+
+
+class DetectReturnInConditional(gast.NodeVisitor):
+ """Assert that no return statements are present in conditionals."""
+
+ def __init__(self):
+ self.cant_return = False
+ super(DetectReturnInConditional, self).__init__()
+
+ def visit_If(self, node):
+ self.cant_return = True
+ self.generic_visit(node)
+ self.cant_return = False
+
+ def visit_Return(self, node):
+ if self.cant_return:
+ raise ValueError(
+ 'After transforms, a conditional contained a `return `statement, '
+ 'which is not allowed. This is a bug, and should not happen.')
+
+
+class DetectReturnInFunctionDef(gast.NodeVisitor):
+
+ def visit_FunctionDef(self, node):
+ self.generic_visit(node)
+ if not contains_return(node):
+ raise ValueError(
+ 'Each function definition should contain at least one return.')
+
+
+def transform(node, ctx):
+ """Ensure a function has only a single return.
+
+ This transforms an AST node with multiple returns successively into containing
+ only a single return node.
+ There are a few restrictions on what we can handle:
+ - An AST being transformed must contain at least one return.
+ - No returns allowed in loops. We have to know the type of the return value,
+ and we currently don't have either a type inference system to discover it,
+ nor do we have a mechanism for late type binding in TensorFlow.
+ - After all transformations are finished, a Return node is not allowed inside
+ control flow. If we were unable to move a return outside of control flow,
+ this is an error.
+
+ Args:
+ node: ast.AST
+ ctx: converter.EntityContext
+
+ Returns:
+ new_node: an AST with a single return value
+
+ Raises:
+ ValueError: if the AST is structured so that we can't perform the
+ transform.
+ """
+ # Make sure that the function has at least one return statement
+ # TODO(alexbw): turning off this assertion for now --
+ # we need to not require this in e.g. class constructors.
+ # DetectReturnInFunctionDef().visit(node)
+
+ # Make sure there's no returns in unsupported locations (loops, try/except)
+ DetectReturnInUnsupportedControlFlow().visit(node)
+
+ while True:
+
+ # Try to lift all returns out of if statements and with blocks
+ lr = LiftReturn(ctx)
+ node = lr.visit(node)
+ changes_made = lr.changes_made
+ fe = FoldElse(ctx)
+ node = fe.visit(node)
+ changes_made = changes_made or fe.changes_made
+
+ if not changes_made:
+ break
+
+ # Make sure we've scrubbed all returns from conditionals
+ DetectReturnInConditional().visit(node)
+
+ return node
diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
new file mode 100644
index 0000000000..01dd03da0b
--- /dev/null
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -0,0 +1,167 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for return_statements module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class SingleReturnTest(converter_testing.TestCase):
+
+ def assertTransformedEquivalent(self, test_fn, *inputs):
+ ns = {'ops': ops}
+ with self.converted(test_fn, return_statements, ns) as result:
+ self.assertEqual(test_fn(*inputs), result.test_fn(*inputs))
+
+ def test_straightline(self):
+
+ def test_fn(x):
+ return x * x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+
+ def test_conditional(self):
+
+ def test_fn(x):
+ if x > 0:
+ return x
+ else:
+ return x * x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def test_missing_orelse(self):
+
+ def test_fn(x):
+ if x > 0:
+ return x
+
+ node, ctx = self.prepare(test_fn, {})
+ with self.assertRaises(ValueError):
+ return_statements.transform(node, ctx)
+
+ def test_missing_orelse_recovrable(self):
+
+ def test_fn(x):
+ if x > 0:
+ return x
+ return x * x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def test_missing_branch_return_recoverable(self):
+
+ def test_fn(x):
+ if x < 0:
+ x *= x
+ else:
+ return x
+ return x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def test_conditional_nested(self):
+
+ def test_fn(x):
+ if x > 0:
+ if x < 5:
+ return x
+ else:
+ return x * x
+ else:
+ return x * x * x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+ self.assertTransformedEquivalent(test_fn, 5)
+
+ def test_context_manager(self):
+
+ def test_fn(x):
+ with ops.name_scope(''):
+ return x * x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def test_context_manager_in_conditional(self):
+
+ def test_fn(x):
+ if x > 0:
+ with ops.name_scope(''):
+ return x * x
+ else:
+ return x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def text_conditional_in_context_manager(self):
+
+ def test_fn(x):
+ with ops.name_scope(''):
+ if x > 0:
+ return x * x
+ else:
+ return x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def test_no_return(self):
+
+ def test_fn(x):
+ x *= x
+
+ self.assertTransformedEquivalent(test_fn, 2)
+
+ def test_nested_functions(self):
+
+ def test_fn(x):
+
+ def inner_fn(y):
+ if y > 0:
+ return y * y
+ else:
+ return y
+
+ return inner_fn(x)
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
+ def test_loop(self):
+
+ def test_fn(x):
+ for _ in range(10):
+ return x
+ return x
+
+ node, ctx = self.prepare(test_fn, {})
+ with self.assertRaises(ValueError):
+ return_statements.transform(node, ctx)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
new file mode 100644
index 0000000000..6e48e57bde
--- /dev/null
+++ b/tensorflow/python/autograph/converters/side_effect_guards.py
@@ -0,0 +1,183 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adds guards against function calls with side effects.
+
+Only standalone calls are guarded.
+
+WARNING: This mechanism is incomplete. Particularly, it only guards the
+arguments passed to functions, and does not account for indirectly modified
+state.
+
+Example:
+ y = tf.layers.dense(x) # Creates TF variable 'foo'
+ loss = loss(y)
+ opt.minimize(loss) # indirectly affects 'foo'
+ z = tf.get_variable('foo') # Indirectly affects `loss` and 'foo'
+ # Here, `loss` can be guarded. But `z` cannot.
+
+# TODO(mdan): We should probably define a safe mode where we guard everything.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+
+class SymbolNamer(object):
+ """Describes the interface for SideEffectGuardTransformer's namer."""
+
+ def new_symbol(self, name_root, reserved_locals):
+ """Generate a new unique function_name.
+
+ Args:
+ name_root: String, used as stem in the new name.
+ reserved_locals: Set(string), additional local symbols that are reserved.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+
+class SideEffectGuardTransformer(converter.Base):
+ """Adds control dependencies to functions with side effects."""
+
+ def _visit_and_reindent(self, nodes):
+ new_nodes = []
+ current_dest = new_nodes
+ alias_map = {}
+ reindent_requested = False
+ for n in nodes:
+ n = self.visit(n)
+ # NOTE: the order in which these statements execute is important; in
+ # particular, watch out for ending up with cycles in the AST.
+ if alias_map:
+ n = ast_util.rename_symbols(n, alias_map)
+ if isinstance(n, (list, tuple)):
+ current_dest.extend(n)
+ else:
+ current_dest.append(n)
+ if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
+ reindent_requested = True
+ new_dest, new_alias_map = anno.getanno(
+ n, anno.Basic.INDENT_BLOCK_REMAINDER)
+ anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
+ new_alias_map.update(alias_map)
+ alias_map = new_alias_map
+ current_dest = new_dest
+ if reindent_requested and not current_dest:
+ # TODO(mdan): There may still be something that could be done.
+ raise ValueError('Unable to insert statement into the computation flow: '
+ 'it is not followed by any computation which '
+ 'the statement could gate.')
+ return new_nodes
+
+ def visit_FunctionDef(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ return node
+
+ def visit_With(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ return node
+
+ def visit_If(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ node.orelse = self._visit_and_reindent(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ node.body = self._visit_and_reindent(node.body)
+ node.orelse = self._visit_and_reindent(node.orelse)
+ return node
+
+ def visit_Expr(self, node):
+ self.generic_visit(node)
+ if isinstance(node.value, gast.Call):
+ # Patterns of single function calls, like:
+ # opt.minimize(loss)
+ # or:
+ # tf.py_func(...)
+
+ # First, attempt to gate future evaluation of args. If that's not
+ # possible, gate all remaining statements (and that may fail too, see
+ # _visit_and_reindent.
+ args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
+ # NOTE: We can't guard object attributes because they may not be writable.
+ # In addition, avoid renaming well-known names.
+ # TODO(mdan): Move these names into config.
+ unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
+ guarded_args = tuple(s for s in args_scope.used
+ if not s.is_composite() and s not in unguarded_names)
+
+ # TODO(mdan): Include all arguments which depended on guarded_args too.
+ # For example, the following will still cause a race:
+ # tf.assign(a, a + 1)
+ # b = a + 1
+ # tf.assign(a, a + 1) # Control deps here should include `b`
+ # c = b + 1
+ # Or maybe we should just raise an "unsafe assign" error?
+
+ if guarded_args:
+ # The aliases may need new names to avoid incorrectly making them local.
+ # TODO(mdan): This is brutal. It will even rename modules - any fix?
+ need_alias = tuple(
+ s for s in guarded_args if s not in args_scope.parent.modified)
+ aliased_new_names = tuple(
+ qual_names.QN(
+ self.ctx.namer.new_symbol(
+ s.ssf(), args_scope.parent.referenced)) for s in need_alias)
+ alias_map = dict(zip(need_alias, aliased_new_names))
+ if len(guarded_args) == 1:
+ s, = guarded_args
+ aliased_guarded_args = alias_map.get(s, s)
+ else:
+ aliased_guarded_args = gast.Tuple(
+ [alias_map.get(s, s).ast() for s in guarded_args], None)
+
+ template = """
+ with ag__.utils.control_dependency_on_returns(call):
+ aliased_guarded_args = ag__.utils.alias_tensors(guarded_args)
+ """
+ control_deps_guard = templates.replace(
+ template,
+ call=node.value,
+ aliased_guarded_args=aliased_guarded_args,
+ guarded_args=guarded_args)[-1]
+ else:
+ alias_map = {}
+
+ template = """
+ with ag__.utils.control_dependency_on_returns(call):
+ pass
+ """
+ control_deps_guard = templates.replace(template, call=node.value)[-1]
+ control_deps_guard.body = []
+
+ node = control_deps_guard
+ anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
+ (node.body, alias_map))
+ return node
+
+
+def transform(node, ctx):
+ return SideEffectGuardTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/side_effect_guards_test.py b/tensorflow/python/autograph/converters/side_effect_guards_test.py
new file mode 100644
index 0000000000..cef3199169
--- /dev/null
+++ b/tensorflow/python/autograph/converters/side_effect_guards_test.py
@@ -0,0 +1,163 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for side_effect_guards module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+tf = None # Will be replaced by a mock.
+
+
+class SideEffectGuardsTest(converter_testing.TestCase):
+
+ def test_side_effect_on_return_only_variable(self):
+
+ def test_fn(a):
+ tf.assign(a, a + 1)
+ return a
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
+
+ self.assertEqual(len(node.body), 1)
+
+ with self.compiled(node, {}, state_ops.assign) as result:
+ with self.cached_session() as sess:
+ v = variable_scope.get_variable('test', initializer=2)
+ sess.run(v.initializer)
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Add support for this use case.
+ # Right now the variable `a` is not conditioned on the `assign` because
+ # there's no way to add control dependencies to a variable object.
+ self.assertEqual(2, sess.run(v))
+
+ def test_side_effect_on_used_variable(self):
+
+ def test_fn(a):
+ tf.assign(a, a + 1)
+ return a + 1
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
+
+ self.assertEqual(len(node.body), 1)
+
+ with self.compiled(node, {}, state_ops.assign) as result:
+ with self.cached_session() as sess:
+ v = variable_scope.get_variable('test', initializer=2)
+ sess.run(v.initializer)
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ # Right now it's 3 or 4 based on whether the read is synchronized.
+ self.assertEqual(3, sess.run(v))
+
+ def test_side_effect_on_tensor(self):
+
+ def test_fn(a):
+ tf.Assert(a > 0, ['expected in throw'])
+ return a
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
+
+ self.assertEqual(len(node.body), 1)
+
+ with self.compiled(node, {}, control_flow_ops.Assert) as result:
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ 'expected in throw'):
+ sess.run(result.test_fn(constant_op.constant(-1)))
+
+ def test_multiline_block(self):
+
+ def test_fn(a):
+ tf.assign_add(a, 1)
+ b = a + 1
+ tf.assign_add(a, 1)
+ b += 1
+ return b
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
+
+ self.assertEqual(len(node.body), 1)
+
+ with self.compiled(node, {}, state_ops.assign_add) as result:
+ with self.cached_session() as sess:
+ v = variable_scope.get_variable('test', initializer=2)
+ sess.run(v.initializer)
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ self.assertEqual(4, sess.run(v))
+
+ def test_multiline_nested_block(self):
+
+ def test_fn(a):
+ with tf.name_scope('foo'):
+ tf.assign(a, a + 1)
+ b = a + 1
+ return b
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
+
+ self.assertEqual(len(node.body[0].body), 1)
+
+ with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
+ with self.cached_session() as sess:
+ v = variable_scope.get_variable('test', initializer=2)
+ sess.run(v.initializer)
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ self.assertEqual(3, sess.run(v))
+
+ def test_multiline_block_unsafe(self):
+
+ def test_fn(a):
+ tf.assign(a, a + 1)
+ b = a + 1
+ tf.assign_add(a, 1)
+ c = b + 1
+ return c
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
+
+ self.assertEqual(len(node.body), 1)
+
+ with self.compiled(node, {}, state_ops.assign,
+ state_ops.assign_add) as result:
+ with self.cached_session() as sess:
+ v = variable_scope.get_variable('test', initializer=2)
+ sess.run(v.initializer)
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ self.assertEqual(4, sess.run(v))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/converters/slices.py b/tensorflow/python/autograph/converters/slices.py
new file mode 100644
index 0000000000..11cea6de5b
--- /dev/null
+++ b/tensorflow/python/autograph/converters/slices.py
@@ -0,0 +1,85 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for slice operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import templates
+
+
+class SliceTransformer(converter.Base):
+ """Converts slicing operations to their TF counterpart.
+
+ Currently, relying on the default slice operator that Tensor uses is
+ insufficient, because TensorArray and tensor lists use dedicated index read
+ and write functions.
+ """
+
+ def _process_single_assignment(self, target, value):
+ if not isinstance(target, gast.Subscript):
+ return None
+ if not isinstance(target.slice, gast.Index):
+ return None
+
+ template = """
+ target = ag__.set_item(target, key, item)
+ """
+ return templates.replace(
+ template, target=target.value, key=target.slice.value, item=value)
+
+ def visit_Assign(self, node):
+ node = self.generic_visit(node)
+ # TODO(mdan): Support unpackings and multiple assignments.
+ if len(node.targets) != 1:
+ raise NotImplementedError('multiple assignment')
+ replacement = self._process_single_assignment(node.targets[0], node.value)
+ if replacement is not None:
+ return replacement
+ return node
+
+ def visit_Subscript(self, node):
+ node = self.generic_visit(node)
+ if not isinstance(node.slice, gast.Index):
+ return node
+
+ if not isinstance(node.ctx, gast.Load):
+ # Index writes are handled at a higher level, one at which the rvalue is
+ # also available.
+ return node
+
+ dtype = self.get_definition_directive(
+ node.value,
+ directives.set_element_type,
+ 'dtype',
+ default=templates.replace_as_expression('None'))
+
+ template = """
+ ag__.get_item(
+ target,
+ key,
+ opts=ag__.GetItemOpts(element_dtype=dtype))
+ """
+ return templates.replace_as_expression(
+ template, target=node.value, key=node.slice.value, dtype=dtype)
+
+
+def transform(node, ctx):
+ return SliceTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/slices_test.py b/tensorflow/python/autograph/converters/slices_test.py
new file mode 100644
index 0000000000..e190a7cfe8
--- /dev/null
+++ b/tensorflow/python/autograph/converters/slices_test.py
@@ -0,0 +1,76 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.lang import directives
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SliceTest(converter_testing.TestCase):
+
+ def test_index_access(self):
+
+ def test_fn(l):
+ return l[1]
+
+ node, ctx = self.prepare(test_fn, {})
+ def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32')
+ }
+ node = slices.transform(node, ctx)
+
+ with self.compiled(node, {}, dtypes.int32) as result:
+ with self.cached_session() as sess:
+ tl = list_ops.tensor_list_from_tensor(
+ [1, 2], element_shape=constant_op.constant([], dtype=dtypes.int32))
+ y = result.test_fn(tl)
+ self.assertEqual(2, sess.run(y))
+
+ def test_index_access_multiple_definitions(self):
+
+ def test_fn(l):
+ if l:
+ l = []
+ return l[1]
+
+ node, ctx = self.prepare(test_fn, {})
+ def_, = anno.getanno(node.args.args[0], anno.Static.DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.int32')
+ }
+ def_, = anno.getanno(node.body[0].body[0].targets[0],
+ anno.Static.DEFINITIONS)
+ def_.directives[directives.set_element_type] = {
+ 'dtype': parser.parse_expression('tf.float32')
+ }
+ with self.assertRaises(transformer.AutographParseError):
+ slices.transform(node, ctx)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
new file mode 100644
index 0000000000..85fecf084d
--- /dev/null
+++ b/tensorflow/python/autograph/core/BUILD
@@ -0,0 +1,75 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "core",
+ srcs = [
+ "config.py",
+ "converter.py",
+ "errors.py",
+ "naming.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
+ ],
+)
+
+py_test(
+ name = "errors_test",
+ srcs = ["errors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ ],
+)
+
+py_test(
+ name = "naming_test",
+ srcs = ["naming_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_library(
+ name = "test_lib",
+ srcs = [
+ "converter_testing.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/python/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
new file mode 100644
index 0000000000..4fa8489af5
--- /dev/null
+++ b/tensorflow/python/autograph/core/config.py
@@ -0,0 +1,49 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Global configuration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph import utils
+
+
+PYTHON_LITERALS = {
+ 'None': None,
+ 'False': False,
+ 'True': True,
+ 'float': float,
+}
+
+DEFAULT_UNCOMPILED_MODULES = set((
+ ('tensorflow',),
+ (utils.__name__,),
+
+ # All of tensorflow's subpackages. Unlike the root tf module, they don't
+ # have well-known names. Not referring to the module directly to avoid
+ # circular imports.
+ (
+ utils.__name__[:-len('.python.autograph.utils')],),
+))
+
+NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
+
+# TODO(mdan): Also allow controlling the generated names.
+# TODO(mdan); Consolidate all internal imports into a single __ag module.
+COMPILED_IMPORT_STATEMENTS = (
+ 'from __future__ import print_function',
+ 'import tensorflow as tf',
+)
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
new file mode 100644
index 0000000000..7b3905fdee
--- /dev/null
+++ b/tensorflow/python/autograph/core/converter.py
@@ -0,0 +1,330 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter construction support.
+
+This module contains a base class for all converters, as well as supporting
+structures. These structures are referred to as contexts.
+
+The class hierarchy is as follows:
+
+ <your converter>
+ [extends] converter.Base
+ [extends] transformer.Base
+ [extends] gast.nodeTransformer
+ [uses] transfomer.SourceInfo
+ [uses] converter.EntityContext
+ [uses] converter.ProgramContext
+ [uses] transfomer.SourceInfo
+
+converter.Base is a specialization of transformer.Base for AutoGraph. It's a
+very lightweight subclass that adds a `ctx` attribute holding the corresponding
+EntityContext object (see below). Note that converters are not reusable, and
+`visit` will raise an error if called more than once.
+
+converter.EntityContext contains mutable state associated with an entity that
+the converter processes.
+
+converter.ProgramContext contains mutable state across related entities. For
+example, when converting several functions that call one another, the
+ProgramContext should be shared across these entities.
+
+Below is the overal flow at conversion:
+
+ program_ctx = ProgramContext(<entities to convert>, <global settings>, ...)
+ while <program_ctx has more entities to convert>:
+ entity, source_info = <get next entity from program_ctx>
+ entity_ctx = EntityContext(program_ctx, source_info)
+ for <each ConverterClass>:
+ converter = ConverterClass(entity_ctx)
+
+ # May update entity_ctx and program_ctx
+ entity = converter.visit(entity)
+
+ <add entity's dependencies to program_ctx>
+
+Note that pyct contains a small number of transformers used for static analysis.
+These implement transformer.Base, rather than converter.Base, to avoid a
+dependency on AutoGraph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from enum import Enum
+
+
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import naming
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import liveness
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
+
+# TODO(mdan): These contexts can be refactored into first class objects.
+# For example, we could define Program and Entity abstractions that hold on
+# to the actual entity and have conversion methods.
+
+# TODO(mdan): Add a test specific to this converter.
+
+
+class ProgramContext(object):
+ """ProgramContext keeps track of converting function hierarchies.
+
+ This object is mutable, and is updated during conversion. Not thread safe.
+
+ Attributes:
+ recursive: bool, whether to recursively convert any functions that the
+ decorator function may call.
+ autograph_decorators: Tuple[Callable, ...], decorator functions that belong
+ to AutoGraph. These require special treatment.
+ dependency_cache: Dict[Any, ast.AST], the original entities mapped to their
+ converted AST
+ additional_imports: Set[Any], additional entities which for any reason
+ cannot be attached after loading and need to be explicitly imported
+ in the generated code
+ name_map: Dict[str, str], map of original entity name to the name of
+ their converted counterparts
+ autograph_module: Module, a reference to the autograph module. This
+ needs to be specified by the caller to avoid circular dependencies.
+ uncompiled_modules: Set[Tuple[str, ...]], with each tuple representing the
+ fully qualified name of a package containing functions that will not be
+ compiled.
+ required_imports: str, containing an import statement on each line. These
+ are all the imports necessary for the compiled code to run, in addition
+ to the closures of each entity, which are attached dynamically.
+ """
+
+ def __init__(
+ self,
+ recursive,
+ autograph_decorators,
+ partial_types,
+ autograph_module,
+ uncompiled_modules,
+ ):
+ self.recursive = recursive
+ self.autograph_decorators = autograph_decorators
+ self.partial_types = partial_types if partial_types else ()
+ self.autograph_module = autograph_module
+ self.uncompiled_modules = uncompiled_modules
+
+ # Required to output dependencies in discovery order, which should match
+ # the reverse dependency order.
+ self.dependency_cache = collections.OrderedDict()
+ self.additional_imports = set()
+ self.name_map = {}
+
+ @property
+ def required_imports(self):
+ """Returns a block containing all imports required by the converted code."""
+ # TODO(mdan): Check that these don't clobber one another.
+ return '\n'.join(config.COMPILED_IMPORT_STATEMENTS +
+ tuple(self.additional_imports))
+
+ def new_namer(self, namespace):
+ return naming.Namer(namespace, self.recursive, self.name_map,
+ self.partial_types)
+
+ def update_name_map(self, namer):
+ """Updates renamed_calls based on the recent activity from the namer.
+
+ Whenever we convert a new entity, any references to other entities are being
+ renamed to match their soon-to-be-converted counterparts. The namer keeps
+ track of these renames. When conversion is complete, we copy those renames
+ so that when those referenced entities are being converted, their new name
+ matches.
+
+ Args:
+ namer: naming.Namer
+
+ Raises:
+ ValueError: when an entity was renamed twice and to different names.
+ """
+ # TODO(mdan): Have call_trees do this directly.
+ # This is done so indirectly, via the namer, for historic reasons. But
+ # now we can have the converter that does the rename record the new name
+ # as well and skip this step altogether.
+ for o, name in namer.renamed_calls.items():
+ if o in self.name_map:
+ if self.name_map[o] != name:
+ raise ValueError(
+ 'Calls to %s were converted using multiple names (%s). This is '
+ 'possible when an entity with one of these names already '
+ 'existed. To fix, avoid using any of these names.' %
+ (o, (name, self.name_map[o])))
+ else:
+ self.name_map[o] = name
+
+ def add_to_cache(self, original_entity, converted_ast):
+ self.dependency_cache[original_entity] = converted_ast
+
+
+class EntityContext(object):
+ """Tracks the conversion of a single entity.
+
+ This object is mutable, and is updated during conversion. Not thread safe.
+
+ Attributes:
+ namer: Namer
+ info: transformer.EntityInfo
+ program: ProgramContext
+ """
+
+ def __init__(self, namer, entity_info, program_ctx):
+ self.namer = namer
+ self.info = entity_info
+ self.program = program_ctx
+
+
+class Base(transformer.Base):
+ """All converters should inherit from this class.
+
+ Attributes:
+ ctx: EntityContext
+ """
+
+ def __init__(self, ctx):
+ super(Base, self).__init__(ctx.info)
+ self.ctx = ctx # Keeping this short because it's used frequently.
+
+ self._used = False
+ self._ast_depth = 0
+
+ def get_definition_directive(self, node, directive, arg, default):
+ """Returns the unique directive for a symbol, or a default if none exist.
+
+ See lang/directives.py for details on directives.
+
+ Args:
+ node: ast.AST
+ directive: Callable[..., Any]
+ arg: str
+ default: Any
+
+ Raises:
+ ValueError: if conflicting annotations have been found
+ """
+ defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ())
+ if not defs:
+ return default
+
+ # TODO(mdan): Simplify this.
+ arg_values = []
+ for def_ in defs:
+ if (directive not in def_.directives or
+ arg not in def_.directives[directive]):
+ continue
+ arg_value = def_.directives[directive][arg]
+ for prev_value in arg_values:
+ if not ast_util.matches(arg_value, prev_value):
+ qn = anno.getanno(node, anno.Basic.QN)
+ raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
+ (qn, directive.__name__, arg,
+ compiler.ast_to_source(arg_value).strip(),
+ compiler.ast_to_source(prev_value).strip()))
+ arg_values.append(arg_value)
+
+ if not arg_values:
+ return default
+
+ arg_value, = arg_values
+ return arg_value
+
+ def visit(self, node):
+ if not self._ast_depth:
+ if self._used:
+ raise ValueError('converter objects cannot be reused')
+ self._used = True
+
+ self._ast_depth += 1
+ try:
+ return super(Base, self).visit(node)
+ finally:
+ self._ast_depth -= 1
+
+
+class AnnotatedDef(reaching_definitions.Definition):
+
+ def __init__(self):
+ super(AnnotatedDef, self).__init__()
+ self.directives = {}
+
+
+class AgAnno(Enum):
+ """Annotation labels specific to AutoGraph. See anno.py."""
+
+ DIRECTIVES = 'User directives associated with the annotated statement.'
+
+ def __repr__(self):
+ return self.name
+
+
+def standard_analysis(node, context, is_initial=False):
+ """Performs a complete static analysis of the given code.
+
+ Args:
+ node: ast.AST
+ context: converter.EntityContext
+ is_initial: bool, whether this is the initial analysis done on the input
+ source code
+
+ Returns:
+ ast.AST, same as node, with the static analysis annotations added
+ """
+ # TODO(mdan): Clear static analysis here.
+ # TODO(mdan): Consider not running all analyses every time.
+ # TODO(mdan): Don't return a node because it's modified by reference.
+ graphs = cfg.build(node)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, context.info, None)
+ node = reaching_definitions.resolve(node, context.info, graphs, AnnotatedDef)
+ node = liveness.resolve(node, context.info, graphs)
+ node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
+ node = type_info.resolve(node, context.info)
+ # This second call allows resolving first-order class attributes.
+ node = live_values.resolve(node, context.info, config.PYTHON_LITERALS)
+ if is_initial:
+ anno.dup(
+ node,
+ {
+ anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS,
+ },
+ )
+ return node
+
+
+def apply_(node, context, converter_module):
+ """Applies a converter to an AST.
+
+ Args:
+ node: ast.AST
+ context: converter.EntityContext
+ converter_module: converter.Base
+
+ Returns:
+ ast.AST, the result of applying converter to node
+ """
+ node = standard_analysis(node, context)
+ node = converter_module.transform(node, context)
+ return node
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
new file mode 100644
index 0000000000..0a0c6f9002
--- /dev/null
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -0,0 +1,166 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for tests in this module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import imp
+import sys
+
+import six
+
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.platform import test
+
+
+def imported_decorator(f):
+ return lambda a: f(a) + 1
+
+
+# TODO(mdan): We might be able to use the real namer here.
+class FakeNamer(object):
+ """A fake namer that uses a global counter to generate unique names."""
+
+ def __init__(self):
+ self.i = 0
+
+ def new_symbol(self, name_root, used):
+ while True:
+ self.i += 1
+ name = '%s%d' % (name_root, self.i)
+ if name not in used:
+ return name
+
+ def compiled_function_name(self,
+ original_fqn,
+ live_entity=None,
+ owner_type=None):
+ del live_entity
+ if owner_type is not None:
+ return None, False
+ return ('renamed_%s' % '_'.join(original_fqn)), True
+
+
+class FakeNoRenameNamer(FakeNamer):
+
+ def compiled_function_name(self, original_fqn, **_):
+ return str(original_fqn), False
+
+
+class TestCase(test.TestCase):
+ """Base class for unit tests in this module. Contains relevant utilities."""
+
+ @contextlib.contextmanager
+ def assertPrints(self, expected_result):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ yield
+ self.assertEqual(out_capturer.getvalue(), expected_result)
+ finally:
+ sys.stdout = sys.__stdout__
+
+ @contextlib.contextmanager
+ def compiled(self, node, namespace, *symbols):
+ source = None
+
+ self.dynamic_calls = []
+ def converted_call(*args):
+ """Mock version of api.converted_call."""
+ self.dynamic_calls.append(args)
+ return 7
+
+ try:
+ result, source = compiler.ast_to_object(node, include_source_map=True)
+
+ result.tf = self.make_fake_mod('fake_tf', *symbols)
+ fake_ag = self.make_fake_mod('fake_ag', converted_call)
+ fake_ag.__dict__.update(operators.__dict__)
+ fake_ag.__dict__['utils'] = utils
+ fake_ag.__dict__['rewrite_graph_construction_error'] = (
+ errors.rewrite_graph_construction_error)
+ result.__dict__['ag__'] = fake_ag
+ for k, v in namespace.items():
+ result.__dict__[k] = v
+ yield result
+ except Exception: # pylint:disable=broad-except
+ if source is None:
+ print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False))
+ else:
+ print('Offending compiled code:\n%s' % source)
+ raise
+
+ @contextlib.contextmanager
+ def converted(self, entity, converter_module, namespace, *tf_symbols):
+ node, ctx = self.prepare(entity, namespace)
+ node = converter_module.transform(node, ctx)
+ with self.compiled(node, namespace, *tf_symbols) as result:
+ yield result
+
+ def make_fake_mod(self, name, *symbols):
+ fake_mod = imp.new_module(name)
+ for s in symbols:
+ if hasattr(s, '__name__'):
+ setattr(fake_mod, s.__name__, s)
+ elif hasattr(s, 'name'):
+ # This is a bit of a hack, but works for things like tf.int32
+ setattr(fake_mod, s.name, s)
+ else:
+ raise ValueError('can not attach %s - what should be its name?' % s)
+ return fake_mod
+
+ def attach_namespace(self, module, **ns):
+ for k, v in ns.items():
+ setattr(module, k, v)
+
+ def prepare(self,
+ test_fn,
+ namespace,
+ namer=None,
+ arg_types=None,
+ owner_type=None,
+ recursive=True,
+ autograph_decorators=()):
+ node, source = parser.parse_entity(test_fn)
+ node = node.body[0]
+ if namer is None:
+ namer = FakeNamer()
+ program_ctx = converter.ProgramContext(
+ recursive=recursive,
+ autograph_decorators=autograph_decorators,
+ partial_types=None,
+ autograph_module=None,
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file='<fragment>',
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=owner_type)
+ ctx = converter.EntityContext(namer, entity_info, program_ctx)
+ node = converter.standard_analysis(node, ctx, is_initial=True)
+ return node, ctx
diff --git a/tensorflow/python/autograph/core/errors.py b/tensorflow/python/autograph/core/errors.py
new file mode 100644
index 0000000000..0750353423
--- /dev/null
+++ b/tensorflow/python/autograph/core/errors.py
@@ -0,0 +1,258 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Error rewriting logic.
+
+Contains the functions responsible for rewriting tracebacks of errors raised
+in AutoGraph (AG) code to refer to user written code, so that errors only refer
+to the original user code.
+
+When 'user code' is used in comments it refers to the original source code that
+the user wrote and is converting using AutoGraph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import logging
+import sys
+import traceback
+
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.framework import errors_impl
+
+# TODO(mdan): Add a superclass common to all errors.
+
+
+class GraphConstructionError(Exception):
+ """Error for graph construction errors from AutoGraph generated code."""
+
+ def __init__(self, original_error, custom_traceback):
+ self.original_error = original_error
+ self.custom_traceback = custom_traceback
+ super(GraphConstructionError, self).__init__()
+
+ def __str__(self):
+ traceback_str = ''.join(traceback.format_list(self.custom_traceback))
+ return ('Traceback (most recent call last):\n' + traceback_str + '\n' + str(
+ self.original_error) + '\n')
+
+
+class TfRuntimeError(Exception):
+ """Error wrapper for runtime errors raised by AutoGraph generated code."""
+
+ def __init__(self, op_name, op_message, custom_traceback):
+ self.op_name = op_name
+ self.op_message = op_message
+ self.custom_traceback = custom_traceback
+ super(TfRuntimeError, self).__init__()
+
+ def __str__(self):
+ message = '%s\n\nCaused by op %r, defined at:\n' % (self.op_message,
+ self.op_name)
+ return message + ''.join(traceback.format_list(self.custom_traceback))
+
+
+def _rewrite_tb(source_map, tb):
+ """Rewrites code references in a traceback.
+
+ Args:
+ source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
+ locations to their origin
+ tb: List[Tuple[Text, Text, Text, Text]], consistent with
+ traceback.extract_tb.
+ Returns:
+ List[Tuple[Text, Text, Text, Text]], the rewritten traceback
+ """
+ new_tb = []
+ for frame in tb:
+ filename, lineno, _, _ = frame
+ loc = origin_info.LineLocation(filename, lineno)
+ origin = source_map.get(loc)
+ if origin is not None:
+ new_tb.append(origin.as_frame())
+ else:
+ new_tb.append(frame)
+ return new_tb
+
+
+# TODO(mdan): rename to raise_*
+def rewrite_graph_construction_error(source_map):
+ """Rewrites errors raised by non-AG APIs inside AG generated code.
+
+ This is called from the except handler inside an AutoGraph generated function
+ (that is, during exception handling). Only rewrites the frames corresponding
+ to the function that this is called from, so each function is responsible
+ to call this to have its own frames rewritten.
+
+ This function always raises an error.
+
+ Args:
+ source_map: Dict[origin_info.Location, origin_info.OriginInfo], the source
+ map belonging to the calling function
+
+ Raises:
+ GraphConstructionError: The rewritten underlying error.
+ Exception: The underlying error, if it could not be rewritten.
+ """
+ error_info = sys.exc_info()
+ _, original_error, e_traceback = error_info
+ assert original_error is not None
+ try:
+ current_traceback = _cut_traceback_loops(source_map,
+ traceback.extract_tb(e_traceback))
+ if isinstance(original_error, GraphConstructionError):
+ # TODO(mdan): This is incomplete.
+ # The error might have bubbled through a non-converted function.
+ previous_traceback = original_error.custom_traceback
+ cleaned_traceback = [current_traceback[0]] + previous_traceback
+ else:
+ cleaned_traceback = current_traceback
+
+ cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
+
+ if isinstance(original_error, GraphConstructionError):
+ original_error.custom_traceback = cleaned_traceback
+ new_error = original_error
+ else:
+ new_error = GraphConstructionError(original_error, cleaned_traceback)
+ except Exception:
+ logging.exception('Error while rewriting AutoGraph error:')
+ # TODO(mdan): Should reraise here, removing the top frame as well.
+ raise original_error
+ else:
+ raise new_error
+ finally:
+ # Addresses warning https://docs.python.org/2/library/sys.html#sys.exc_info.
+ del e_traceback
+
+
+def _cut_traceback_loops(source_map, original_traceback):
+ """Check for cases where we leave a user method and re-enter it.
+
+ This is done by looking at the function names when the filenames are from any
+ files the user code is in. If we find a case where we return to a user method
+ after leaving it then we cut out the frames in between because we assume this
+ means these in between frames are from internal AutoGraph code that shouldn't
+ be included.
+
+ An example of this is:
+
+ File "file1.py", line 57, in my_func
+ ...
+ File "control_flow_ops.py", line 231, in cond
+ ...
+ File "control_flow_ops.py", line 1039, in inner_cond
+ ...
+ File "file1.py", line 68, in my_func
+ ...
+
+ Where we would remove the control_flow_ops.py frames because we re-enter
+ my_func in file1.py.
+
+ The source map keys are (file_path, line_number) so get the set of all user
+ file_paths.
+
+ Args:
+ source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo], mapping
+ locations to their origin
+ original_traceback: List[Tuple[Text, Text, Text, Text]], consistent with
+ traceback.extract_tb.
+
+ Returns:
+ List[Tuple[Text, Text, Text, Text]], the traceback with any loops removed.
+ """
+ all_user_files = set(loc.filename for loc in source_map)
+ cleaned_traceback = []
+ last_user_frame_index = None
+ last_user_user_file_path = None
+ # TODO(mdan): Simplify this logic.
+ for fi, frame in enumerate(original_traceback):
+ frame_file_path, lineno, _, _ = frame
+ src_map_key = origin_info.LineLocation(frame_file_path, lineno)
+ if frame_file_path in all_user_files:
+ if src_map_key in source_map:
+ if (last_user_frame_index is not None and
+ last_user_user_file_path == frame_file_path):
+ cleaned_traceback = cleaned_traceback[:last_user_frame_index]
+ last_user_frame_index = fi
+ last_user_user_file_path = frame_file_path
+ cleaned_traceback.append(frame)
+ return cleaned_traceback
+
+
+# TODO(mdan): This should be consistent with rewrite_graph_construction_error
+# Both should either raise or return.
+def rewrite_tf_runtime_error(error, source_map):
+ """Rewrites TensorFlow runtime errors raised by ops created in AG code.
+
+ Args:
+ error: tf.OpError
+ source_map: Dict[origin_info.LineLocation, origin_info.OriginInfo]
+
+ Returns:
+ TfRuntimeError, the rewritten underlying error.
+ """
+ try:
+ cleaned_traceback = _cut_traceback_loops(source_map, error.op.traceback)
+ # cleaned_traceback = error.op.traceback
+ cleaned_traceback = _rewrite_tb(source_map, cleaned_traceback)
+
+ op_name = error.op.name
+ op_message = error.message
+ rewritten_error = TfRuntimeError(op_name, op_message, cleaned_traceback)
+ return rewritten_error
+ except Exception: # pylint: disable=broad-except
+ logging.exception('Error while rewriting AutoGraph error:')
+ return error
+
+
+# TODO(znado): Add arg to enable different levels of error rewriting.
+@contextlib.contextmanager
+def improved_errors(converted_function):
+ """Context manager that rewrites runtime errors.
+
+ This context manager will rewrite runtime errors so that their traceback
+ is relative to the original code before conversion.
+
+ Use with the output of to_graph, and wrap the execution of respective ops.
+ Example:
+
+ converted_my_func = ag.to_graph(my_func)
+ ops = converted_my_func(...)
+
+ with ag.improved_errors(converted_my_func):
+ sess.run(ops)
+
+ Args:
+ converted_function: Callable[..., Any], the output of a to_graph call
+
+ Yields:
+ None
+
+ Raises:
+ TfRuntimeError: if any OpError originates in the converted code, it will
+ be wrapped into a TfRuntimeError
+ ValueError: If converted_function is not generated by AutoGraph
+ """
+ if (getattr(converted_function, 'ag_source_map', None) is None or
+ not isinstance(converted_function.ag_source_map, dict)):
+ raise ValueError(
+ 'converted_function must be the result of an autograph.to_graph call')
+ try:
+ yield
+ except errors_impl.OpError as e:
+ raise rewrite_tf_runtime_error(e, converted_function.ag_source_map)
diff --git a/tensorflow/python/autograph/core/errors_test.py b/tensorflow/python/autograph/core/errors_test.py
new file mode 100644
index 0000000000..0444ed7eab
--- /dev/null
+++ b/tensorflow/python/autograph/core/errors_test.py
@@ -0,0 +1,105 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for errors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors as tf_errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
+
+
+def zero_div():
+ x = array_ops.constant(10, dtype=dtypes.int32)
+ return x // 0
+
+
+def zero_div_caller():
+ return zero_div()
+
+
+class RuntimeErrorsTest(test.TestCase):
+
+ def fake_origin(self, function, line_offset):
+ _, lineno = tf_inspect.getsourcelines(function)
+ filename = tf_inspect.getsourcefile(function)
+ lineno += line_offset
+ loc = origin_info.LineLocation(filename, lineno)
+ origin = origin_info.OriginInfo(loc, 'test_function_name', 'test_code',
+ 'test_comment')
+ return loc, origin
+
+ def test_improved_errors_basic(self):
+ loc, origin = self.fake_origin(zero_div, 2)
+ zero_div_caller.ag_source_map = {loc: origin}
+
+ ops = zero_div_caller()
+ with self.assertRaises(errors.TfRuntimeError) as cm:
+ with errors.improved_errors(zero_div_caller):
+ with self.test_session() as sess:
+ sess.run(ops)
+
+ for frame in cm.exception.custom_traceback:
+ _, _, function_name, _ = frame
+ self.assertNotEqual('zero_div', function_name)
+ self.assertIn(origin.as_frame(), set(cm.exception.custom_traceback))
+
+ def test_improved_errors_no_matching_lineno(self):
+ loc, origin = self.fake_origin(zero_div, -1)
+ zero_div_caller.ag_source_map = {loc: origin}
+
+ ops = zero_div_caller()
+ with self.assertRaises(errors.TfRuntimeError) as cm:
+ with errors.improved_errors(zero_div_caller):
+ with self.test_session() as sess:
+ sess.run(ops)
+
+ all_function_names = set()
+ for frame in cm.exception.custom_traceback:
+ _, _, function_name, _ = frame
+ all_function_names.add(function_name)
+ self.assertNotEqual('test_function_name', function_name)
+ self.assertIn('zero_div', all_function_names)
+
+ def test_improved_errors_failures(self):
+ loc, _ = self.fake_origin(zero_div, 2)
+ zero_div_caller.ag_source_map = {loc: 'bogus object'}
+
+ ops = zero_div_caller()
+ with self.assertRaises(tf_errors.InvalidArgumentError):
+ with errors.improved_errors(zero_div_caller):
+ with self.test_session() as sess:
+ sess.run(ops)
+
+ def test_improved_errors_validation(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'converted_function must be the result of an autograph.to_graph call'):
+ errors.improved_errors(zero_div).__enter__()
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'converted_function must be the result of an autograph.to_graph call'):
+ zero_div_caller.ag_source_map = 'not a dict'
+ errors.improved_errors(zero_div_caller).__enter__()
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/core/naming.py b/tensorflow/python/autograph/core/naming.py
new file mode 100644
index 0000000000..aecc9e33ca
--- /dev/null
+++ b/tensorflow/python/autograph/core/naming.py
@@ -0,0 +1,130 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Symbol naming utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import qual_names
+
+
+class Namer(object):
+ """Implementation of the namer interfaces required by various converters.
+
+ This implementation performs additional tasks like keeping track of the
+ function calls that have been encountered and replaced with calls to their
+ corresponding compiled counterparts.
+
+ Interfaces currently implemented:
+ * call_trees.FunctionNamer
+ * control_flow.SymbolNamer
+ * side_effect_guards.SymbolNamer
+ """
+
+ def __init__(self, global_namespace, recursive, name_map, partial_types):
+ self.global_namespace = global_namespace
+ self.recursive = recursive
+ self.partial_types = partial_types
+
+ self.renamed_calls = {}
+ if name_map is not None:
+ self.renamed_calls.update(name_map)
+
+ self.generated_names = set()
+
+ def compiled_class_name(self, original_fqn, live_entity=None):
+ """See call_trees.FunctionNamer.compiled_class_name."""
+ if live_entity is not None and live_entity in self.renamed_calls:
+ return self.renamed_calls[live_entity]
+
+ if isinstance(original_fqn, tuple):
+ original_name = '__'.join(original_fqn)
+ else:
+ original_name = original_fqn
+
+ new_name_root = 'Tf%s' % original_name
+ new_name = new_name_root
+ n = 0
+ while new_name in self.global_namespace:
+ n += 1
+ new_name = '%s_%d' % (new_name_root, n)
+
+ self.generated_names.add(new_name)
+ if live_entity is not None:
+ self.renamed_calls[live_entity] = new_name
+ return new_name
+
+ def compiled_function_name(self,
+ original_fqn,
+ live_entity=None,
+ owner_type=None):
+ """See call_trees.FunctionNamer.compiled_function_name."""
+
+ if not self.recursive:
+ return None, False
+
+ if owner_type is not None and owner_type not in self.partial_types:
+ # Members are not renamed when part of an entire converted class.
+ return None, False
+
+ if isinstance(original_fqn, tuple):
+ original_name = '__'.join(original_fqn)
+ else:
+ original_name = original_fqn
+
+ if live_entity is not None and live_entity in self.renamed_calls:
+ return self.renamed_calls[live_entity], True
+
+ new_name_root = 'tf__%s' % original_name
+ new_name = new_name_root
+ n = 0
+ while new_name in self.global_namespace:
+ n += 1
+ new_name = '%s_%d' % (new_name_root, n)
+
+ if live_entity is not None:
+ self.renamed_calls[live_entity] = new_name
+ self.generated_names.add(new_name)
+
+ return new_name, True
+
+ def new_symbol(self, name_root, reserved_locals):
+ """See control_flow.SymbolNamer.new_symbol."""
+ # reserved_locals may contain QNs.
+ all_reserved_locals = set()
+ for s in reserved_locals:
+ if isinstance(s, qual_names.QN):
+ all_reserved_locals.update(s.qn)
+ elif isinstance(s, str):
+ all_reserved_locals.add(s)
+ else:
+ raise ValueError('Unexpected symbol type "%s"' % type(s))
+
+ pieces = name_root.split('_')
+ if pieces[-1].isdigit():
+ name_root = '_'.join(pieces[:-1])
+ n = int(pieces[-1])
+ else:
+ n = 0
+ new_name = name_root
+
+ while (new_name in self.global_namespace or
+ new_name in all_reserved_locals or new_name in self.generated_names):
+ n += 1
+ new_name = '%s_%d' % (name_root, n)
+
+ self.generated_names.add(new_name)
+ return new_name
diff --git a/tensorflow/python/autograph/core/naming_test.py b/tensorflow/python/autograph/core/naming_test.py
new file mode 100644
index 0000000000..2db98836d1
--- /dev/null
+++ b/tensorflow/python/autograph/core/naming_test.py
@@ -0,0 +1,77 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for naming module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import naming
+from tensorflow.python.platform import test
+
+
+class NamerTest(test.TestCase):
+
+ def test_compiled_function_name_tracks_names(self):
+ def bar():
+ pass
+
+ namer = naming.Namer({}, True, None, ())
+ self.assertEqual(('tf__foo', True), namer.compiled_function_name('foo'))
+ self.assertEqual(('tf__bar', True), namer.compiled_function_name(
+ 'bar', bar))
+ self.assertEqual({bar: 'tf__bar'}, namer.renamed_calls)
+ self.assertItemsEqual(('tf__bar', 'tf__foo'), namer.generated_names)
+
+ def test_compiled_function_name_consistent(self):
+ def foo():
+ pass
+
+ namer = naming.Namer({}, True, None, ())
+ self.assertEqual(('tf__foo', True), namer.compiled_function_name(
+ 'foo', foo))
+ self.assertEqual(('tf__foo', True), namer.compiled_function_name(
+ 'foo', foo))
+
+ def test_compiled_function_name_avoids_global_conflicts(self):
+ def foo():
+ pass
+
+ namer = naming.Namer({'tf__foo': 1}, True, None, ())
+ self.assertEqual(('tf__foo_1', True),
+ namer.compiled_function_name('foo', foo))
+
+ def test_new_symbol_tracks_names(self):
+ namer = naming.Namer({}, True, None, ())
+ self.assertEqual('temp', namer.new_symbol('temp', set()))
+ self.assertItemsEqual(('temp',), namer.generated_names)
+
+ def test_new_symbol_avoids_duplicates(self):
+ namer = naming.Namer({}, True, None, ())
+ self.assertEqual('temp', namer.new_symbol('temp', set()))
+ self.assertEqual('temp_1', namer.new_symbol('temp', set()))
+ self.assertItemsEqual(('temp', 'temp_1'), namer.generated_names)
+
+ def test_new_symbol_avoids_conflicts(self):
+ namer = naming.Namer({'temp': 1}, True, None, ())
+ # temp is reserved in the global namespace
+ self.assertEqual('temp_1', namer.new_symbol('temp', set()))
+ # temp_2 is reserved in the local namespace
+ self.assertEqual('temp_3', namer.new_symbol('temp', set(('temp_2',))))
+ self.assertItemsEqual(('temp_1', 'temp_3'), namer.generated_names)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/docs/pyfunc_dtypes.md b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
new file mode 100644
index 0000000000..c2427f5f4f
--- /dev/null
+++ b/tensorflow/python/autograph/docs/pyfunc_dtypes.md
@@ -0,0 +1,33 @@
+# Specifying return data type for `py_func` calls
+
+The `py_func` op requires specifying a
+[data type](https://www.tensorflow.org/guide/tensors#data_types).
+
+When wrapping a function with `py_func`, for instance using
+`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
+options to specify the returned data type:
+
+ * explicitly, with a specified `tf.DType` value
+ * by matching the data type of an input argument, which is then assumed to be
+ a `Tensor`
+
+Examples:
+
+Specify an explicit data type:
+
+```
+ def foo(a):
+ return a + 1
+
+ autograph.util.wrap_py_func(f, return_dtypes=[tf.float32])
+```
+
+Match the data type of the first argument:
+
+```
+ def foo(a):
+ return a + 1
+
+ autograph.util.wrap_py_func(
+ f, return_dtypes=[autograph.utils.py_func.MatchDType(0)])
+```
diff --git a/tensorflow/python/autograph/impl/BUILD b/tensorflow/python/autograph/impl/BUILD
new file mode 100644
index 0000000000..bef62a6403
--- /dev/null
+++ b/tensorflow/python/autograph/impl/BUILD
@@ -0,0 +1,62 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "impl",
+ srcs = [
+ "api.py",
+ "conversion.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/autograph/converters",
+ "//tensorflow/python/autograph/core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "api_test",
+ srcs = ["api_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":impl",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/utils",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "conversion_test",
+ srcs = ["conversion_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":impl",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
new file mode 100644
index 0000000000..669d36bd28
--- /dev/null
+++ b/tensorflow/python/autograph/impl/api.py
@@ -0,0 +1,328 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module contains the user-facing API for AutoGraph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import wraps
+
+from enum import Enum
+
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import conversion
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+# TODO(mdan): Properly document the type hints.
+# TODO(mdan): Reduce the type hint information to (module, type).
+# (currently we require (module + class name, type))
+
+
+# TODO(mdan): This should behave like to_graph (e.g. convert statically).
+def convert(recursive=False, verbose=False):
+ """Decorator that compiles a function to use TensorFlow ops.
+
+ The decorator is dynamic - it recompiles the target whenever the decorated
+ function is called. This means the parameter values are known at conversion.
+ It also means that repeated calls with different types of parameters will be
+ correctly processed.
+
+ Args:
+ recursive: bool, whether to recursively convert any functions or classes
+ that the converted function may use.
+ verbose: bool, whether to output the compiled code in the logs.
+
+ Returns:
+ Callable, a decorator that converts the given function into an equivalent
+ function that uses TensorFlow ops.
+ """
+ def decorator(f):
+ """Decorator implementation."""
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ return converted_call(f, recursive, verbose, True, {}, *args, **kwargs)
+
+ wrapper = tf_decorator.make_decorator(f, wrapper)
+
+ # Sometimes the decorator is just desugared, making it impossible to detect.
+ # This attribute makes detection easier.
+ setattr(wrapper, '__pyct_is_compile_decorator', True)
+ return wrapper
+
+ return decorator
+
+
+class RunMode(Enum):
+ """Specifies the way a converted function or method should be executed in TF.
+
+ The enum values have the following semantics:
+
+ * GRAPH: Call this function directly, as-is. This is suitable for functions
+ that were already designed for TF graphs and contain ops.
+ * PY_FUNC: Wrap this function into a py_func op. This is suitable for code
+ that will only run correctly in Python, for example code that renders
+ to the display, reads keyboard input, etc.
+ """
+ GRAPH = 1
+ PY_FUNC = 2
+
+
+def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
+ """Decorator that suppresses the conversion of a function.
+
+ See also: docs/pyfunc_dtypes.md
+
+ Args:
+ run_as: RunMode, specifies how to use the function in TensorFlow.
+ return_dtypes: Optional[Iterable[
+ Union[tf.DType, utils.py_func.MatchDType]]], the return data types of
+ the converted function, if run_as is RunMode.PY_FUNC. Ignored otherwise.
+ May be set to None if the function has no return values.
+
+ Returns:
+ Callable, a decorator that wraps the original function.
+ """
+
+ def decorator(f):
+ """Decorator implementation."""
+
+ @wraps(f)
+ def graph_wrapper(*args, **kwargs):
+ return f(*args, **kwargs)
+
+ @wraps(f)
+ def py_func_wrapper(*args, **kwargs):
+ if kwargs:
+ raise NotImplementedError('RunMode.PY_FUNC does not yet support kwargs')
+ # TODO(mdan): Add support for kwargs.
+ return py_func.wrap_py_func(
+ f, return_dtypes, args, kwargs, use_dummy_return=not return_dtypes)
+
+ if run_as == RunMode.GRAPH:
+ wrapper = graph_wrapper
+ elif run_as == RunMode.PY_FUNC:
+ wrapper = py_func_wrapper
+ else:
+ raise ValueError('unknown value for run_as: %s' % run_as)
+
+ # Sometimes the decorator is just desugared, making it impossible to detect.
+ # This attribute makes detection easier.
+ setattr(wrapper, '__pyct_is_compile_decorator', True)
+ return wrapper
+
+ return decorator
+
+
+# TODO(mdan): Move to a private, undocumented module.
+def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
+ **kwargs):
+ """Compiles a function call inline. For internal use only."""
+ # TODO(mdan): This needs cleanup.
+ # In particular, we may want to avoid renaming functions altogether.
+ if not force_conversion and conversion.is_whitelisted_for_graph(f):
+ return f(*args, **kwargs)
+
+ unknown_arg_value = object() # Sentinel for arguments of unknown value
+
+ if inspect_utils.isbuiltin(f):
+ return py_builtins.overload_of(f)(*args, **kwargs)
+
+ if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
+ # Regular functions
+ target_entity = f
+ arg_map_target = f
+ effective_args = args
+ f_class = inspect_utils.getmethodclass(f)
+
+ if f_class is not None:
+ partial_types = (f_class,)
+ else:
+ partial_types = ()
+
+ elif tf_inspect.isclass(f):
+ # Constructors
+ target_entity = f
+ arg_map_target = f.__init__
+ effective_args = args
+ partial_types = ()
+
+ elif hasattr(f, '__call__') and hasattr(f, '__class__'):
+ # Callable objects
+ target_entity = f.__call__
+ arg_map_target = f.__call__
+ effective_args = (f,) + args
+ partial_types = (f.__class__,)
+
+ else:
+ NotImplementedError('unknown callable type "%s"' % type(f))
+
+ arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
+ for name, arg in arg_values.items():
+ if arg is unknown_arg_value:
+ continue
+ arg_class = arg.__class__
+ # If arg_value_hints specifies any name, use that instead.
+ if name not in arg_types:
+ arg_types[name] = (arg_class.__name__, arg_class)
+
+ # When called from within a decorator, this is the only indication that
+ # the function is a method - it appears that the decorator is applied
+ # before the method is bound.
+ if not partial_types:
+ if 'self' in arg_values:
+ if tf_inspect.isclass(arg_values['self'].__class__):
+ partial_types = (arg_values['self'].__class__,)
+ elif 'cls' in arg_values:
+ if tf_inspect.isclass(arg_values['cls']):
+ partial_types = (arg_values['cls'],)
+
+ converted_f = to_graph(
+ target_entity,
+ recursive=recursive,
+ verbose=verbose,
+ arg_values=arg_values,
+ arg_types=arg_types,
+ partial_types=partial_types)
+ return converted_f(*effective_args, **kwargs)
+
+
+# TODO(mdan): Rename: to_ops?
+# TODO(mdan): Looki into overloading as function and decorator, like tfe.defun.
+# TODO(mdan): Remove partial_types.
+def to_graph(e,
+ recursive=True,
+ verbose=False,
+ arg_values=None,
+ arg_types=None,
+ partial_types=None):
+ """Converts a Python entity into equivalent code that uses TensorFlow ops.
+
+ Supported Python entities include:
+ * functions
+ * classes
+
+ Classes are converted by converting all their methods into a new class.
+
+ Args:
+ e: Union[Callable, Type], the Python entity to convert.
+ recursive: bool, whether to recursively convert any functions that the
+ converted function may call.
+ verbose: bool, whether to output the compiled code in the logs.
+ arg_values: Optional[Dict[Text, Any]], value hints for symbols including
+ function arguments.
+ arg_types: Optional[Dict[Text, Type]], type hints for symbols including
+ function arguments.
+ partial_types: Set[Type], reserved for internal use.
+
+ Returns:
+ Union[Callable, Type], the converted entity, which is the same kind as e
+ (that is, a function is e is a function, a class if e is a class, etc.) but
+ its code has been converted to use TF ops.
+
+ Raises:
+ ValueError: If the entity could not be converted.
+ """
+ program_ctx = converter.ProgramContext(
+ recursive=recursive,
+ autograph_decorators=(convert, do_not_convert, converted_call),
+ partial_types=partial_types,
+ autograph_module=tf_inspect.getmodule(to_graph),
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+ _, name, namespace = conversion.entity_to_graph(e, program_ctx, arg_values,
+ arg_types)
+
+ nodes = []
+ for dep in reversed(tuple(program_ctx.dependency_cache.values())):
+ nodes.extend(dep)
+ compiled_module, compiled_src = compiler.ast_to_object(
+ nodes,
+ source_prefix=program_ctx.required_imports,
+ include_source_map=True)
+
+ # The compiled code should see everything the entry entity saw.
+ # TODO(mdan): This might not work well if the call tree spans modules?
+ for key, val in namespace.items():
+ # Avoid overwriting entities that have been transformed.
+ if key not in compiled_module.__dict__:
+ compiled_module.__dict__[key] = val
+ compiled = getattr(compiled_module, name)
+
+ # Need this so the source_mapping attribute is available for the context
+ # manager to access for runtime errors.
+ #
+ # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
+ # symbol to the compiled module.
+ # TODO(mdan): Record this statically in the generated code.
+ # TODO(mdan): Rename this attribute to 'autograph_info__'
+ source_map_attribute_name = 'ag_source_map'
+ if getattr(compiled, source_map_attribute_name, None) is not None:
+ raise ValueError('cannot convert %s because is has an attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled, source_map_attribute_name))
+ setattr(compiled, source_map_attribute_name,
+ compiled_module.__dict__['ag_source_map__'])
+
+ if verbose:
+ logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)
+
+ return compiled
+
+
+def to_code(e,
+ recursive=True,
+ arg_values=None,
+ arg_types=None,
+ partial_types=None,
+ indentation=' '):
+ """Returns the equivalent code that uses TensorFlow ops.
+
+ Also see: `to_graph`, `convert`
+
+ Args:
+ e: Union[Callable, Type], the Python entity to convert.
+ recursive: bool, whether to recursively convert any functions that the
+ converted function may call.
+ arg_values: Optional[Dict[Text, Any]], value hints for symbols including
+ function arguments.
+ arg_types: Optional[Dict[Text, Type]], type hints for symbols including
+ function arguments.
+ partial_types: Set[Type], reserved for internal use.
+ indentation: Text, when to use for each level of indentation.
+
+ Returns:
+ Text, the converted code.
+ """
+ program_ctx = converter.ProgramContext(
+ recursive=recursive,
+ autograph_decorators=(convert, do_not_convert, converted_call),
+ partial_types=partial_types,
+ autograph_module=tf_inspect.getmodule(to_graph),
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+ conversion.entity_to_graph(e, program_ctx, arg_values, arg_types)
+
+ code = '\n'.join(
+ compiler.ast_to_source(dep, indentation)
+ for dep in reversed(tuple(program_ctx.dependency_cache.values())))
+
+ return program_ctx.required_imports + '\n\n' + code
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
new file mode 100644
index 0000000000..54e12f0223
--- /dev/null
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -0,0 +1,329 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for api module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
+
+
+tf = utils.fake_tf()
+
+
+class ApiTest(test.TestCase):
+
+ def setUp(self):
+ config.COMPILED_IMPORT_STATEMENTS = (
+ 'from __future__ import print_function',
+ )
+
+ def test_decorator_recurses(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_does_not_recurse(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ return tf.negative(a)
+
+ @api.convert(recursive=False)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_unconverted_graph(self):
+
+ class TestClass(object):
+
+ @api.do_not_convert(api.RunMode.GRAPH)
+ def called_member(self, a):
+ return tf.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_unconverted_py_func(self):
+
+ class TestClass(object):
+
+ @api.do_not_convert(
+ api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1))
+ def called_member(self, a):
+ return np.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ y = self.called_member(a)
+ # set_shape works around while_loop's limitations.
+ # TODO(mdan): Allow specifying shapes (or ShapeLike) instead.
+ y.set_shape(a.shape)
+ x //= y
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_decorated(self):
+
+ class TestClass(object):
+
+ @api.convert()
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_preserves_argspec(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ called_member_converted = api.convert()(called_member)
+
+ tc = TestClass()
+ self.assertListEqual(
+ list(tf_inspect.getfullargspec(tc.called_member)),
+ list(tf_inspect.getfullargspec(tc.called_member_converted)))
+
+ def test_convert_call_site_decorator(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= api.converted_call(self.called_member, False, False, False, {},
+ self, a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_converted_call_builtin(self):
+ x = api.converted_call(range, False, False, False, {}, 3)
+ self.assertEqual((0, 1, 2), tuple(x))
+
+ def test_converted_call_function(self):
+
+ def test_fn(x):
+ if x < 0:
+ return -x
+ return x
+
+ with self.test_session() as sess:
+ x = api.converted_call(test_fn, False, False, False, {},
+ constant_op.constant(-1))
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_method(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc.test_method, False, False, False, {}, tc)
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_method_by_class(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_callable_object(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def __call__(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc, False, False, False, {})
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_constructor(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = api.converted_call(TestClass, False, False, False, {},
+ constant_op.constant(-1))
+ # tc is now a converted object.
+ x = tc.test_method()
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_already_converted(self):
+
+ def f(x):
+ return x == 0
+
+ with self.test_session() as sess:
+ x = api.converted_call(f, False, False, False, {},
+ constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
+ converted_f = api.to_graph(f)
+ x = api.converted_call(converted_f, False, False, False, {},
+ constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
+ def test_to_graph_basic(self):
+
+ def test_fn(x, s):
+ while tf.reduce_sum(x) > s:
+ x //= 2
+ return x
+
+ compiled_fn = api.to_graph(test_fn)
+
+ with self.test_session() as sess:
+ x = compiled_fn(constant_op.constant([4, 8]), 4)
+ self.assertListEqual([1, 2], sess.run(x).tolist())
+
+ def test_to_code_basic(self):
+
+ def test_fn(x, s):
+ while tf.reduce_sum(x) > s:
+ x /= 2
+ return x
+
+ compiled_code = api.to_code(test_fn)
+
+ # Just check that it is parseable Python code.
+ self.assertIsNotNone(parser.parse_str(compiled_code))
+
+ def test_source_map_attribute_present(self):
+
+ def test_fn(y):
+ return y**2
+
+ self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
new file mode 100644
index 0000000000..928ff9e7ea
--- /dev/null
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -0,0 +1,351 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Core conversion logic, serves as main point of access."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+
+import gast
+
+from tensorflow.python.autograph import operators
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.converters import break_statements
+from tensorflow.python.autograph.converters import builtin_functions
+from tensorflow.python.autograph.converters import call_trees
+from tensorflow.python.autograph.converters import conditional_expressions
+from tensorflow.python.autograph.converters import continue_statements
+from tensorflow.python.autograph.converters import control_flow
+from tensorflow.python.autograph.converters import decorators
+from tensorflow.python.autograph.converters import directives
+from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import lists
+from tensorflow.python.autograph.converters import logical_expressions
+from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import return_statements
+from tensorflow.python.autograph.converters import side_effect_guards
+from tensorflow.python.autograph.converters import slices
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.util import tf_inspect
+
+
+# TODO(mdan): Might we not need any renaming at all?
+
+
+def is_whitelisted_for_graph(o):
+ """Check whether an entity is whitelisted for use in graph mode.
+
+ Examples of whitelisted entities include all members of the tensorflow
+ package.
+
+ Args:
+ o: A Python entity.
+ Returns:
+ Boolean
+ """
+ m = tf_inspect.getmodule(o)
+ for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
+ if m.__name__.startswith(prefix):
+ return True
+ if hasattr(o, 'autograph_info__'):
+ return True
+ return False
+
+
+def entity_to_graph(o, program_ctx, arg_values, arg_types):
+ """Compile a Python entity into equivalent TensorFlow.
+
+ The function will also recursively compile all the entities that `o`
+ references, updating `dependency_cache`.
+
+ This function is reentrant, and relies on dependency_cache to avoid
+ generating duplicate code.
+
+ Args:
+ o: A Python entity.
+ program_ctx: A ProgramContext object.
+ arg_values: A dict containing value hints for symbols like function
+ parameters.
+ arg_types: A dict containing type hints for symbols like function
+ parameters.
+
+ Returns:
+ A tuple (ast, new_name, namespace):
+ * ast: An AST representing an entity with interface equivalent to `o`,
+ but which when executed it creates TF a graph.
+ * new_name: The symbol name under which the new entity can be found.
+ * namespace: A dict mapping all symbols visible to the converted entity,
+ keyed by their symbol name.
+
+ Raises:
+ ValueError: if the entity type is not supported.
+ """
+ if tf_inspect.isclass(o):
+ node, name, ns = class_to_graph(o, program_ctx)
+ elif tf_inspect.isfunction(o):
+ # TODO(mdan): This is not a reliable mechanism.
+ # The most reliable way is to check the source code, the AST will contain
+ # a Lambda node instead of a FunctionDef
+ if o.__name__ == '<lambda>':
+ raise NotImplementedError(
+ 'lambda functions are not yet supported; declare the function'
+ ' using def instead: %s' % o)
+ else:
+ node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
+ elif tf_inspect.ismethod(o):
+ node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
+ # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
+ elif hasattr(o, '__class__'):
+ raise NotImplementedError(
+ 'Object conversion is not yet supported. If you are '
+ 'trying to convert code that uses an existing object, '
+ 'try including the creation of that object in the '
+ 'conversion. For example, instead of converting the method '
+ 'of a class, try converting the entire class instead. '
+ 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
+ 'contrib/autograph/README.md#using-the-functional-api '
+ 'for more information.')
+ else:
+ raise ValueError(
+ 'Entity "%s" has unsupported type "%s". Only functions and classes are '
+ 'supported for now.' % (o, type(o)))
+
+ # TODO(mdan): This is temporary. it should be created using a converter.
+ # TODO(mdan): The attribute should be added with a helper, not directly.
+ # The helper can ensure there are no collisions.
+ template = '''
+ entity.autograph_info__ = {}
+ '''
+ node.extend(templates.replace(template, entity=name))
+
+ program_ctx.add_to_cache(o, node)
+
+ if program_ctx.recursive:
+ while True:
+ candidate = None
+ for obj in program_ctx.name_map.keys():
+ if obj not in program_ctx.dependency_cache:
+ candidate = obj
+ break
+ if candidate is None:
+ break
+ if (hasattr(candidate, 'im_class') and
+ getattr(candidate, 'im_class') not in program_ctx.partial_types):
+ # Class members are converted with their objects, unless they're
+ # only converted partially.
+ continue
+ entity_to_graph(candidate, program_ctx, {}, {})
+
+ return node, name, ns
+
+
+def class_to_graph(c, program_ctx):
+ """Specialization of `entity_to_graph` for classes."""
+ converted_members = {}
+ method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
+ members = tf_inspect.getmembers(c, predicate=method_filter)
+ if not members:
+ raise ValueError('Cannot convert %s: it has no member methods.' % c)
+
+ class_namespace = {}
+ for _, m in members:
+ # Only convert the members that are directly defined by the class.
+ if inspect_utils.getdefiningclass(m, c) is not c:
+ continue
+ node, _, namespace = function_to_graph(
+ m,
+ program_ctx=program_ctx,
+ arg_values={},
+ arg_types={'self': (c.__name__, c)},
+ owner_type=c,
+ rewrite_errors=False)
+ if class_namespace is None:
+ class_namespace = namespace
+ else:
+ class_namespace.update(namespace)
+ converted_members[m] = node[0]
+ namer = program_ctx.new_namer(class_namespace)
+ class_name = namer.compiled_class_name(c.__name__, c)
+
+ # TODO(mdan): This needs to be explained more thoroughly.
+ # Process any base classes: if the superclass if of a whitelisted type, an
+ # absolute import line is generated. Otherwise, it is marked for conversion
+ # (as a side effect of the call to namer.compiled_class_name() followed by
+ # program_ctx.update_name_map(namer)).
+ output_nodes = []
+ renames = {}
+ base_names = []
+ for base in c.__bases__:
+ if isinstance(object, base):
+ base_names.append('object')
+ continue
+ if is_whitelisted_for_graph(base):
+ alias = namer.new_symbol(base.__name__, ())
+ output_nodes.append(
+ gast.ImportFrom(
+ module=base.__module__,
+ names=[gast.alias(name=base.__name__, asname=alias)],
+ level=0))
+ else:
+ # This will trigger a conversion into a class with this name.
+ alias = namer.compiled_class_name(base.__name__, base)
+ base_names.append(alias)
+ renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
+ program_ctx.update_name_map(namer)
+
+ # Generate the definition of the converted class.
+ bases = [gast.Name(n, gast.Load(), None) for n in base_names]
+ class_def = gast.ClassDef(
+ class_name,
+ bases=bases,
+ keywords=[],
+ body=list(converted_members.values()),
+ decorator_list=[])
+ # Make a final pass to replace references to the class or its base classes.
+ # Most commonly, this occurs when making super().__init__() calls.
+ # TODO(mdan): Making direct references to superclass' superclass will fail.
+ class_def = qual_names.resolve(class_def)
+ renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
+ class_def = ast_util.rename_symbols(class_def, renames)
+
+ output_nodes.append(class_def)
+
+ return output_nodes, class_name, class_namespace
+
+
+def _add_reserved_symbol(namespace, name, entity):
+ if name not in namespace:
+ namespace[name] = entity
+ elif namespace[name] != entity:
+ raise ValueError('The name "%s" is reserved and may not be used.' % name)
+
+
+ag_internal = None
+
+
+def _add_self_references(namespace, autograph_module):
+ """Adds namespace references to the module that exposes the api itself."""
+ global ag_internal
+ if ag_internal is None:
+ # Craft a module that exposes parts of the external API as well as certain
+ # internal modules.
+ ag_internal = imp.new_module('autograph')
+ ag_internal.converted_call = autograph_module.converted_call
+ ag_internal.utils = utils
+ ag_internal.rewrite_graph_construction_error = (
+ errors.rewrite_graph_construction_error)
+ # TODO(mdan): Add safeguards against name clashes.
+ # We don't want to create a submodule because we want the operators to be
+ # accessible as ag__.<operator>
+ ag_internal.__dict__.update(operators.__dict__)
+
+ _add_reserved_symbol(namespace, 'ag__', ag_internal)
+
+
+def function_to_graph(f,
+ program_ctx,
+ arg_values,
+ arg_types,
+ owner_type=None,
+ rewrite_errors=True):
+ """Specialization of `entity_to_graph` for callable functions."""
+
+ node, source = parser.parse_entity(f)
+ node = node.body[0]
+ origin_info.resolve(node, source, f)
+ namespace = inspect_utils.getnamespace(f)
+ _add_self_references(namespace, program_ctx.autograph_module)
+ namer = program_ctx.new_namer(namespace)
+
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file='<fragment>',
+ namespace=namespace,
+ arg_values=arg_values,
+ arg_types=arg_types,
+ owner_type=owner_type)
+ context = converter.EntityContext(namer, entity_info, program_ctx)
+ node = node_to_graph(node, context, rewrite_errors=rewrite_errors)
+
+ # TODO(mdan): This somewhat duplicates the call rename logic in call_trees.py
+ new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type)
+ if not did_rename:
+ new_name = f.__name__
+ if node.name != f.__name__:
+ raise NotImplementedError('Strange corner case. Send us offending code!')
+ node.name = new_name
+
+ program_ctx.update_name_map(namer)
+ # TODO(mdan): Use this at compilation.
+
+ return [node], new_name, namespace
+
+
+def node_to_graph(node, context, rewrite_errors=True):
+ """Convert Python code to equivalent TF graph mode code.
+
+ Args:
+ node: AST, the code to convert.
+ context: converter.EntityContext
+ rewrite_errors: Boolean, whether or not to rewrite the error traceback.
+
+ Returns:
+ A tuple (node, deps):
+ * node: A Python ast node, representing the converted code.
+ * deps: A set of strings, the fully qualified names of entity
+ dependencies that this node has.
+ """
+ # TODO(mdan): Insert list_comprehensions somewhere.
+
+ node = converter.standard_analysis(node, context, is_initial=True)
+ # Past this point, line numbers are no longer accurate so we ignore the
+ # source.
+ # TODO(mdan): Is it feasible to reconstruct intermediate source code?
+ context.info.source_code = None
+
+ node = converter.apply_(node, context, decorators)
+ node = converter.apply_(node, context, directives)
+ node = converter.apply_(node, context, break_statements)
+ node = converter.apply_(node, context, asserts)
+ # Note: sequencing continue canonicalization before for loop one avoids
+ # dealing with the extra loop increment operation that the for
+ # canonicalization creates.
+ node = converter.apply_(node, context, continue_statements)
+ context.info.namespace['len'] = len
+ node = converter.apply_(node, context, return_statements)
+ node = converter.apply_(node, context, lists)
+ node = converter.apply_(node, context, slices)
+ node = converter.apply_(node, context, builtin_functions)
+ node = converter.apply_(node, context, call_trees)
+ node = converter.apply_(node, context, control_flow)
+ node = converter.apply_(node, context, conditional_expressions)
+ node = converter.apply_(node, context, logical_expressions)
+ node = converter.apply_(node, context, side_effect_guards)
+ node = converter.apply_(node, context, name_scopes)
+ if rewrite_errors:
+ node = converter.apply_(node, context, error_handlers)
+ return node
diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py
new file mode 100644
index 0000000000..07d0f75129
--- /dev/null
+++ b/tensorflow/python/autograph/impl/conversion_test.py
@@ -0,0 +1,172 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for conversion module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.impl import conversion
+from tensorflow.python.framework import constant_op
+from tensorflow.python.keras.engine import training
+from tensorflow.python.platform import test
+
+
+class ConversionTest(test.TestCase):
+
+ def _simple_program_ctx(self):
+ return converter.ProgramContext(
+ recursive=True,
+ autograph_decorators=(),
+ partial_types=(),
+ autograph_module=api,
+ uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
+
+ def test_is_whitelisted_for_graph(self):
+
+ def test_fn():
+ return constant_op.constant(1)
+
+ self.assertFalse(conversion.is_whitelisted_for_graph(test_fn))
+ self.assertTrue(conversion.is_whitelisted_for_graph(utils))
+ self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant))
+
+ def test_entity_to_graph_unsupported_types(self):
+ with self.assertRaises(NotImplementedError):
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph('dummy', program_ctx, None, None)
+
+ def test_entity_to_graph_callable(self):
+ b = 2
+ def f(a):
+ return a + b
+
+ program_ctx = self._simple_program_ctx()
+ nodes, name, ns = conversion.entity_to_graph(f, program_ctx, None, None)
+ fn_node, _ = nodes
+ self.assertIsInstance(fn_node, gast.FunctionDef)
+ self.assertEqual('tf__f', name)
+ self.assertIs(ns['b'], b)
+
+ def test_entity_to_graph_call_tree(self):
+
+ def g(a):
+ return a
+
+ def f(a):
+ return g(a)
+
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(f, program_ctx, None, None)
+
+ self.assertTrue(f in program_ctx.dependency_cache)
+ self.assertTrue(g in program_ctx.dependency_cache)
+ f_node = program_ctx.dependency_cache[f][0]
+ g_node = program_ctx.dependency_cache[g][0]
+ self.assertEqual('tf__f', f_node.name)
+ self.assertEqual('tf__g', f_node.body[0].body[0].body[0].value.func.id)
+ self.assertEqual('tf__g', g_node.name)
+
+ def test_entity_to_graph_class_hierarchy(self):
+
+ class TestBase(object):
+
+ def __init__(self, x='base'):
+ self.x = x
+
+ def foo(self):
+ return self.x
+
+ def bar(self):
+ return self.x
+
+ class TestSubclass(TestBase):
+
+ def __init__(self, y):
+ super(TestSubclass, self).__init__('sub')
+ self.y = y
+
+ def foo(self):
+ return self.y
+
+ def baz(self):
+ return self.y
+
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(TestSubclass, program_ctx, None, None)
+
+ self.assertTrue(TestBase in program_ctx.dependency_cache)
+ self.assertTrue(TestSubclass in program_ctx.dependency_cache)
+ # The returned nodes will include:
+ # <import nodes>, <class node>, <assignment node>
+ self.assertEqual('TfTestBase',
+ program_ctx.dependency_cache[TestBase][-2].name)
+ self.assertEqual('TfTestSubclass',
+ program_ctx.dependency_cache[TestSubclass][-2].name)
+
+ def test_entity_to_graph_class_hierarchy_whitelisted(self):
+
+ class TestSubclass(training.Model):
+
+ def __init__(self, y):
+ super(TestSubclass, self).__init__()
+ self.built = False
+
+ def call(self, x):
+ return 3 * x
+
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(TestSubclass, program_ctx, None, None)
+
+ self.assertTrue(TestSubclass in program_ctx.dependency_cache)
+ self.assertFalse(training.Model in program_ctx.dependency_cache)
+ self.assertEqual(
+ 'Model', program_ctx.dependency_cache[TestSubclass][0].names[0].name)
+ # The returned nodes will include:
+ # <import nodes>, <class node>, <assignment node>
+ self.assertEqual('TfTestSubclass',
+ program_ctx.dependency_cache[TestSubclass][-2].name)
+
+ def test_entity_to_graph_lambda(self):
+ f = lambda a: a
+
+ with self.assertRaises(NotImplementedError):
+ program_ctx = self._simple_program_ctx()
+ conversion.entity_to_graph(f, program_ctx, None, None)
+
+ def test_ag_module_cached(self):
+ def callee():
+ return range(3)
+
+ def caller(a):
+ return a()
+
+ program_ctx = self._simple_program_ctx()
+ _, _, callee_ns = conversion.entity_to_graph(callee, program_ctx, None,
+ None)
+ _, _, caller_ns = conversion.entity_to_graph(caller, program_ctx, None,
+ None)
+
+ self.assertTrue(callee_ns['ag__'] is caller_ns['ag__'])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/lang/BUILD b/tensorflow/python/autograph/lang/BUILD
new file mode 100644
index 0000000000..462349cc10
--- /dev/null
+++ b/tensorflow/python/autograph/lang/BUILD
@@ -0,0 +1,40 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "lang",
+ srcs = [
+ "directives.py",
+ "special_functions.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/python/autograph/operators",
+ ],
+)
+
+py_test(
+ name = "special_functions_test",
+ srcs = ["special_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":lang",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/python/autograph/lang/directives.py b/tensorflow/python/autograph/lang/directives.py
new file mode 100644
index 0000000000..aabe5d9939
--- /dev/null
+++ b/tensorflow/python/autograph/lang/directives.py
@@ -0,0 +1,68 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Directives are special no-op functions that serve as compilation markers.
+
+They provide static information like type hints, compilation and TensorFlow
+overrides.
+
+These serve as annotations in the compiled code, allowing the user some control
+over the compilation process. They have no functional role at runtime.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+UNSPECIFIED = object()
+
+
+def set_element_type(entity, dtype, shape=UNSPECIFIED):
+ """Indicates that the entity is expected hold items of specified type/shape.
+
+ The staged TensorFlow ops will reflect and assert this data type. Ignored
+ otherwise.
+
+ Args:
+ entity: The entity to annotate.
+ dtype: TensorFlow dtype value to assert for entity.
+ shape: Optional shape to assert for entity.
+ """
+ del entity
+ del dtype
+ del shape
+
+
+def set_loop_options(
+ parallel_iterations=UNSPECIFIED,
+ back_prop=UNSPECIFIED,
+ swap_memory=UNSPECIFIED,
+ maximum_iterations=UNSPECIFIED):
+ """Specifies additional arguments to be passed to the enclosing while_loop.
+
+ The parameters apply to and only to the immediately enclosing loop. It only
+ has effect if the loop is staged as a TF while_loop; otherwise the parameters
+ have no effect.
+
+ Args:
+ parallel_iterations: See tf.while_loop.
+ back_prop: See tf.while_loop.
+ swap_memory: See tf.while_loop.
+ maximum_iterations: See tf.while_loop.
+ """
+ del parallel_iterations
+ del back_prop
+ del swap_memory
+ del maximum_iterations
diff --git a/tensorflow/python/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
new file mode 100644
index 0000000000..e4838d1b6d
--- /dev/null
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -0,0 +1,96 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Special functions that only make sense for AutoGraph.
+
+These functions are meant to ensure feature parity between Python and AutoGraph,
+so that the exact same code works in both modes. In general, AutoGraph will
+replace these calls.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import data_structures
+
+
+def tensor_list(elements,
+ element_dtype=None,
+ element_shape=None,
+ use_tensor_array=False):
+ """Creates an tensor list and populates it with the given elements.
+
+ This function provides a more uniform access to tensor lists and tensor
+ arrays, and allows optional initialization.
+
+ Note: this function is a simplified wrapper. If you need greater control,
+ it is recommended to use the underlying implementation directly.
+
+ Args:
+ elements: Iterable[tf.Tensor, ...], the elements to initially fill the list
+ with
+ element_dtype: Optional[tf.DType], data type for the elements in the list;
+ required if the list is empty
+ element_shape: Optional[tf.TensorShape], shape for the elements in the list;
+ required if the list is empty
+ use_tensor_array: bool, whether to use the more compatible but restrictive
+ tf.TensorArray implementation
+ Returns:
+ Union[tf.Tensor, tf.TensorArray], the new list.
+ Raises:
+ ValueError: for invalid arguments
+ """
+ if not (elements or (element_dtype and element_shape)):
+ raise ValueError(
+ 'element_dtype and element_shape are required for empty lists')
+ if use_tensor_array:
+ return data_structures.tf_tensor_array_new(elements, element_dtype,
+ element_shape)
+ else:
+ return data_structures.tf_tensor_list_new(elements, element_dtype,
+ element_shape)
+
+
+def stack(list_or_tensor, element_dtype=None, strict=True):
+ """Stacks the input, if it admits the notion of stacking.
+
+ For example, a list of tensors can be stacked into a larger tensor. This
+ function is similar to tf.stack, but it accepts non-lists and lists of
+ non-tensors as arguments. In the latter case, the function does nothing.
+
+ Args:
+ list_or_tensor: Any
+ element_dtype: tf.DType, optional dtypedtype for the elements in the list.
+ Required if the input is stackable, and the list is untyped.
+ strict: bool, if True an error is raised if the input is not stackable.
+ Otherwise the function is a no-op.
+
+ Returns:
+ Any, if the input is stackable, the result will be a tf.Tensor. Otherwise,
+ if strict=False, the result will be list_or_tensor.
+
+ Raises:
+ ValueError: if strict=True and the input is not stackable.
+ """
+ if strict:
+ def raise_error(x):
+ raise ValueError('%s must be stackable when strict=True' % x)
+ original_call = raise_error
+ else:
+ original_call = lambda x: x
+ return data_structures.list_stack(
+ list_or_tensor,
+ data_structures.ListStackOpts(
+ element_dtype=element_dtype, original_call=original_call))
diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
new file mode 100644
index 0000000000..1f1cec18f7
--- /dev/null
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -0,0 +1,70 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for special_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.lang import special_functions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SpecialFunctionsTest(test.TestCase):
+
+ def test_tensor_list_from_elements(self):
+ elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
+
+ l = special_functions.tensor_list(elements)
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
+
+ def test_tensor_list_array_from_elements(self):
+ elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
+
+ l = special_functions.tensor_list(elements, use_tensor_array=True)
+ sl = l.stack()
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
+
+ def test_stack(self):
+ self.assertEqual(special_functions.stack(1, strict=False), 1)
+ self.assertListEqual(
+ special_functions.stack([1, 2, 3], strict=False), [1, 2, 3])
+ # TODO(mdan): This should probably forward to tf.stack.
+ self.assertTrue(
+ isinstance(
+ special_functions.stack(
+ [constant_op.constant(1),
+ constant_op.constant(2)], strict=False), list))
+
+ with self.assertRaises(ValueError):
+ special_functions.stack([1, 2, 3])
+
+ t = constant_op.constant([1.0, 2.0])
+ l = list_ops.tensor_list_from_tensor(
+ t, element_shape=constant_op.constant([], dtype=dtypes.int32))
+ self.assertTrue(
+ tensor_util.is_tensor(
+ special_functions.stack(l, element_dtype=dtypes.float32)))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
new file mode 100644
index 0000000000..a116611b64
--- /dev/null
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -0,0 +1,84 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "operators",
+ srcs = [
+ "__init__.py",
+ "control_flow.py",
+ "data_structures.py",
+ "py_builtins.py",
+ "slices.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/autograph/utils",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "data_structures_test",
+ srcs = ["data_structures_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "control_flow_test",
+ srcs = ["control_flow_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "py_builtins_test",
+ srcs = ["py_builtins_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "slices_test",
+ srcs = ["slices_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
new file mode 100644
index 0000000000..0d3b44b6c4
--- /dev/null
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -0,0 +1,55 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module implements operators that AutoGraph overloads.
+
+Note that "operator" is used loosely here, and includes control structures like
+conditionals and loops, implemented in functional form, using for example
+closures for the body.
+"""
+
+# Naming conventions:
+# * operator names match the name usually used for the respective Python
+# idiom; examples: for_stmt, list_append
+# * operator arguments match either of:
+# - the corresponding Python AST attribute (e.g. the condition of an if
+# statement is called test) if the operator represents an AST construct
+# - the names used in the Python docs, if the operator is a function (e.g.
+# list_ and x for append, see
+# https://docs.python.org/3.7/tutorial/datastructures.html)
+#
+# All operators may accept a final argument named "opts", of a type that
+# subclasses namedtuple and contains any arguments that are only required
+# for some specializations of the operator.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators.control_flow import for_stmt
+from tensorflow.python.autograph.operators.control_flow import while_stmt
+from tensorflow.python.autograph.operators.data_structures import list_append
+from tensorflow.python.autograph.operators.data_structures import list_pop
+from tensorflow.python.autograph.operators.data_structures import list_stack
+from tensorflow.python.autograph.operators.data_structures import ListPopOpts
+from tensorflow.python.autograph.operators.data_structures import ListStackOpts
+from tensorflow.python.autograph.operators.data_structures import new_list
+from tensorflow.python.autograph.operators.py_builtins import float_
+from tensorflow.python.autograph.operators.py_builtins import int_
+from tensorflow.python.autograph.operators.py_builtins import len_
+from tensorflow.python.autograph.operators.py_builtins import print_
+from tensorflow.python.autograph.operators.py_builtins import range_
+from tensorflow.python.autograph.operators.slices import get_item
+from tensorflow.python.autograph.operators.slices import GetItemOpts
+from tensorflow.python.autograph.operators.slices import set_item
diff --git a/tensorflow/python/autograph/operators/control_flow.py b/tensorflow/python/autograph/operators/control_flow.py
new file mode 100644
index 0000000000..6eedd695a7
--- /dev/null
+++ b/tensorflow/python/autograph/operators/control_flow.py
@@ -0,0 +1,227 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow statements: loops, conditionals, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+
+
+def for_stmt(iter_, extra_test, body, init_state):
+ """Functional form of a for statement.
+
+ The loop operates on a state, which includes all symbols that are
+ variant across loop iterations, excluding the iterate as well as the
+ variables local to the loop.
+
+ For example, given the loop below that calculates the geometric and
+ arithmetic means or some numbers:
+
+ geo_mean = 1
+ arith_mean = 0
+ for i in range(n):
+ a = numbers[i]
+ geo_mean *= a
+ arith_mean += a
+
+ The state is represented by the variables geo_mean and arith_mean. The
+ argument for initial_state may contain the tuple (1, 0), the body will
+ include the arguments geo_mean and arith_mean and will return a tuple
+ representing the new values for geo_mean and respectively arith_mean.
+
+ Args:
+ iter_: The entity being iterated over.
+ extra_test: Callable with the state as arguments, and boolean return type.
+ An additional loop condition.
+ body: Callable with the iterate and the state as arguments, and
+ state as return type. The actual loop body.
+ init_state: Tuple containing the initial state.
+
+ Returns:
+ Tuple containing the final state.
+ """
+ if tensor_util.is_tensor(iter_):
+ return _known_len_for_stmt(iter_, extra_test, body, init_state)
+ elif isinstance(iter_, dataset_ops.Dataset):
+ return _dataset_for_stmt(iter_, extra_test, body, init_state)
+ else:
+ return _py_for_stmt(iter_, extra_test, body, init_state)
+
+
+def _py_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that executes a Python for loop."""
+ state = init_state
+ for target in iter_:
+ if not extra_test(*state):
+ break
+ state = body(target, *state)
+
+ # TODO(mdan): Remove this special case.
+ if len(state) == 1:
+ return state[0]
+ return state
+
+
+def _known_len_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over objects that admit a length."""
+ n = py_builtins.len_(iter_)
+
+ def while_body(iterate_index, *state):
+ iterate = iter_[iterate_index]
+ new_state = body(iterate, *state)
+ return (iterate_index + 1,) + new_state
+
+ def while_cond(iterate_index, *state):
+ return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))
+
+ results = while_stmt(
+ while_cond,
+ while_body,
+ init_state=(0,) + init_state,
+ extra_deps=(iter_,),
+ opts=dict(maximum_iterations=n))
+ # Dropping the iteration index because it's not syntactically visible.
+ results = results[1:]
+
+ # TODO(mdan): Remove this special case.
+ if len(results) == 1:
+ return results[0]
+ return results
+
+
+def _dataset_for_stmt(ds, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over TF Datasets."""
+ # Because Datsets only expose get_next, in the style of Python iterators,
+ # we are forced to unpack the loop as:
+ #
+ # epoch_number, iterate = ds.get_next()
+ # while epoch_number < 2:
+ # <body>
+ # epoch_number, iterate = ds.get_next()
+ epoch_numbers = dataset_ops.Dataset.range(2)
+ def tag_with(ds, tag):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(tag).repeat(), ds))
+ ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i))
+
+ iterator = ds_with_epoch.make_initializable_iterator()
+ with ops.control_dependencies((iterator.initializer,)):
+ epoch_number, iterate = iterator.get_next()
+
+ def while_body(epoch_number, iterate, *state):
+ new_state = body(iterate, *state)
+ epoch_number, iterate = iterator.get_next()
+ return (epoch_number, iterate) + new_state
+
+ def while_cond(epoch_number, iterate, *state):
+ del iterate
+ return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state))
+
+ results = while_stmt(
+ while_cond,
+ while_body,
+ init_state=(epoch_number, iterate) + init_state,
+ extra_deps=())
+ # Dropping the epoch number and iterate because they are not syntactically
+ # visible.
+ results = results[2:]
+
+ # TODO(mdan): Remove this special case.
+ if len(results) == 1:
+ return results[0]
+ return results
+
+
+def while_stmt(test, body, init_state, extra_deps, opts=None):
+ """Functional form of a while statement.
+
+ The loop operates on a so-called state, which includes all symbols that are
+ variant across loop iterations. In what follows we refer to state as either
+ a tuple of entities that represent an actual state, or a list of arguments
+ of the corresponding types.
+
+ Args:
+ test: Callable with the state as arguments, and boolean return type.
+ The loop condition.
+ body: Callable with the state as arguments, and state as return type.
+ The actual loop body.
+ init_state: Tuple containing the initial state.
+ extra_deps: Tuple containing additional entities on which the loop may
+ depend, such as loop invariants referenced by test. Used
+ exclusively for dispatch control.
+ opts: Optional dict of extra loop parameters.
+
+ Returns:
+ Tuple containing the final state.
+ """
+ # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch.
+ # That could be something as simple as a collection of dispatch rules, with
+ # some prioritization.
+ if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
+ return _tf_while_stmt(test, body, init_state, opts)
+ else:
+ return _py_while_stmt(test, body, init_state, opts)
+
+
+def _tf_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that stages a TF while_stmt."""
+ if opts is None:
+ opts = {}
+ return control_flow_ops.while_loop(test, body, init_state, **opts)
+
+
+def _py_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that executes a Python while loop."""
+ del opts
+ state = init_state
+ while test(*state):
+ state = body(*state)
+ return state
+
+
+def if_stmt(cond, body, orelse):
+ """Functional form of an if statement.
+
+ Args:
+ cond: Boolean.
+ body: Callable with no arguments, and outputs of the positive (if) branch
+ as return type.
+ orelse: Callable with no arguments, and outputs of the negative (else)
+ branch as return type.
+
+ Returns:
+ Tuple containing the statement outputs.
+ """
+ if tensor_util.is_tensor(cond):
+ return tf_if_stmt(cond, body, orelse)
+ else:
+ return _py_if_stmt(cond, body, orelse)
+
+
+def tf_if_stmt(cond, body, orelse):
+ """Overload of if_stmt that stages a TF cond."""
+ return control_flow_ops.cond(cond, body, orelse)
+
+
+def _py_if_stmt(cond, body, orelse):
+ """Overload of if_stmt that executes a Python if statement."""
+ return body() if cond else orelse()
diff --git a/tensorflow/python/autograph/operators/control_flow_test.py b/tensorflow/python/autograph/operators/control_flow_test.py
new file mode 100644
index 0000000000..bb214b6f16
--- /dev/null
+++ b/tensorflow/python/autograph/operators/control_flow_test.py
@@ -0,0 +1,100 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import control_flow
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ForLoopTest(test.TestCase):
+
+ def test_tensor(self):
+ s = control_flow.for_stmt(
+ constant_op.constant([1, 2, 3, 4]),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
+ init_state=(0,))
+ with self.cached_session() as sess:
+ self.assertEqual((10,), sess.run(s))
+
+ def test_python(self):
+ s = control_flow.for_stmt(
+ range(5),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
+ init_state=(0,))
+ self.assertEqual(10, s)
+
+ def test_dataset(self):
+ to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
+ s = control_flow.for_stmt(
+ dataset_ops.Dataset.range(5).map(to_int32),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
+ init_state=(0,))
+ with self.cached_session() as sess:
+ self.assertEqual((10,), sess.run(s))
+
+
+class WhileLoopTest(test.TestCase):
+
+ def test_tensor(self):
+ n = constant_op.constant(5)
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i,),
+ init_state=(0, 0),
+ extra_deps=(n,))
+ with self.cached_session() as sess:
+ self.assertEqual((5, 10), sess.run(results))
+
+ def test_python(self):
+ n = 5
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i),
+ init_state=(0, 0),
+ extra_deps=(n,))
+ self.assertEqual((5, 10), results)
+
+
+class IfStmtTest(test.TestCase):
+
+ def test_tensor(self):
+ def test_if_stmt(cond):
+ return control_flow.if_stmt(
+ cond=cond,
+ body=lambda: 1,
+ orelse=lambda: -1)
+
+ with self.cached_session() as sess:
+ self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True))))
+ self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False))))
+
+ def test_python(self):
+ self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1))
+ self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
new file mode 100644
index 0000000000..cc0a3c3544
--- /dev/null
+++ b/tensorflow/python/autograph/operators/data_structures.py
@@ -0,0 +1,338 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators specific to data structures: list append, subscripts, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+# TODO(mdan): Once control flow supports objects, repackage as a class.
+
+
+def new_list(iterable=None):
+ """The list constructor.
+
+ Args:
+ iterable: Optional elements to fill the list with.
+
+ Returns:
+ A list-like object. The exact return value depends on the initial elements.
+ """
+ if iterable:
+ elements = tuple(iterable)
+ else:
+ elements = ()
+
+ if elements:
+ # When the list contains elements, it is assumed to be a "Python" lvalue
+ # list.
+ return _py_list_new(elements)
+ return tf_tensor_list_new(elements)
+
+
+def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
+ """Overload of new_list that stages a Tensor list creation."""
+ elements = tuple(ops.convert_to_tensor(el) for el in elements)
+
+ all_dtypes = set(el.dtype for el in elements)
+ if len(all_dtypes) == 1:
+ inferred_dtype, = tuple(all_dtypes)
+ if element_dtype is not None and element_dtype != inferred_dtype:
+ raise ValueError(
+ 'incompatible dtype; specified: {}, inferred from {}: {}'.format(
+ element_dtype, elements, inferred_dtype))
+ elif len(all_dtypes) > 1:
+ raise ValueError(
+ 'TensorArray requires all elements to have the same dtype:'
+ ' {}'.format(elements))
+ else:
+ if element_dtype is None:
+ raise ValueError('dtype is required to create an empty TensorArray')
+
+ all_shapes = set(tuple(el.shape.as_list()) for el in elements)
+ if len(all_shapes) == 1:
+ inferred_shape, = tuple(all_shapes)
+ if element_shape is not None and element_shape != inferred_shape:
+ raise ValueError(
+ 'incompatible shape; specified: {}, inferred from {}: {}'.format(
+ element_shape, elements, inferred_shape))
+ elif len(all_shapes) > 1:
+ raise ValueError(
+ 'TensorArray requires all elements to have the same shape:'
+ ' {}'.format(elements))
+ # TODO(mdan): We may want to allow different shapes with infer_shape=False.
+ else:
+ inferred_shape = None
+
+ if element_dtype is None:
+ element_dtype = inferred_dtype
+ if element_shape is None:
+ element_shape = inferred_shape
+
+ l = tensor_array_ops.TensorArray(
+ dtype=element_dtype,
+ size=len(elements),
+ dynamic_size=True,
+ infer_shape=(element_shape is None),
+ element_shape=element_shape)
+ for i, el in enumerate(elements):
+ l = l.write(i, el)
+ return l
+
+
+def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
+ """Overload of new_list that stages a Tensor list creation."""
+ elements = tuple(ops.convert_to_tensor(el) for el in elements)
+
+ all_dtypes = set(el.dtype for el in elements)
+ if len(all_dtypes) == 1:
+ inferred_dtype = tuple(all_dtypes)[0]
+ if element_dtype is not None and element_dtype != inferred_dtype:
+ raise ValueError(
+ 'incompatible dtype; specified: {}, inferred from {}: {}'.format(
+ element_dtype, elements, inferred_dtype))
+ else:
+ # Heterogeneous lists are ok.
+ if element_dtype is not None:
+ raise ValueError(
+ 'specified dtype {} is inconsistent with that of elements {}'.format(
+ element_dtype, elements))
+ inferred_dtype = dtypes.variant
+
+ all_shapes = set(tuple(el.shape.as_list()) for el in elements)
+ if len(all_shapes) == 1:
+ inferred_shape = array_ops.shape(elements[0])
+ if element_shape is not None and element_shape != inferred_shape:
+ raise ValueError(
+ 'incompatible shape; specified: {}, inferred from {}: {}'.format(
+ element_shape, elements, inferred_shape))
+ else:
+ # Heterogeneous lists are ok.
+ if element_shape is not None:
+ raise ValueError(
+ 'specified shape {} is inconsistent with that of elements {}'.format(
+ element_shape, elements))
+ inferred_shape = constant_op.constant(-1) # unknown shape, by convention
+
+ if element_dtype is None:
+ element_dtype = inferred_dtype
+ if element_shape is None:
+ element_shape = inferred_shape
+
+ l = list_ops.empty_tensor_list(
+ element_shape=element_shape, element_dtype=element_dtype)
+ for el in elements:
+ l = list_ops.tensor_list_push_back(l, el)
+ return l
+
+
+def _py_list_new(elements):
+ """Overload of new_list that creates a Python list."""
+ return list(elements)
+
+
+def list_append(list_, x):
+ """The list append function.
+
+ Note: it is unspecified where list_ will be mutated or not. If list_ is
+ a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+ list, it will be. In general, if the list is mutated then the return value
+ should point to the original entity.
+
+ Args:
+ list_: An entity that supports append semantics.
+ x: The element to append.
+
+ Returns:
+ Same as list_, after the append was performed.
+
+ Raises:
+ ValueError: if list_ is not of a known list-like type.
+ """
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_append(list_, x)
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_append(list_, x)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % list_)
+ else:
+ return _py_list_append(list_, x)
+
+
+def _tf_tensor_list_append(list_, x):
+ """Overload of list_append that stages a Tensor list write."""
+ def empty_list_of_elements_like_x():
+ tensor_x = ops.convert_to_tensor(x)
+ return list_ops.empty_tensor_list(
+ element_shape=array_ops.shape(tensor_x),
+ element_dtype=tensor_x.dtype)
+
+ list_ = control_flow_ops.cond(
+ list_ops.tensor_list_length(list_) > 0,
+ lambda: list_,
+ empty_list_of_elements_like_x,
+ )
+ return list_ops.tensor_list_push_back(list_, x)
+
+
+def _tf_tensorarray_append(list_, x):
+ """Overload of list_append that stages a TensorArray write."""
+ return list_.write(list_.size(), x)
+
+
+def _py_list_append(list_, x):
+ """Overload of list_append that executes a Python list append."""
+ # Revert to the original call.
+ list_.append(x)
+ return list_
+
+
+class ListPopOpts(
+ collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
+ pass
+
+
+def list_pop(list_, i, opts):
+ """The list pop function.
+
+ Note: it is unspecified where list_ will be mutated or not. If list_ is
+ a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+ list, it will be. In general, if the list is mutated then the return value
+ should point to the original entity.
+
+ Args:
+ list_: An entity that supports pop semantics.
+ i: Optional index to pop from. May be None.
+ opts: A ListPopOpts.
+
+ Returns:
+ Tuple (x, out_list_):
+ out_list_: same as list_, after the removal was performed.
+ x: the removed element value.
+
+ Raises:
+ ValueError: if list_ is not of a known list-like type or the operation is
+ not supported for that type.
+ """
+ assert isinstance(opts, ListPopOpts)
+
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ raise ValueError('TensorArray does not support item removal')
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_pop(list_, i, opts)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % list_)
+ else:
+ return _py_list_pop(list_, i)
+
+
+def _tf_tensor_list_pop(list_, i, opts):
+ """Overload of list_pop that stages a Tensor list pop."""
+ if i is not None:
+ raise NotImplementedError('tensor lists only support removing from the end')
+
+ if opts.element_dtype is None:
+ raise ValueError('cannot pop from a list without knowing its element '
+ 'type; use set_element_type to annotate it')
+ if opts.element_shape is None:
+ raise ValueError('cannot pop from a list without knowing its element '
+ 'shape; use set_element_type to annotate it')
+ list_out, x = list_ops.tensor_list_pop_back(
+ list_, element_dtype=opts.element_dtype)
+ x.set_shape(opts.element_shape)
+ return list_out, x
+
+
+def _py_list_pop(list_, i):
+ """Overload of list_pop that executes a Python list append."""
+ if i is None:
+ x = list_.pop()
+ else:
+ x = list_.pop(i)
+ return list_, x
+
+
+# TODO(mdan): Look into reducing duplication between all these containers.
+class ListStackOpts(
+ collections.namedtuple('ListStackOpts',
+ ('element_dtype', 'original_call'))):
+ pass
+
+
+def list_stack(list_, opts):
+ """The list stack function.
+
+ This does not have a direct correspondent in Python. The closest idiom to
+ this is tf.append or np.stack. It's different from those in the sense that it
+ accepts a Tensor list, rather than a list of tensors. It can also accept
+ TensorArray. When the target is anything else, the dispatcher will rely on
+ ctx.original_call for fallback.
+
+ Args:
+ list_: An entity that supports append semantics.
+ opts: A ListStackOpts object.
+
+ Returns:
+ The output of the stack operation, typically a Tensor.
+ """
+ assert isinstance(opts, ListStackOpts)
+
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_stack(list_)
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_stack(list_, opts)
+ else:
+ # No-op for primitive Tensor arguments.
+ return list_
+ else:
+ return _py_list_stack(list_, opts)
+
+
+def _tf_tensorarray_stack(list_):
+ """Overload of list_stack that stages a TensorArray stack."""
+ return list_.stack()
+
+
+def _tf_tensor_list_stack(list_, opts):
+ """Overload of list_stack that stages a Tensor list write."""
+ if opts.element_dtype is None:
+ raise ValueError('cannot stack a list without knowing its element type;'
+ ' use set_element_type to annotate it')
+ return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
+
+
+def _py_list_stack(list_, opts):
+ """Overload of list_stack that executes a Python list append."""
+ # Revert to the original call.
+ return opts.original_call(list_)
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
new file mode 100644
index 0000000000..8532dbe466
--- /dev/null
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -0,0 +1,158 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for data_structures module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class ListTest(test.TestCase):
+
+ def test_new_list_empty(self):
+ l = data_structures.new_list()
+ # Can't evaluate an empty list.
+ # TODO(mdan): sess.run should allow tf.variant maybe?
+ self.assertTrue(isinstance(l, ops.Tensor))
+
+ def test_new_list_tensor(self):
+ l = data_structures.new_list([3, 4, 5])
+ self.assertAllEqual(l, [3, 4, 5])
+
+ def test_tf_tensor_list_new(self):
+ l = data_structures.tf_tensor_list_new([3, 4, 5])
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
+ def test_tf_tensor_list_new_illegal_input(self):
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([3, 4.0])
+ # TODO(mdan): It might make more sense to type cast in this case.
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([3, 4], element_dtype=dtypes.float32)
+ # Tensor lists do support heterogeneous lists.
+ self.assertIsNot(data_structures.tf_tensor_list_new([3, [4, 5]]), None)
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([], element_shape=(2,))
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
+
+ def test_tf_tensor_array_new(self):
+ l = data_structures.tf_tensor_array_new([3, 4, 5])
+ t = l.stack()
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
+ def test_tf_tensor_array_new_illegal_input(self):
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, 4.0])
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, 4], element_dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, [4, 5]])
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([3, 4], element_shape=(2,))
+ with self.assertRaises(ValueError):
+ data_structures.tf_tensor_array_new([], element_shape=(2,))
+ # TAs can infer the shape.
+ self.assertIsNot(
+ data_structures.tf_tensor_array_new([], element_dtype=dtypes.float32),
+ None)
+
+ def test_append_tensor_list(self):
+ l = data_structures.new_list()
+ x = constant_op.constant([1, 2, 3])
+ l = data_structures.list_append(l, x)
+
+ t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [[1, 2, 3]])
+
+ def test_append_tensorarray(self):
+ l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
+ l1 = data_structures.list_append(l, 1)
+ l2 = data_structures.list_append(l1, 2)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(l1.stack()), [1])
+ self.assertAllEqual(sess.run(l2.stack()), [1, 2])
+
+ def test_append_python(self):
+ l = []
+ self.assertAllEqual(data_structures.list_append(l, 1), [1])
+ self.assertAllEqual(data_structures.list_append(l, 2), [1, 2])
+
+ def test_pop_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+ opts = data_structures.ListPopOpts(
+ element_dtype=initial_list.dtype,
+ element_shape=(2,))
+
+ with self.assertRaises(NotImplementedError):
+ data_structures.list_pop(l, 0, opts)
+
+ with self.cached_session() as sess:
+ l, x = data_structures.list_pop(l, None, opts)
+ self.assertAllEqual(sess.run(x), [3, 4])
+
+ t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+ self.assertAllEqual(sess.run(t), [[1, 2]])
+
+ def test_pop_python(self):
+ l = [1, 2, 3]
+ opts = data_structures.ListPopOpts(element_dtype=None, element_shape=())
+ self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3))
+ self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2))
+
+ def test_stack_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=initial_list.dtype, original_call=None)
+
+ with self.cached_session() as sess:
+ t = data_structures.list_stack(l, opts)
+ self.assertAllEqual(sess.run(t), sess.run(initial_list))
+
+ def test_stack_fallback(self):
+
+ def dummy_function(l):
+ # Lazy person's mock: just transform the argument in a way in which we
+ # can check that this function was indeed called.
+ return [x * 2 for x in l]
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=None, original_call=dummy_function)
+
+ self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/operators/dispatch_context.py b/tensorflow/python/autograph/operators/dispatch_context.py
new file mode 100644
index 0000000000..097002465b
--- /dev/null
+++ b/tensorflow/python/autograph/operators/dispatch_context.py
@@ -0,0 +1,41 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Structures that allow uniform control over the dispatch process."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+# TODO(mdan): This is where macro override controls fit.
+
+
+class DispatchContext(collections.namedtuple(
+ 'DispatchContext',
+ ('options',))):
+ """Allows passing additional parameters to the specific implementations.
+
+ Attributes:
+ options: Optional dict of extra arguments that may be required by specific
+ implementations.
+ """
+
+ def option(self, name):
+ return self.options[name]
+
+
+NO_CTX = DispatchContext(options={})
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
new file mode 100644
index 0000000000..1d37ae72d3
--- /dev/null
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -0,0 +1,225 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators corresponding to Python builtin functions.
+
+List of built-in functions: https://docs.python.org/3/library/functions.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
+
+
+UNDEFINED = object()
+
+
+def overload_of(f):
+ if f in SUPPORTED_BUILTINS:
+ return BUILTIN_FUINCTIONS_MAP[f.__name__]
+ return f
+
+
+def abs_(x):
+ if tensor_util.is_tensor(x):
+ return _tf_abs(x)
+ return _py_abs(x)
+
+
+def _tf_abs(x):
+ return math_ops.abs(x)
+
+
+def _py_abs(x):
+ return abs(x)
+
+
+def float_(x=0):
+ if tensor_util.is_tensor(x):
+ return _tf_float(x)
+ return _py_float(x)
+
+
+def _tf_float(x):
+ # TODO(mdan): We shouldn't assume float32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
+ return math_ops.cast(x, dtype=dtypes.float32)
+
+
+def _py_float(x):
+ return float(x)
+
+
+def int_(x=0, base=UNDEFINED):
+ if tensor_util.is_tensor(x):
+ return _tf_int(x, base)
+ return _py_int(x, base)
+
+
+def _tf_int(x, base):
+ if base not in (10, UNDEFINED):
+ raise NotImplementedError('base {} not supported for int'.format(base))
+
+ # TODO(mdan): We shouldn't assume int32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
+ return math_ops.cast(x, dtype=dtypes.int32)
+
+
+def _py_int(x, base):
+ if base is UNDEFINED:
+ return int(x)
+ return int(x, base)
+
+
+def len_(s):
+ if tensors.is_tensor_array(s):
+ return _tf_tensor_array_len(s)
+ elif tensors.is_tensor_list(s):
+ return _tf_tensor_list_len(s)
+ elif tensor_util.is_tensor(s):
+ return _tf_tensor_len(s)
+ return _py_len(s)
+
+
+def _tf_tensor_array_len(s):
+ return s.size()
+
+
+def _tf_tensor_list_len(s):
+ return list_ops.tensor_list_length(s)
+
+
+def _tf_tensor_len(s):
+ """Overload of len_ for Tensor arguments."""
+ # Statically shaped tensors: length is known ahead of time.
+ if s.shape.ndims and s.shape[0].value is not None:
+ return s.shape[0].value
+
+ # Static shape of unknown dimensions: use dynamic shape but statically
+ # chech that it's a scalar.
+ shape = array_ops.shape(s)
+
+ assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+ if shape.shape[0] == 0:
+ raise ValueError(
+ 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
+ if shape.shape[0].value is not None:
+ return array_ops.shape(s)[0]
+
+ # Fully dynamic shape: use ops.
+ rank = array_ops.rank(s)
+
+ def raise_zero_rank_error():
+ msg = gen_string_ops.string_join(
+ ['len requires non-zero rank, got ',
+ gen_string_ops.as_string(rank)])
+ with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
+ return constant_op.constant(0, dtype=dtypes.int32)
+
+ return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
+ raise_zero_rank_error)
+
+
+def _py_len(s):
+ return len(s)
+
+
+def print_(*objects, **kwargs):
+ # Note: Python 2.6 doesn't support explicit keywords after starargs.
+ unknown_kwargs = tuple(
+ set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
+ if unknown_kwargs:
+ raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
+
+ # TODO(mdan): use logging_ops.Print when py_func is not supported.
+ return _tf_py_func_print(objects, kwargs)
+
+
+def _tf_py_func_print(objects, kwargs):
+ """Overload of print_ as a py_func implementation."""
+ override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
+ if 'flush' not in override_kwargs:
+ # Defaulting to flushing the console in graph mode, which helps reduce
+ # garbled output in IPython.
+ override_kwargs['flush'] = True
+
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(
+ v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
+ six.print_(*vals, **override_kwargs)
+
+ return py_func.wrap_py_func(
+ print_wrapper, None, objects, use_dummy_return=True)
+
+
+def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
+ if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
+ return _tf_range(start_or_stop, stop, step)
+ return _py_range(start_or_stop, stop, step)
+
+
+def _tf_range(start_or_stop, stop, step):
+ # TODO(mdan): We should optimize this when a full tensor is not required.
+ if step is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop)
+ return math_ops.range(start_or_stop)
+
+
+def _py_range(start_or_stop, stop, step):
+ if step is not UNDEFINED:
+ return range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return range(start_or_stop, stop)
+ return range(start_or_stop)
+
+
+SUPPORTED_BUILTINS = set((abs, float, int, len, print, range))
+
+if six.PY2:
+ SUPPORTED_BUILTINS.add(xrange)
+
+BUILTIN_FUINCTIONS_MAP = {
+ 'abs': abs_,
+ 'float': float_,
+ 'int': int_,
+ 'len': len_,
+ 'print': print_,
+ 'range': range_,
+ 'xrange': range_,
+}
diff --git a/tensorflow/python/autograph/operators/py_builtins_test.py b/tensorflow/python/autograph/operators/py_builtins_test.py
new file mode 100644
index 0000000000..a021263ffa
--- /dev/null
+++ b/tensorflow/python/autograph/operators/py_builtins_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.autograph.operators import py_builtins
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class PyBuiltinsTest(test.TestCase):
+
+ def test_abs(self):
+ self.assertEqual(py_builtins.abs_(-1), 1)
+ with self.test_session() as sess:
+ t = py_builtins.abs_(constant_op.constant(-1))
+ self.assertEqual(sess.run(t), 1)
+ t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
+ self.assertAllEqual(sess.run(t), [1, 2, 3])
+
+ def test_float(self):
+ self.assertEqual(py_builtins.float_(10), 10.0)
+ self.assertEqual(py_builtins.float_('10.0'), 10.0)
+ with self.test_session() as sess:
+ t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
+ self.assertEqual(sess.run(t), 1.0)
+ st = py_builtins.float_(constant_op.constant('1.0'))
+ self.assertEqual(sess.run(st), 1.0)
+
+ def test_int(self):
+ self.assertEqual(py_builtins.int_(10.0), 10)
+ self.assertEqual(py_builtins.int_('11', 2), 3)
+ with self.test_session() as sess:
+ t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
+ self.assertEqual(sess.run(t), 1)
+ st = py_builtins.int_(constant_op.constant('1'))
+ self.assertEqual(sess.run(st), 1)
+ st = py_builtins.int_(constant_op.constant('1'), 10)
+ self.assertEqual(sess.run(st), 1)
+
+ def test_int_unsupported_base(self):
+ t = constant_op.constant(1, dtype=dtypes.float64)
+ with self.assertRaises(NotImplementedError):
+ py_builtins.int_(t, 2)
+
+ def test_len(self):
+ self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
+ with self.test_session() as sess:
+ t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
+ self.assertEqual(t, 3)
+ ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
+ self.assertEqual(sess.run(ta), 5)
+ tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
+ self.assertEqual(sess.run(tl), 3)
+
+ def test_len_scalar(self):
+ with self.assertRaises(ValueError):
+ py_builtins.len_(constant_op.constant(1))
+
+ def test_len_dynamic_shape(self):
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ t = py_builtins.len_(p)
+ self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
+
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ t = py_builtins.len_(p)
+ sess.run(t, {p: 1})
+
+ def test_print_tensors(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
+ self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_complex(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(
+ py_builtins.print_(constant_op.constant('test message'), [1, 2]))
+ self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_range(self):
+ self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
+ self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
+ self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
+
+ def test_range_tensor(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [0, 1, 2])
+ r = py_builtins.range_(1, constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [1, 2])
+ r = py_builtins.range_(2, 0, constant_op.constant(-1))
+ self.assertAllEqual(sess.run(r), [2, 1])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/operators/slices.py b/tensorflow/python/autograph/operators/slices.py
new file mode 100644
index 0000000000..2b7f5ad922
--- /dev/null
+++ b/tensorflow/python/autograph/operators/slices.py
@@ -0,0 +1,142 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators specific to slicing operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+# TODO(mdan): Support extended slices.
+
+
+class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))):
+ pass
+
+
+def get_item(target, i, opts):
+ """The slice read operator (i.e. __getitem__).
+
+ Note: it is unspecified whether target will be mutated or not. In general,
+ if target is mutable (like Python lists), it will be mutated.
+
+ Args:
+ target: An entity that supports getitem semantics.
+ i: Index to read from.
+ opts: A GetItemOpts object.
+
+ Returns:
+ The read element.
+
+ Raises:
+ ValueError: if target is not of a supported type.
+ """
+ assert isinstance(opts, GetItemOpts)
+
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_get_item(target, i)
+ elif tensor_util.is_tensor(target):
+ if target.dtype == dtypes.variant:
+ return _tf_tensor_list_get_item(target, i, opts)
+ elif target.dtype == dtypes.string and target.shape.ndims == 0:
+ return _tf_tensor_string_get_item(target, i)
+ else:
+ return _tf_tensor_get_item(target, i)
+ else:
+ return _py_get_item(target, i)
+
+
+def _tf_tensorarray_get_item(target, i):
+ """Overload of get_item that stages a TensorArray read."""
+ return target.read(i)
+
+
+def _tf_tensor_list_get_item(target, i, opts):
+ """Overload of get_item that stages a Tensor list read."""
+ if opts.element_dtype is None:
+ raise ValueError('cannot retrieve from a list without knowing its '
+ 'element type; use set_element_type to annotate it')
+ x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
+ return x
+
+
+def _tf_tensor_get_item(target, i):
+ """Overload of get_item that stages a Tensor (not Tensor list) read."""
+ return target[i]
+
+
+def _tf_tensor_string_get_item(target, i):
+ """Overload of get_item that stages a Tensor string read."""
+ x = gen_string_ops.substr(target, i, 1)
+ return x
+
+
+def _py_get_item(target, i):
+ """Overload of get_item that executes a Python list modification."""
+ return target[i]
+
+
+def set_item(target, i, x):
+ """The slice write operator (i.e. __setitem__).
+
+ Note: it is unspecified whether target will be mutated or not. In general,
+ if target is mutable (like Python lists), it will be mutated.
+
+ Args:
+ target: An entity that supports setitem semantics.
+ i: Index to modify.
+ x: The new element value.
+
+ Returns:
+ Same as target, after the update was performed.
+
+ Raises:
+ ValueError: if target is not of a supported type.
+ """
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_set_item(target, i, x)
+ elif tensor_util.is_tensor(target):
+ if target.dtype == dtypes.variant:
+ return _tf_tensor_list_set_item(target, i, x)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % target)
+ else:
+ return _py_set_item(target, i, x)
+
+
+def _tf_tensorarray_set_item(target, i, x):
+ """Overload of set_item that stages a TensorArray write."""
+ return target.write(i, x)
+
+
+def _tf_tensor_list_set_item(target, i, x):
+ """Overload of set_item that stages a Tensor list update."""
+ return list_ops.tensor_list_set_item(target, i, x)
+
+
+def _py_set_item(target, i, x):
+ """Overload of set_item that executes a Python list modification."""
+ target[i] = x
+ return target
diff --git a/tensorflow/python/autograph/operators/slices_test.py b/tensorflow/python/autograph/operators/slices_test.py
new file mode 100644
index 0000000000..d8b8418750
--- /dev/null
+++ b/tensorflow/python/autograph/operators/slices_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import slices
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SlicesTest(test.TestCase):
+
+ def test_set_item_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+ l = slices.set_item(l, 0, [5, 6])
+
+ with self.cached_session() as sess:
+ t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+ self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
+
+ def test_get_item_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+ t = slices.get_item(
+ l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
+
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4])
+
+ def test_get_item_tensor_string(self):
+ initial_str = constant_op.constant('abcd')
+ t = slices.get_item(initial_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'b')
+
+ initial_list_str = constant_op.constant(['abcd', 'bcde'])
+ t = slices.get_item(initial_list_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'bcde')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/BUILD b/tensorflow/python/autograph/pyct/BUILD
new file mode 100644
index 0000000000..ddadc6b96e
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/BUILD
@@ -0,0 +1,163 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "pyct",
+ srcs = [
+ "__init__.py",
+ "anno.py",
+ "ast_util.py",
+ "cfg.py",
+ "compiler.py",
+ "inspect_utils.py",
+ "origin_info.py",
+ "parser.py",
+ "pretty_printer.py",
+ "qual_names.py",
+ "templates.py",
+ "transformer.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@astor_archive//:astor",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ "@termcolor_archive//:termcolor",
+ # TODO(mdan): Remove this dependency.
+ "//tensorflow/python:util",
+ ],
+)
+
+py_test(
+ name = "anno_test",
+ srcs = ["anno_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "ast_util_test",
+ srcs = ["ast_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "cfg_test",
+ srcs = ["cfg_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "compiler_test",
+ srcs = ["compiler_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "inspect_utils_test",
+ srcs = ["inspect_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "origin_info_test",
+ srcs = ["origin_info_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "parser_test",
+ srcs = ["parser_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "pretty_printer_test",
+ srcs = ["pretty_printer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "qual_names_test",
+ srcs = ["qual_names_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "templates_test",
+ srcs = ["templates_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "transformer_test",
+ srcs = ["transformer_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":pyct",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
diff --git a/tensorflow/python/autograph/pyct/__init__.py b/tensorflow/python/autograph/pyct/__init__.py
new file mode 100644
index 0000000000..d787e56bbe
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python source code transformation library."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
new file mode 100644
index 0000000000..1a52110ef3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST node annotation support.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import enum
+
+# pylint:disable=g-bad-import-order
+import gast
+# pylint:enable=g-bad-import-order
+
+
+# TODO(mdan): Shorten the names.
+# These names are heavily used, and anno.blaa
+# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.
+
+
+class NoValue(enum.Enum):
+
+ def __repr__(self):
+ return self.name
+
+
+class Basic(NoValue):
+ """Container for basic annotation keys.
+
+ The enum values are used strictly for documentation purposes.
+ """
+
+ QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
+ SKIP_PROCESSING = (
+ 'This node should be preserved as is and not processed any further.')
+ INDENT_BLOCK_REMAINDER = (
+ 'When a node is annotated with this, the remainder of the block should'
+ ' be indented below it. The annotation contains a tuple'
+ ' (new_body, name_map), where `new_body` is the new indented block and'
+ ' `name_map` allows renaming symbols.')
+ ORIGIN = ('Information about the source code that converted code originated'
+ ' from. See origin_information.py.')
+
+
+class Static(NoValue):
+ """Container for static analysis annotation keys.
+
+ The enum values are used strictly for documentation purposes.
+ """
+
+ # Deprecated - use reaching definitions instead.
+ # Symbols
+ # These flags are boolean.
+ IS_LOCAL = 'Symbol is local to the function scope being analyzed.'
+ IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
+
+ # Scopes
+ # Scopes are represented by objects of type activity.Scope.
+ SCOPE = 'The scope for the annotated node. See activity.py.'
+ # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
+ ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
+ BODY_SCOPE = (
+ 'The scope for the main body of a statement (True branch for if '
+ 'statements, main body for loops).')
+ ORELSE_SCOPE = (
+ 'The scope for the orelse body of a statement (False branch for if '
+ 'statements, orelse body for loops).')
+
+ # Static analysis annotations.
+ DEFINITIONS = (
+ 'Reaching definition information. See reaching_definitions.py.')
+ ORIG_DEFINITIONS = (
+ 'The value of DEFINITIONS that applied to the original code before any'
+ ' conversion.')
+ DEFINED_VARS_IN = (
+ 'Symbols defined when entering the node. See reaching_definitions.py.')
+ LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
+
+
+FAIL = object()
+
+
+def keys(node, field_name='___pyct_anno'):
+ if not hasattr(node, field_name):
+ return frozenset()
+ return frozenset(getattr(node, field_name).keys())
+
+
+def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
+ if (default is FAIL or (hasattr(node, field_name) and
+ (key in getattr(node, field_name)))):
+ return getattr(node, field_name)[key]
+ else:
+ return default
+
+
+def hasanno(node, key, field_name='___pyct_anno'):
+ return hasattr(node, field_name) and key in getattr(node, field_name)
+
+
+def setanno(node, key, value, field_name='___pyct_anno'):
+ annotations = getattr(node, field_name, {})
+ setattr(node, field_name, annotations)
+ annotations[key] = value
+
+ # So that the annotations survive gast_to_ast() and ast_to_gast()
+ if field_name not in node._fields:
+ node._fields += (field_name,)
+
+
+def delanno(node, key, field_name='___pyct_anno'):
+ annotations = getattr(node, field_name)
+ del annotations[key]
+ if not annotations:
+ delattr(node, field_name)
+ node._fields = tuple(f for f in node._fields if f != field_name)
+
+
+def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
+ if hasanno(from_node, key, field_name=field_name):
+ setanno(
+ to_node,
+ key,
+ getanno(from_node, key, field_name=field_name),
+ field_name=field_name)
+
+
+def dup(node, copy_map, field_name='___pyct_anno'):
+ """Recursively copies annotations in an AST tree.
+
+ Args:
+ node: ast.AST
+ copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
+ key. All annotations with the source key will be copied to identical
+ annotations with the destination key.
+ field_name: str
+ """
+ for n in gast.walk(node):
+ for k in copy_map:
+ if hasanno(n, k, field_name):
+ setanno(n, copy_map[k], getanno(n, k, field_name), field_name)
diff --git a/tensorflow/python/autograph/pyct/anno_test.py b/tensorflow/python/autograph/pyct/anno_test.py
new file mode 100644
index 0000000000..1f873871c6
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/anno_test.py
@@ -0,0 +1,84 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for anno module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.platform import test
+
+
+# TODO(mdan): Consider strong types instead of primitives.
+
+
+class AnnoTest(test.TestCase):
+
+ def test_basic(self):
+ node = ast.Name()
+
+ self.assertEqual(anno.keys(node), set())
+ self.assertFalse(anno.hasanno(node, 'foo'))
+ with self.assertRaises(AttributeError):
+ anno.getanno(node, 'foo')
+
+ anno.setanno(node, 'foo', 3)
+
+ self.assertEqual(anno.keys(node), {'foo'})
+ self.assertTrue(anno.hasanno(node, 'foo'))
+ self.assertEqual(anno.getanno(node, 'foo'), 3)
+ self.assertEqual(anno.getanno(node, 'bar', default=7), 7)
+
+ anno.delanno(node, 'foo')
+
+ self.assertEqual(anno.keys(node), set())
+ self.assertFalse(anno.hasanno(node, 'foo'))
+ with self.assertRaises(AttributeError):
+ anno.getanno(node, 'foo')
+ self.assertIsNone(anno.getanno(node, 'foo', default=None))
+
+ def test_copy(self):
+ node_1 = ast.Name()
+ anno.setanno(node_1, 'foo', 3)
+
+ node_2 = ast.Name()
+ anno.copyanno(node_1, node_2, 'foo')
+ anno.copyanno(node_1, node_2, 'bar')
+
+ self.assertTrue(anno.hasanno(node_2, 'foo'))
+ self.assertFalse(anno.hasanno(node_2, 'bar'))
+
+ def test_duplicate(self):
+ node = ast.If(
+ test=ast.Num(1),
+ body=[ast.Expr(ast.Name('bar', ast.Load()))],
+ orelse=[])
+ anno.setanno(node, 'spam', 1)
+ anno.setanno(node, 'ham', 1)
+ anno.setanno(node.body[0], 'ham', 1)
+
+ anno.dup(node, {'spam': 'eggs'})
+
+ self.assertTrue(anno.hasanno(node, 'spam'))
+ self.assertTrue(anno.hasanno(node, 'ham'))
+ self.assertTrue(anno.hasanno(node, 'eggs'))
+ self.assertFalse(anno.hasanno(node.body[0], 'eggs'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/ast_util.py b/tensorflow/python/autograph/pyct/ast_util.py
new file mode 100644
index 0000000000..7df3b8858c
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/ast_util.py
@@ -0,0 +1,313 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST manipulation utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+
+
+class CleanCopier(object):
+ """NodeTransformer-like visitor that copies an AST."""
+
+ def __init__(self, preserve_annos):
+ super(CleanCopier, self).__init__()
+ self.preserve_annos = preserve_annos
+
+ def copy(self, node):
+ """Returns a deep copy of node (excluding some fields, see copy_clean)."""
+
+ if isinstance(node, list):
+ return [self.copy(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(self.copy(n) for n in node)
+ elif not isinstance(node, (gast.AST, ast.AST)):
+ # Assuming everything that's not an AST, list or tuple is a value type
+ # and may simply be assigned.
+ return node
+
+ assert isinstance(node, (gast.AST, ast.AST))
+
+ new_fields = {}
+ for f in node._fields:
+ if not f.startswith('__') and hasattr(node, f):
+ new_fields[f] = self.copy(getattr(node, f))
+ new_node = type(node)(**new_fields)
+
+ if self.preserve_annos:
+ for k in self.preserve_annos:
+ anno.copyanno(node, new_node, k)
+ return new_node
+
+
+def copy_clean(node, preserve_annos=None):
+ """Creates a deep copy of an AST.
+
+ The copy will not include fields that are prefixed by '__', with the
+ exception of user-specified annotations.
+
+ Args:
+ node: ast.AST
+ preserve_annos: Optional[Set[Hashable]], annotation keys to include in the
+ copy
+ Returns:
+ ast.AST
+ """
+ return CleanCopier(preserve_annos).copy(node)
+
+
+class SymbolRenamer(gast.NodeTransformer):
+ """Transformer that can rename symbols to a simple names."""
+
+ def __init__(self, name_map):
+ self.name_map = name_map
+
+ def _process(self, node):
+ qn = anno.getanno(node, anno.Basic.QN)
+ if qn in self.name_map:
+ new_node = gast.Name(str(self.name_map[qn]), node.ctx, None)
+ # All annotations get carried over.
+ for k in anno.keys(node):
+ anno.copyanno(node, new_node, k)
+ return new_node
+ return self.generic_visit(node)
+
+ def visit_Name(self, node):
+ return self._process(node)
+
+ def visit_Attribute(self, node):
+ if anno.hasanno(node, anno.Basic.QN):
+ return self._process(node)
+ # Attributes of dynamic objects will not have a QN.
+ return self.generic_visit(node)
+
+
+def rename_symbols(node, name_map):
+ """Renames symbols in an AST. Requires qual_names annotations."""
+ renamer = SymbolRenamer(name_map)
+ if isinstance(node, list):
+ return [renamer.visit(n) for n in node]
+ elif isinstance(node, tuple):
+ return tuple(renamer.visit(n) for n in node)
+ return renamer.visit(node)
+
+
+def keywords_to_dict(keywords):
+ """Converts a list of ast.keyword objects to a dict."""
+ keys = []
+ values = []
+ for kw in keywords:
+ keys.append(gast.Str(kw.arg))
+ values.append(kw.value)
+ return gast.Dict(keys=keys, values=values)
+
+
+class PatternMatcher(gast.NodeVisitor):
+ """Matches a node against a pattern represented by a node."""
+
+ def __init__(self, pattern):
+ self.pattern = pattern
+ self.pattern_stack = []
+ self.matches = True
+
+ def compare_and_visit(self, node, pattern):
+ self.pattern_stack.append(self.pattern)
+ self.pattern = pattern
+ self.generic_visit(node)
+ self.pattern = self.pattern_stack.pop()
+
+ def no_match(self):
+ self.matches = False
+ return False
+
+ def is_wildcard(self, p):
+ if isinstance(p, (list, tuple)) and len(p) == 1:
+ p, = p
+ if isinstance(p, gast.Name) and p.id == '_':
+ return True
+ if p == '_':
+ return True
+ return False
+
+ def generic_visit(self, node):
+ if not self.matches:
+ return
+
+ pattern = self.pattern
+ for f in node._fields:
+ if f.startswith('__'):
+ continue
+
+ if not hasattr(node, f):
+ if hasattr(pattern, f) and getattr(pattern, f):
+ return self.no_match()
+ else:
+ continue
+ if not hasattr(pattern, f):
+ return self.no_match()
+
+ v = getattr(node, f)
+ p = getattr(pattern, f)
+
+ if self.is_wildcard(p):
+ continue
+ if isinstance(v, (list, tuple)):
+ if not isinstance(p, (list, tuple)) or len(v) != len(p):
+ return self.no_match()
+ for v_item, p_item in zip(v, p):
+ self.compare_and_visit(v_item, p_item)
+ elif isinstance(v, (gast.AST, ast.AST)):
+ if not isinstance(v, type(p)) and not isinstance(p, type(v)):
+ return self.no_match()
+ self.compare_and_visit(v, p)
+ else:
+ # Assume everything else is a value type.
+ if v != p:
+ return self.no_match()
+
+
+def matches(node, pattern):
+ """Basic pattern matcher for AST.
+
+ The pattern may contain wildcards represented by the symbol '_'. A node
+ matches a pattern if for every node in the tree, either there is a node of
+ the same type in pattern, or a Name node with id='_'.
+
+ Args:
+ node: ast.AST
+ pattern: ast.AST
+ Returns:
+ bool
+ """
+ if isinstance(pattern, str):
+ pattern = parser.parse_expression(pattern)
+ matcher = PatternMatcher(pattern)
+ matcher.visit(node)
+ return matcher.matches
+
+
+# TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+def apply_to_single_assignments(targets, values, apply_fn):
+ """Applies a function to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: Union[List[ast.AST, ...], Tuple[ast.AST, ...], ast.AST, should be
+ used with the targets field of an ast.Assign node
+ values: ast.AST
+ apply_fn: Callable[[ast.AST, ast.AST], None], called with the
+ respective nodes of each single assignment
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ idx = parser.parse_expression(str(i))
+ value_el = gast.Subscript(values, gast.Index(idx), ctx=gast.Load())
+ apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ apply_fn(target, values)
+
+
+def parallel_walk(node, other):
+ """Walks two ASTs in parallel.
+
+ The two trees must have identical structure.
+
+ Args:
+ node: Union[ast.AST, Iterable[ast.AST]]
+ other: Union[ast.AST, Iterable[ast.AST]]
+ Yields:
+ Tuple[ast.AST, ast.AST]
+ Raises:
+ ValueError: if the two trees don't have identical structure.
+ """
+ if isinstance(node, (list, tuple)):
+ node_stack = list(node)
+ else:
+ node_stack = [node]
+
+ if isinstance(other, (list, tuple)):
+ other_stack = list(other)
+ else:
+ other_stack = [other]
+
+ while node_stack and other_stack:
+ assert len(node_stack) == len(other_stack)
+ n = node_stack.pop()
+ o = other_stack.pop()
+
+ if (not isinstance(n, (ast.AST, gast.AST)) or
+ not isinstance(o, (ast.AST, gast.AST)) or
+ n.__class__.__name__ != o.__class__.__name__):
+ raise ValueError('inconsistent nodes: {} and {}'.format(n, o))
+
+ yield n, o
+
+ for f in n._fields:
+ n_child = getattr(n, f, None)
+ o_child = getattr(o, f, None)
+ if f.startswith('__') or n_child is None or o_child is None:
+ continue
+
+ if isinstance(n_child, (list, tuple)):
+ if (not isinstance(o_child, (list, tuple)) or
+ len(n_child) != len(o_child)):
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
+ node_stack.extend(n_child)
+ other_stack.extend(o_child)
+
+ elif isinstance(n_child, (gast.AST, ast.AST)):
+ node_stack.append(n_child)
+ other_stack.append(o_child)
+
+ elif n_child != o_child:
+ raise ValueError(
+ 'inconsistent values for field {}: {} and {}'.format(
+ f, n_child, o_child))
diff --git a/tensorflow/python/autograph/pyct/ast_util_test.py b/tensorflow/python/autograph/pyct/ast_util_test.py
new file mode 100644
index 0000000000..b1577c466e
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/ast_util_test.py
@@ -0,0 +1,196 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ast_util module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import collections
+import textwrap
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.platform import test
+
+
+class AstUtilTest(test.TestCase):
+
+ def setUp(self):
+ super(AstUtilTest, self).setUp()
+ self._invocation_counts = collections.defaultdict(lambda: 0)
+
+ def test_rename_symbols_basic(self):
+ node = parser.parse_str('a + b')
+ node = qual_names.resolve(node)
+
+ node = ast_util.rename_symbols(
+ node, {qual_names.QN('a'): qual_names.QN('renamed_a')})
+
+ self.assertIsInstance(node.body[0].value.left.id, str)
+ source = compiler.ast_to_source(node)
+ self.assertEqual(source.strip(), 'renamed_a + b')
+
+ def test_rename_symbols_attributes(self):
+ node = parser.parse_str('b.c = b.c.d')
+ node = qual_names.resolve(node)
+
+ node = ast_util.rename_symbols(
+ node, {qual_names.from_str('b.c'): qual_names.QN('renamed_b_c')})
+
+ source = compiler.ast_to_source(node)
+ self.assertEqual(source.strip(), 'renamed_b_c = renamed_b_c.d')
+
+ def test_rename_symbols_annotations(self):
+ node = parser.parse_str('a[i]')
+ node = qual_names.resolve(node)
+ anno.setanno(node, 'foo', 'bar')
+ orig_anno = anno.getanno(node, 'foo')
+
+ node = ast_util.rename_symbols(node,
+ {qual_names.QN('a'): qual_names.QN('b')})
+
+ self.assertIs(anno.getanno(node, 'foo'), orig_anno)
+
+ def test_copy_clean(self):
+ node = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ setattr(node.body[0], '__foo', 'bar')
+ new_node = ast_util.copy_clean(node)
+ self.assertIsNot(new_node, node)
+ self.assertIsNot(new_node.body[0], node.body[0])
+ self.assertFalse(hasattr(new_node.body[0], '__foo'))
+
+ def test_copy_clean_preserves_annotations(self):
+ node = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ anno.setanno(node.body[0], 'foo', 'bar')
+ anno.setanno(node.body[0], 'baz', 1)
+ new_node = ast_util.copy_clean(node, preserve_annos={'foo'})
+ self.assertEqual(anno.getanno(new_node.body[0], 'foo'), 'bar')
+ self.assertFalse(anno.hasanno(new_node.body[0], 'baz'))
+
+ def test_keywords_to_dict(self):
+ keywords = parser.parse_expression('f(a=b, c=1, d=\'e\')').keywords
+ d = ast_util.keywords_to_dict(keywords)
+ # Make sure we generate a usable dict node by attaching it to a variable and
+ # compiling everything.
+ node = parser.parse_str('def f(b): pass').body[0]
+ node.body.append(ast.Return(d))
+ result, _ = compiler.ast_to_object(node)
+ self.assertDictEqual(result.f(3), {'a': 3, 'c': 1, 'd': 'e'})
+
+ def assertMatch(self, target_str, pattern_str):
+ node = parser.parse_expression(target_str)
+ pattern = parser.parse_expression(pattern_str)
+ self.assertTrue(ast_util.matches(node, pattern))
+
+ def assertNoMatch(self, target_str, pattern_str):
+ node = parser.parse_expression(target_str)
+ pattern = parser.parse_expression(pattern_str)
+ self.assertFalse(ast_util.matches(node, pattern))
+
+ def test_matches_symbols(self):
+ self.assertMatch('foo', '_')
+ self.assertNoMatch('foo()', '_')
+ self.assertMatch('foo + bar', 'foo + _')
+ self.assertNoMatch('bar + bar', 'foo + _')
+ self.assertNoMatch('foo - bar', 'foo + _')
+
+ def test_matches_function_args(self):
+ self.assertMatch('super(Foo, self).__init__(arg1, arg2)',
+ 'super(_).__init__(_)')
+ self.assertMatch('super().__init__()', 'super(_).__init__(_)')
+ self.assertNoMatch('super(Foo, self).bar(arg1, arg2)',
+ 'super(_).__init__(_)')
+ self.assertMatch('super(Foo, self).__init__()', 'super(Foo, _).__init__(_)')
+ self.assertNoMatch('super(Foo, self).__init__()',
+ 'super(Bar, _).__init__(_)')
+
+ def _mock_apply_fn(self, target, source):
+ target = compiler.ast_to_source(target)
+ source = compiler.ast_to_source(source)
+ self._invocation_counts[(target.strip(), source.strip())] += 1
+
+ def test_apply_to_single_assignments_dynamic_unpack(self):
+ node = parser.parse_str('a, b, c = d')
+ node = node.body[0]
+ ast_util.apply_to_single_assignments(node.targets, node.value,
+ self._mock_apply_fn)
+ self.assertDictEqual(self._invocation_counts, {
+ ('a', 'd[0]'): 1,
+ ('b', 'd[1]'): 1,
+ ('c', 'd[2]'): 1,
+ })
+
+ def test_apply_to_single_assignments_static_unpack(self):
+ node = parser.parse_str('a, b, c = d, e, f')
+ node = node.body[0]
+ ast_util.apply_to_single_assignments(node.targets, node.value,
+ self._mock_apply_fn)
+ self.assertDictEqual(self._invocation_counts, {
+ ('a', 'd'): 1,
+ ('b', 'e'): 1,
+ ('c', 'f'): 1,
+ })
+
+ def test_parallel_walk(self):
+ node = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ for child_a, child_b in ast_util.parallel_walk(node, node):
+ self.assertEqual(child_a, child_b)
+
+ def test_parallel_walk_inconsistent_trees(self):
+ node_1 = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 1
+ """))
+ node_2 = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + (a * 2)
+ """))
+ node_3 = parser.parse_str(
+ textwrap.dedent("""
+ def f(a):
+ return a + 2
+ """))
+ with self.assertRaises(ValueError):
+ for _ in ast_util.parallel_walk(node_1, node_2):
+ pass
+ # There is not particular reason to reject trees that differ only in the
+ # value of a constant.
+ # TODO(mdan): This should probably be allowed.
+ with self.assertRaises(ValueError):
+ for _ in ast_util.parallel_walk(node_1, node_3):
+ pass
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
new file mode 100644
index 0000000000..1433f9ac83
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -0,0 +1,815 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow graph (CFG) structure for Python AST representation.
+
+The CFG is a digraph with edges representing valid control flow. Each
+node is associated with exactly one AST node, but not all AST nodes may have
+a corresponding CFG counterpart.
+
+Once built, the CFG itself is immutable, but the values it holds need not be;
+they are usually annotated with information extracted by walking the graph.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+from enum import Enum
+
+# pylint:disable=g-bad-import-order
+import gast
+# pylint:enable=g-bad-import-order
+
+from tensorflow.python.autograph.pyct import compiler
+
+
+class Node(object):
+ """A node in the CFG.
+
+ Although new instances of this class are mutable, the objects that a user
+ finds in the CFG are typically not.
+
+ The nodes represent edges in the CFG graph, and maintain pointers to allow
+ efficient walking in both forward and reverse order. The following property
+ holds for all nodes: "child in node.next" iff "node in child.prev".
+
+ Attributes:
+ next: FrozenSet[Node, ...], the nodes that follow this node, in control
+ flow order
+ prev: FrozenSet[Node, ...], the nodes that precede this node, in reverse
+ control flow order
+ ast_node: ast.AST, the AST node corresponding to this CFG node
+ """
+
+ def __init__(self, next_, prev, ast_node):
+ self.next = next_
+ self.prev = prev
+ self.ast_node = ast_node
+
+ def freeze(self):
+ self.next = frozenset(self.next)
+ self.prev = frozenset(self.prev)
+
+ def __repr__(self):
+ if isinstance(self.ast_node, gast.FunctionDef):
+ return 'def %s' % self.ast_node.name
+ elif isinstance(self.ast_node, gast.withitem):
+ return compiler.ast_to_source(self.ast_node.context_expr).strip()
+ return compiler.ast_to_source(self.ast_node).strip()
+
+
+class Graph(
+ collections.namedtuple(
+ 'Graph',
+ ['entry', 'exit', 'error', 'index', 'stmt_prev', 'stmt_next'])):
+ """A Control Flow Graph.
+
+ The CFG maintains an index to allow looking up a CFG node by the AST node to
+ which it is associated. The index can also be enumerated in top-down, depth
+ first order.
+
+ Walking the graph in forward or reverse order is supported by double
+ parent-child links.
+
+ Note: the error nodes are not wired to their corresponding finally guards,
+ because these are shared, and wiring them would create a reverse path from
+ normal control flow into the error nodes, which we want to avoid.
+
+ The graph also maintains edges corresponding to higher level statements
+ like for-else loops. A node is considered successor of a statement if there
+ is an edge from a node that is lexically a child of that statement to a node
+ that is not. Statement predecessors are analogously defined.
+
+ Attributes:
+ entry: Node, the entry node
+ exit: FrozenSet[Node, ...], the exit nodes
+ error: FrozenSet[Node, ...], nodes that exit due to an explicitly raised
+ error (errors propagated from function calls are not accounted)
+ index: Dict[ast.Node, Node], mapping AST nodes to the respective CFG
+ node
+ stmt_prev: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST
+ nodes to their predecessor CFG nodes
+ stmt_next: Dict[ast.Node, FrozenSet[Node, ...]], mapping statement AST
+ nodes to their successor CFG nodes
+ """
+
+ def __repr__(self):
+ result = 'digraph CFG {\n'
+ for node in self.index.values():
+ result += ' %s [label="%s"];\n' % (id(node), node)
+ for node in self.index.values():
+ for next_ in node.next:
+ result += ' %s -> %s;\n' % (id(node), id(next_))
+ result += '}'
+ return result
+
+
+class _WalkMode(Enum):
+ FORWARD = 1
+ REVERSE = 2
+
+
+# TODO(mdan): Rename to DataFlowAnalyzer.
+# TODO(mdan): Consider specializations that use gen/kill/transfer abstractions.
+class GraphVisitor(object):
+ """Base class for a CFG visitors.
+
+ This implementation is not thread safe.
+
+ The visitor has some facilities to simplify dataflow analyses. In particular,
+ it allows revisiting the nodes at the decision of the subclass. This can be
+ used to visit the graph until the state reaches a fixed point.
+
+ For more details on dataflow analysis, see
+ https://www.seas.harvard.edu/courses/cs252/2011sp/slides/Lec02-Dataflow.pdf
+
+ Note: the literature generally suggests visiting successor nodes only when the
+ state of the current node changed, regardless of whether that successor has
+ ever been visited. This implementation visits every successor at least once.
+
+ Attributes:
+ graph: Graph
+ in_: Dict[Node, Any], stores node-keyed state during a visit
+ out: Dict[Node, Any], stores node-keyed state during a visit
+ """
+
+ def __init__(self, graph):
+ self.graph = graph
+ self.reset()
+
+ def init_state(self, node):
+ """State initialization function. Optional to overload.
+
+ An in/out state slot will be created for each node in the graph. Subclasses
+ must overload this to control what that is initialized to.
+
+ Args:
+ node: Node
+ """
+ raise NotImplementedError('Subclasses must implement this.')
+
+ # TODO(mdan): Rename to flow?
+ def visit_node(self, node):
+ """Visitor function.
+
+ Args:
+ node: Node
+ Returns:
+ bool, whether the node should be revisited; subclasses can visit every
+ reachable node exactly once by always returning False
+ """
+ raise NotImplementedError('Subclasses must implement this.')
+
+ def reset(self):
+ self.in_ = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+ self.out = {
+ node: self.init_state(node) for node in self.graph.index.values()
+ }
+
+ def _visit_internal(self, mode):
+ """Visits the CFG, depth-first."""
+ assert mode in (_WalkMode.FORWARD, _WalkMode.REVERSE)
+ if mode == _WalkMode.FORWARD:
+ open_ = [self.graph.entry]
+ elif mode == _WalkMode.REVERSE:
+ open_ = list(self.graph.exit)
+ closed = set()
+
+ while open_:
+ node = open_.pop(0)
+ closed.add(node)
+
+ should_revisit = self.visit_node(node)
+
+ if mode == _WalkMode.FORWARD:
+ children = node.next
+ elif mode == _WalkMode.REVERSE:
+ children = node.prev
+
+ for next_ in children:
+ if should_revisit or next_ not in closed:
+ open_.append(next_)
+
+ def visit_forward(self):
+ self._visit_internal(_WalkMode.FORWARD)
+
+ def visit_reverse(self):
+ self._visit_internal(_WalkMode.REVERSE)
+
+
+class GraphBuilder(object):
+ """Builder that constructs a CFG from a given AST.
+
+ This GraphBuilder facilitates constructing the DAG that forms the CFG when
+ nodes
+ are supplied in lexical order (i.e., top-down, depth first). Under these
+ conditions, it supports building patterns found in typical structured
+ programs.
+
+ This builder ignores the flow generated by exceptions, which are assumed to
+ always be catastrophic and present purely for diagnostic purposes (e.g. to
+ print debug information). Statements like raise and try/catch sections are
+ allowed and will generate control flow edges, but ordinaty statements are
+ assumed not to raise exceptions.
+
+ Finally sections are also correctly interleaved between break/continue/return
+ nodes and their subsequent statements.
+
+ Important concepts:
+ * nodes - nodes refer refer to CFG nodes; AST nodes are qualified explicitly
+ * leaf set - since the graph is constructed gradually, a leaf set maintains
+ the CFG nodes that will precede the node that the builder expects to
+ receive next; when an ordinary node is added, it is connected to the
+ existing leaves and it in turn becomes the new leaf
+ * jump nodes - nodes that should generate edges other than what
+ ordinary nodes would; these correspond to break, continue and return
+ statements
+ * sections - logical delimiters for subgraphs that require special
+ edges; there are various types of nodes, each admitting various
+ types of jump nodes; sections are identified by their corresponding AST
+ node
+ """
+
+ # TODO(mdan): Perhaps detail this in a markdown doc.
+ # TODO(mdan): Add exception support.
+
+ def __init__(self, parent_ast_node):
+ self.reset()
+ self.parent = parent_ast_node
+
+ def reset(self):
+ """Resets the state of this factory."""
+ self.head = None
+ self.errors = set()
+ self.node_index = collections.OrderedDict()
+
+ # TODO(mdan): Too many primitives. Use classes.
+ self.leaves = set()
+
+ # Note: This mechanism requires that nodes are added in lexical order (top
+ # to bottom, depth first).
+ self.active_stmts = set()
+ self.owners = {} # type: Set[any]
+ self.forward_edges = set() # type: Tuple[Node, Node] # (from, to)
+
+ self.finally_sections = {}
+ # Dict values represent (entry, exits)
+ self.finally_section_subgraphs = {
+ } # type: Dict[ast.AST, Tuple[Node, Set[Node]]]
+ # Whether the guard section can be reached from the statement that precedes
+ # it.
+ self.finally_section_has_direct_flow = {}
+ # Finally sections that await their first node.
+ self.pending_finally_sections = set()
+
+ # Exit jumps keyed by the section they affect.
+ self.exits = {}
+
+ # The entry of loop sections, keyed by the section.
+ self.section_entry = {}
+ # Continue jumps keyed by the section they affect.
+ self.continues = {}
+
+ # The entry of conditional sections, keyed by the section.
+ self.cond_entry = {}
+ # Lists of leaf nodes corresponding to each branch in the section.
+ self.cond_leaves = {}
+
+ def _connect_nodes(self, first, second):
+ """Connects nodes to signify that control flows from first to second.
+
+ Args:
+ first: Union[Set[Node, ...], Node]
+ second: Node
+ """
+ if isinstance(first, Node):
+ first.next.add(second)
+ second.prev.add(first)
+ self.forward_edges.add((first, second))
+ else:
+ for node in first:
+ self._connect_nodes(node, second)
+
+ def _add_new_node(self, ast_node):
+ """Grows the graph by adding a CFG node following the current leaves."""
+ if ast_node is self.node_index:
+ raise ValueError('%s added twice' % ast_node)
+ node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ self.node_index[ast_node] = node
+ self.owners[node] = frozenset(self.active_stmts)
+
+ if self.head is None:
+ self.head = node
+
+ for leaf in self.leaves:
+ self._connect_nodes(leaf, node)
+
+ # If any finally section awaits its first node, populate it.
+ for section_id in self.pending_finally_sections:
+ self.finally_section_subgraphs[section_id][0] = node
+ self.pending_finally_sections = set()
+
+ return node
+
+ def begin_statement(self, stmt):
+ """Marks the beginning of a statement.
+
+ Args:
+ stmt: Hashable, a key by which the statement can be identified in
+ the CFG's stmt_prev and stmt_next attributes
+ """
+ self.active_stmts.add(stmt)
+
+ def end_statement(self, stmt):
+ """Marks the end of a statement.
+
+ Args:
+ stmt: Hashable, a key by which the statement can be identified in
+ the CFG's stmt_prev and stmt_next attributes; must match a key
+ previously passed to begin_statement.
+ """
+ self.active_stmts.remove(stmt)
+
+ def add_ordinary_node(self, ast_node):
+ """Grows the graph by adding an ordinary CFG node.
+
+ Ordinary nodes are followed by the next node, in lexical order, that is,
+ they become the new leaf set.
+
+ Args:
+ ast_node: ast.AST
+ Returns:
+ Node
+ """
+ node = self._add_new_node(ast_node)
+ self.leaves = set((node,))
+ return node
+
+ def _add_jump_node(self, ast_node, guards):
+ """Grows the graph by adding a jump node.
+
+ Jump nodes are added to the current leaf set, and the leaf set becomes
+ empty. If the jump node is the last in a cond section, then it may be added
+ back to the leaf set by a separate mechanism.
+
+ Args:
+ ast_node: ast.AST
+ guards: Tuple[ast.AST, ...], the finally sections active for this node
+ Returns:
+ Node
+ """
+ node = self._add_new_node(ast_node)
+ self.leaves = set()
+ # The guards themselves may not yet be complete, and will be wired later.
+ self.finally_sections[node] = guards
+ return node
+
+ def _connect_jump_to_finally_sections(self, node):
+ """Connects a jump node to the finally sections protecting it."""
+ cursor = set((node,))
+ for guard_section_id in self.finally_sections[node]:
+ guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id]
+ self._connect_nodes(cursor, guard_begin)
+ cursor = guard_ends
+ del self.finally_sections[node]
+ # TODO(mdan): Should garbage-collect finally_section_subgraphs.
+ return cursor
+
+ def add_exit_node(self, ast_node, section_id, guards):
+ """Grows the graph by adding an exit node.
+
+ This node becomes an exit for the current section.
+
+ Args:
+ ast_node: ast.AST
+ section_id: Hashable, the node for which ast_node should be considered
+ to be an exit node
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.exits[section_id].add(node)
+
+ def add_continue_node(self, ast_node, section_id, guards):
+ """Grows the graph by adding a reentry node.
+
+ This node causes control flow to go back to the loop section's entry.
+
+ Args:
+ ast_node: ast.AST
+ section_id: Hashable, the node for which ast_node should be considered
+ to be an exit node
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.continues[section_id].add(node)
+
+ def add_error_node(self, ast_node, guards):
+ """Grows the graph by adding an error node.
+
+ This node becomes an exit for the entire graph.
+
+ Args:
+ ast_node: ast.AST
+ guards: Tuple[ast.AST, ...], the finally sections that guard ast_node
+ """
+ node = self._add_jump_node(ast_node, guards)
+ self.errors.add(node)
+ self.leaves = set()
+
+ def enter_section(self, section_id):
+ """Enters a regular section.
+
+ Regular sections admit exit jumps, which end the section.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ ast_node arg passed to add_exit_node
+ """
+ assert section_id not in self.exits
+ self.exits[section_id] = set()
+
+ def exit_section(self, section_id):
+ """Exits a regular section."""
+
+ # Exits are jump nodes, which may be protected.
+ for exit_ in self.exits[section_id]:
+ self.leaves |= self._connect_jump_to_finally_sections(exit_)
+
+ del self.exits[section_id]
+
+ def enter_loop_section(self, section_id, entry_node):
+ """Enters a loop section.
+
+ Loop sections define an entry node. The end of the section always flows back
+ to the entry node. These admit continue jump nodes which also flow to the
+ entry node.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ ast_node arg passed to add_continue_node
+ entry_node: ast.AST, the entry node into the loop (e.g. the test node
+ for while loops)
+ """
+ assert section_id not in self.section_entry
+ assert section_id not in self.continues
+ self.continues[section_id] = set()
+ node = self.add_ordinary_node(entry_node)
+ self.section_entry[section_id] = node
+
+ def exit_loop_section(self, section_id):
+ """Exits a loop section."""
+ self._connect_nodes(self.leaves, self.section_entry[section_id])
+
+ # continues are jump nodes, which may be protected.
+ for reentry in self.continues[section_id]:
+ guard_ends = self._connect_jump_to_finally_sections(reentry)
+ self._connect_nodes(guard_ends, self.section_entry[section_id])
+
+ # Loop nodes always loop back.
+ self.leaves = set((self.section_entry[section_id],))
+
+ del self.continues[section_id]
+ del self.section_entry[section_id]
+
+ def enter_cond_section(self, section_id):
+ """Enters a conditional section.
+
+ Conditional sections define an entry node, and one or more branches.
+
+ Args:
+ section_id: Hashable, the same node that will be used in calls to the
+ section_id arg passed to new_cond_branch
+ """
+
+ assert section_id not in self.cond_entry
+ assert section_id not in self.cond_leaves
+ self.cond_leaves[section_id] = []
+
+ def new_cond_branch(self, section_id):
+ """Begins a new branch in a cond section."""
+ assert section_id in self.cond_leaves
+
+ if section_id in self.cond_entry:
+ # Subsequent splits move back to the split point, and memorize the
+ # current leaves.
+ self.cond_leaves[section_id].append(self.leaves)
+ self.leaves = self.cond_entry[section_id]
+ else:
+ # If this is the first time we split a section, just remember the split
+ # point.
+ self.cond_entry[section_id] = self.leaves
+
+ def exit_cond_section(self, section_id):
+ """Exits a conditional section."""
+ for split in self.cond_leaves[section_id]:
+ self.leaves |= split
+ del self.cond_entry[section_id]
+ del self.cond_leaves[section_id]
+
+ def enter_finally_section(self, section_id):
+ """Enters a finally section."""
+ # TODO(mdan): This, not the caller, should track the active sections.
+ self.finally_section_subgraphs[section_id] = [None, None]
+ if self.leaves:
+ self.finally_section_has_direct_flow[section_id] = True
+ else:
+ self.finally_section_has_direct_flow[section_id] = False
+ self.pending_finally_sections.add(section_id)
+
+ def exit_finally_section(self, section_id):
+ """Exits a finally section."""
+ assert section_id not in self.pending_finally_sections, 'Empty finally?'
+ self.finally_section_subgraphs[section_id][1] = self.leaves
+ # If the guard can only be reached by a jump, then it will not flow
+ # into the statement that follows it.
+ if not self.finally_section_has_direct_flow[section_id]:
+ self.leaves = set()
+ del self.finally_section_has_direct_flow[section_id]
+
+ def build(self):
+ """Returns the CFG accumulated so far and resets the builder.
+
+ Returns:
+ Graph
+ """
+ # Freeze the nodes.
+ for node in self.node_index.values():
+ node.freeze()
+
+ # Build the statement edges.
+ stmt_next = {}
+ stmt_prev = {}
+ for node, _ in self.forward_edges:
+ for stmt in self.owners[node]:
+ if stmt not in stmt_next:
+ stmt_next[stmt] = set()
+ if stmt not in stmt_prev:
+ stmt_prev[stmt] = set()
+ for first, second in self.forward_edges:
+ stmts_exited = self.owners[first] - self.owners[second]
+ for stmt in stmts_exited:
+ stmt_next[stmt].add(second)
+ stmts_entered = self.owners[second] - self.owners[first]
+ for stmt in stmts_entered:
+ stmt_prev[stmt].add(first)
+ for stmt in stmt_next:
+ stmt_next[stmt] = frozenset(stmt_next[stmt])
+ for stmt in stmt_prev:
+ stmt_prev[stmt] = frozenset(stmt_prev[stmt])
+
+ # Construct the final graph object.
+ result = Graph(
+ entry=self.head,
+ exit=self.leaves,
+ error=self.errors,
+ index=self.node_index,
+ stmt_prev=stmt_prev,
+ stmt_next=stmt_next)
+
+ # Reset the state.
+ self.reset()
+
+ return result
+
+
+class AstToCfg(gast.NodeVisitor):
+ """Converts an AST to CFGs.
+
+ A separate CFG will be constructed for each function.
+ """
+
+ def __init__(self):
+ super(AstToCfg, self).__init__()
+
+ self.builder_stack = []
+ self.builder = None
+ self.cfgs = {}
+
+ self.lexical_scopes = []
+
+ def _enter_lexical_scope(self, node):
+ self.lexical_scopes.append(node)
+
+ def _exit_lexical_scope(self, node):
+ leaving_node = self.lexical_scopes.pop()
+ assert node == leaving_node
+
+ def _get_enclosing_scopes(self, include, stop_at):
+ included = []
+ for node in reversed(self.lexical_scopes):
+ if isinstance(node, include):
+ included.append(node)
+ if isinstance(node, stop_at):
+ return node, included
+ return None, included
+
+ def _process_basic_statement(self, node):
+ self.generic_visit(node)
+ self.builder.add_ordinary_node(node)
+
+ def _process_exit_statement(self, node, *exits_nodes_of_type):
+ # Note: this is safe because we process functions separately.
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=tuple(exits_nodes_of_type),
+ )
+ if try_node is None:
+ raise ValueError(
+ '%s that is not enclosed by any of %s' % (node, exits_nodes_of_type))
+ self.builder.add_exit_node(node, try_node, guards)
+
+ def _process_continue_statement(self, node, *loops_to_nodes_of_type):
+ # Note: this is safe because we process functions separately.
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=tuple(loops_to_nodes_of_type),
+ )
+ if try_node is None:
+ raise ValueError('%s that is not enclosed by any of %s' %
+ (node, loops_to_nodes_of_type))
+ self.builder.add_continue_node(node, try_node, guards)
+
+ def visit_FunctionDef(self, node):
+ # We also keep the FunctionDef node in the CFG. This allows us to determine
+ # things like reaching definitions via closure. Note that the function body
+ # will be stored in a separate graph, because function definitions are not
+ # the same as function calls.
+ if self.builder is not None:
+ self.builder.add_ordinary_node(node)
+
+ self.builder_stack.append(self.builder)
+ self.builder = GraphBuilder(node)
+
+ self._enter_lexical_scope(node)
+ self.builder.enter_section(node)
+
+ self._process_basic_statement(node.args)
+ for stmt in node.body:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+ self._exit_lexical_scope(node)
+
+ self.cfgs[node] = self.builder.build()
+ self.builder = self.builder_stack.pop()
+
+ def visit_Lambda(self, node):
+ # TODO(mdan): Treat like FunctionDef? That would be a separate CFG.
+ raise NotImplementedError()
+
+ def visit_Return(self, node):
+ self._process_exit_statement(node, gast.FunctionDef)
+
+ def visit_Expr(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Assign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_AnnAssign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_AugAssign(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Print(self, node):
+ self._process_basic_statement(node)
+
+ def visit_Raise(self, node):
+ try_node, guards = self._get_enclosing_scopes(
+ include=(gast.Try,),
+ stop_at=(gast.FunctionDef,),
+ )
+ if try_node is None:
+ raise ValueError('%s that is not enclosed by any FunctionDef' % node)
+ self.builder.add_error_node(node, guards)
+
+ def visit_Assert(self, node):
+ # Ignoring the effect of exceptions.
+ self._process_basic_statement(node)
+
+ def visit_Delete(self, node):
+ self._process_basic_statement(node)
+
+ def visit_If(self, node):
+ # No need to track ifs as lexical scopes, for now.
+ # Lexical scopes are generally tracked in order to be able to resolve the
+ # targets of jump statements like break/continue/etc. Since there is no
+ # statement that can interrupt a conditional, we don't need to track their
+ # lexical scope. That may change in the future.
+ self.builder.begin_statement(node)
+
+ self.builder.enter_cond_section(node)
+ self._process_basic_statement(node.test)
+
+ self.builder.new_cond_branch(node)
+ for stmt in node.body:
+ self.visit(stmt)
+
+ self.builder.new_cond_branch(node)
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_cond_section(node)
+ self.builder.end_statement(node)
+
+ def visit_While(self, node):
+ self.builder.begin_statement(node)
+ self._enter_lexical_scope(node)
+
+ self.builder.enter_section(node)
+
+ self.builder.enter_loop_section(node, node.test)
+ for stmt in node.body:
+ self.visit(stmt)
+ self.builder.exit_loop_section(node)
+
+ # Note: although the orelse is technically part of the loop node,
+ # the statements inside it don't affect the loop itself. For example, a
+ # break in the loop's orelse will not affect the loop itself.
+ self._exit_lexical_scope(node)
+
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+ self.builder.end_statement(node)
+
+ def visit_For(self, node):
+ self.builder.begin_statement(node)
+ self._enter_lexical_scope(node)
+
+ self.builder.enter_section(node)
+
+ # TODO(mdan): Strictly speaking, this should be node.target + node.iter.
+ # A blind dataflow analysis would have to process both node.target and
+ # node.iter to properly process read and write access.
+ self.builder.enter_loop_section(node, node.iter)
+ for stmt in node.body:
+ self.visit(stmt)
+ self.builder.exit_loop_section(node)
+
+ # Note: although the orelse is technically part of the loop node,
+ # they don't count as loop bodies. For example, a break in the loop's
+ # orelse will affect the parent loop, not the current one.
+ self._exit_lexical_scope(node)
+
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ self.builder.exit_section(node)
+ self.builder.end_statement(node)
+
+ def visit_Break(self, node):
+ self._process_exit_statement(node, gast.While, gast.For)
+
+ def visit_Continue(self, node):
+ self._process_continue_statement(node, gast.While, gast.For)
+
+ def visit_Try(self, node):
+ self._enter_lexical_scope(node)
+
+ for stmt in node.body:
+ self.visit(stmt)
+ # Unlike loops, the orelse is a simple continuation of the body.
+ for stmt in node.orelse:
+ self.visit(stmt)
+
+ if node.handlers:
+ # TODO(mdan): Should we still support bare try/except? Might be confusing.
+ raise NotImplementedError('exceptions are not yet supported')
+
+ self._exit_lexical_scope(node)
+
+ self.builder.enter_finally_section(node)
+ for stmt in node.finalbody:
+ self.visit(stmt)
+ self.builder.exit_finally_section(node)
+
+ def visit_With(self, node):
+ # TODO(mdan): Mark the context manager's exit call as exit guard.
+ for item in node.items:
+ self._process_basic_statement(item)
+ for stmt in node.body:
+ self.visit(stmt)
+
+
+def build(node):
+ visitor = AstToCfg()
+ visitor.visit(node)
+ return visitor.cfgs
diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py
new file mode 100644
index 0000000000..bd82e70f7d
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/cfg_test.py
@@ -0,0 +1,969 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for cfg module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class CountingVisitor(cfg.GraphVisitor):
+
+ def __init__(self, graph):
+ super(CountingVisitor, self).__init__(graph)
+ self.counts = {}
+
+ def init_state(self, _):
+ return None
+
+ def visit_node(self, node):
+ self.counts[node.ast_node] = self.counts.get(node.ast_node, 0) + 1
+ return False # visit only once
+
+
+class GraphVisitorTest(test.TestCase):
+
+ def _build_cfg(self, fn):
+ node, _ = parser.parse_entity(fn)
+ cfgs = cfg.build(node)
+ return cfgs, node
+
+ def test_basic_coverage_forward(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ break
+ return a # pylint:disable=unreachable
+ a = 2
+
+ graphs, node = self._build_cfg(test_fn)
+ graph, = graphs.values()
+ visitor = CountingVisitor(graph)
+ visitor.visit_forward()
+ fn_node = node.body[0]
+
+ self.assertEqual(visitor.counts[fn_node.args], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
+ # The return node should be unreachable in forward direction.
+ self.assertTrue(fn_node.body[0].body[2] not in visitor.counts)
+ self.assertEqual(visitor.counts[fn_node.body[1]], 1)
+
+ def test_basic_coverage_reverse(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ break
+ return a # pylint:disable=unreachable
+ a = 2
+
+ graphs, node = self._build_cfg(test_fn)
+ graph, = graphs.values()
+ visitor = CountingVisitor(graph)
+ visitor.visit_reverse()
+ fn_node = node.body[0]
+
+ self.assertEqual(visitor.counts[fn_node.args], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].test], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[0]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[0].body[1]], 1)
+ self.assertTrue(visitor.counts[fn_node.body[0].body[2]], 1)
+ self.assertEqual(visitor.counts[fn_node.body[1]], 1)
+
+
+class AstToCfgTest(test.TestCase):
+
+ def _build_cfg(self, fn):
+ node, _ = parser.parse_entity(fn)
+ cfgs = cfg.build(node)
+ return cfgs
+
+ def _repr_set(self, node_set):
+ return frozenset(repr(n) for n in node_set)
+
+ def _as_set(self, elements):
+ if elements is None:
+ return frozenset()
+ elif isinstance(elements, str):
+ return frozenset((elements,))
+ else:
+ return frozenset(elements)
+
+ def assertGraphMatches(self, graph, edges):
+ """Tests whether the CFG contains the specified edges."""
+ for prev, node_repr, next_ in edges:
+ matched = False
+ for cfg_node in graph.index.values():
+ if repr(cfg_node) == node_repr:
+ if (self._as_set(prev) == frozenset(map(repr, cfg_node.prev)) and
+ self._as_set(next_) == frozenset(map(repr, cfg_node.next))):
+ matched = True
+ break
+ if not matched:
+ self.fail(
+ 'match failed for node "%s" in graph:\n%s' % (node_repr, graph))
+
+ def assertStatementEdges(self, graph, edges):
+ """Tests whether the CFG contains the specified statement edges."""
+ for prev_node_reprs, node_repr, next_node_reprs in edges:
+ matched = False
+ partial_matches = []
+ self.assertSetEqual(
+ frozenset(graph.stmt_next.keys()), frozenset(graph.stmt_prev.keys()))
+ for stmt_ast_node in graph.stmt_next:
+ ast_repr = '%s:%s' % (stmt_ast_node.__class__.__name__,
+ stmt_ast_node.lineno)
+ if ast_repr == node_repr:
+ actual_next = frozenset(map(repr, graph.stmt_next[stmt_ast_node]))
+ actual_prev = frozenset(map(repr, graph.stmt_prev[stmt_ast_node]))
+ partial_matches.append((actual_prev, node_repr, actual_next))
+ if (self._as_set(prev_node_reprs) == actual_prev and
+ self._as_set(next_node_reprs) == actual_next):
+ matched = True
+ break
+ if not matched:
+ self.fail('edges mismatch for %s: %s' % (node_repr, partial_matches))
+
+ def test_straightline(self):
+
+ def test_fn(a):
+ a += 1
+ a = 2
+ a = 3
+ return
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', 'a += 1'),
+ ('a += 1', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', 'return'),
+ ('a = 3', 'return', None),
+ ),
+ )
+
+ def test_straightline_no_return(self):
+
+ def test_fn(a, b):
+ a = b + 1
+ a += max(a)
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a, b', 'a = b + 1'),
+ ('a = b + 1', 'a += max(a)', None),
+ ),
+ )
+
+ def test_unreachable_code(self):
+
+ def test_fn(a):
+ return
+ a += 1 # pylint:disable=unreachable
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', 'return'),
+ ('a', 'return', None),
+ (None, 'a += 1', None),
+ ),
+ )
+
+ def test_if_straightline(self):
+
+ def test_fn(a):
+ if a > 0:
+ a = 1
+ else:
+ a += -1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('(a > 0)', 'a = 1', None),
+ ('(a > 0)', 'a += -1', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', None),),
+ )
+
+ def test_branch_nested(self):
+
+ def test_fn(a):
+ if a > 0:
+ if a > 1:
+ a = 1
+ else:
+ a = 2
+ else:
+ if a > 2:
+ a = 3
+ else:
+ a = 4
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('a', '(a > 0)', ('(a > 1)', '(a > 2)')),
+ ('(a > 0)', '(a > 1)', ('a = 1', 'a = 2')),
+ ('(a > 1)', 'a = 1', None),
+ ('(a > 1)', 'a = 2', None),
+ ('(a > 0)', '(a > 2)', ('a = 3', 'a = 4')),
+ ('(a > 2)', 'a = 3', None),
+ ('(a > 2)', 'a = 4', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'If:2', None),
+ ('(a > 0)', 'If:3', None),
+ ('(a > 0)', 'If:8', None),
+ ),
+ )
+
+ def test_branch_straightline_semi(self):
+
+ def test_fn(a):
+ if a > 0:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (None, 'a', '(a > 0)'),
+ ('a', '(a > 0)', 'a = 1'),
+ ('(a > 0)', 'a = 1', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', None),),
+ )
+
+ def test_branch_return(self):
+
+ def test_fn(a):
+ if a > 0:
+ return
+ else:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', ('return', 'a = 1')),
+ ('(a > 0)', 'a = 1', 'a = 2'),
+ ('(a > 0)', 'return', None),
+ ('a = 1', 'a = 2', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', 'a = 2'),),
+ )
+
+ def test_branch_return_minimal(self):
+
+ def test_fn(a):
+ if a > 0:
+ return
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', 'return'),
+ ('(a > 0)', 'return', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'If:2', None),),
+ )
+
+ def test_while_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
+ ('(a > 0)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'While:2', 'a = 2'),),
+ )
+
+ def test_while_else_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('a = 1', 'a = 2')),
+ ('(a > 0)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'While:2', 'a = 3'),),
+ )
+
+ def test_while_else_continue(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ continue
+ else:
+ a = 0
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'continue', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('continue', 'a = 0')),
+ ('(a > 1)', 'continue', '(a > 0)'),
+ ('a = 0', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'If:3', ('a = 1', '(a > 0)')),
+ ),
+ )
+
+ def test_while_else_break(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ break
+ a = 1
+ else:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('break', 'a = 1')),
+ ('(a > 1)', 'break', 'a = 3'),
+ ('(a > 1)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ (('break', 'a = 2'), 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'If:3', ('a = 1', 'a = 3')),
+ ),
+ )
+
+ def test_while_else_return(self):
+
+ def test_fn(a):
+ while a > 0:
+ if a > 1:
+ return
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', ('(a > 1)', 'a = 2')),
+ ('(a > 0)', '(a > 1)', ('return', 'a = 1')),
+ ('(a > 1)', 'return', None),
+ ('(a > 1)', 'a = 1', '(a > 0)'),
+ ('(a > 0)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'If:3', 'a = 1'),
+ ),
+ )
+
+ def test_while_nested_straightline(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('a = 1', 'a = 2')),
+ ('(a > 1)', 'a = 1', '(a > 1)'),
+ ('(a > 1)', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'While:3', 'a = 2'),
+ ),
+ )
+
+ def test_while_nested_continue(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ if a > 3:
+ continue
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'continue', 'a = 1'), '(a > 1)', ('(a > 3)', 'a = 2')),
+ ('(a > 1)', '(a > 3)', ('continue', 'a = 1')),
+ ('(a > 3)', 'continue', '(a > 1)'),
+ ('(a > 3)', 'a = 1', '(a > 1)'),
+ ('(a > 1)', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'While:3', 'a = 2'),
+ ('(a > 1)', 'If:4', ('a = 1', '(a > 1)')),
+ ),
+ )
+
+ def test_while_nested_break(self):
+
+ def test_fn(a):
+ while a > 0:
+ while a > 1:
+ if a > 2:
+ break
+ a = 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(graph, (
+ (('a', 'a = 2'), '(a > 0)', ('(a > 1)', 'a = 3')),
+ (('(a > 0)', 'a = 1'), '(a > 1)', ('(a > 2)', 'a = 2')),
+ ('(a > 1)', '(a > 2)', ('break', 'a = 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'a = 1', '(a > 1)'),
+ (('(a > 1)', 'break'), 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'a = 3', None),
+ ))
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'While:2', 'a = 3'),
+ ('(a > 0)', 'While:3', 'a = 2'),
+ ('(a > 1)', 'If:4', ('a = 1', 'a = 2')),
+ ),
+ )
+
+ def test_for_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
+ ('range(0, a)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'For:2', 'a = 2'),),
+ )
+
+ def test_for_else_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('a = 1', 'a = 2')),
+ ('range(0, a)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (('a', 'For:2', 'a = 3'),),
+ )
+
+ def test_for_else_continue(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ continue
+ else:
+ a = 0
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'continue', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('continue', 'a = 0')),
+ ('(a > 1)', 'continue', 'range(0, a)'),
+ ('(a > 1)', 'a = 0', 'a = 1'),
+ ('a = 0', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'If:3', ('a = 1', 'range(0, a)')),
+ ),
+ )
+
+ def test_for_else_break(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ break
+ a = 1
+ else:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('break', 'a = 1')),
+ ('(a > 1)', 'break', 'a = 3'),
+ ('(a > 1)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ (('break', 'a = 2'), 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'If:3', ('a = 1', 'a = 3')),
+ ),
+ )
+
+ def test_for_else_return(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ if a > 1:
+ return
+ a = 1
+ else: # pylint:disable=useless-else-on-loop
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), 'range(0, a)', ('(a > 1)', 'a = 2')),
+ ('range(0, a)', '(a > 1)', ('return', 'a = 1')),
+ ('(a > 1)', 'return', None),
+ ('(a > 1)', 'a = 1', 'range(0, a)'),
+ ('range(0, a)', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'If:3', 'a = 1'),
+ ),
+ )
+
+ def test_for_nested_straightline(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'b += 1'), 'range(1, a)', ('b += 1', 'a = 2')),
+ ('range(1, a)', 'b += 1', 'range(1, a)'),
+ ('range(1, a)', 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'For:3', 'a = 2'),
+ ),
+ )
+
+ def test_for_nested_continue(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ if a > 3:
+ continue
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'continue', 'b += 1'), 'range(1, a)',
+ ('(a > 3)', 'a = 2')),
+ ('range(1, a)', '(a > 3)', ('continue', 'b += 1')),
+ ('(a > 3)', 'continue', 'range(1, a)'),
+ ('(a > 3)', 'b += 1', 'range(1, a)'),
+ ('range(1, a)', 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'For:3', 'a = 2'),
+ ('range(1, a)', 'If:4', ('b += 1', 'range(1, a)')),
+ ),
+ )
+
+ def test_for_nested_break(self):
+
+ def test_fn(a):
+ for a in range(0, a):
+ for b in range(1, a):
+ if a > 2:
+ break
+ b += 1
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 2'), 'range(0, a)', ('range(1, a)', 'a = 3')),
+ (('range(0, a)', 'b += 1'), 'range(1, a)', ('(a > 2)', 'a = 2')),
+ ('range(1, a)', '(a > 2)', ('break', 'b += 1')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', 'b += 1', 'range(1, a)'),
+ (('range(1, a)', 'break'), 'a = 2', 'range(0, a)'),
+ ('range(0, a)', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('a', 'For:2', 'a = 3'),
+ ('range(0, a)', 'For:3', 'a = 2'),
+ ('range(1, a)', 'If:4', ('b += 1', 'a = 2')),
+ ),
+ )
+
+ def test_complex(self):
+
+ def test_fn(a):
+ b = 0
+ while a > 0:
+ for b in range(0, a):
+ if a > 2:
+ break
+ if a > 3:
+ if a > 4:
+ continue
+ else:
+ max(a)
+ break
+ b += 1
+ else: # for b in range(0, a):
+ return a
+ a = 2
+ for a in range(1, a):
+ return b
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('b = 0', 'a = 2'), '(a > 0)', ('range(0, a)', 'range(1, a)')),
+ (
+ ('(a > 0)', 'continue', 'b += 1'),
+ 'range(0, a)',
+ ('(a > 2)', 'return a'),
+ ),
+ ('range(0, a)', '(a > 2)', ('(a > 3)', 'break')),
+ ('(a > 2)', 'break', 'a = 2'),
+ ('(a > 2)', '(a > 3)', ('(a > 4)', 'b += 1')),
+ ('(a > 3)', '(a > 4)', ('continue', 'max(a)')),
+ ('(a > 4)', 'max(a)', 'break'),
+ ('max(a)', 'break', 'a = 2'),
+ ('(a > 4)', 'continue', 'range(0, a)'),
+ ('(a > 3)', 'b += 1', 'range(0, a)'),
+ ('range(0, a)', 'return a', None),
+ ('break', 'a = 2', '(a > 0)'),
+ ('(a > 0)', 'range(1, a)', ('return b', 'a = 3')),
+ ('range(1, a)', 'return b', None),
+ ('range(1, a)', 'a = 3', None),
+ ),
+ )
+ self.assertStatementEdges(
+ graph,
+ (
+ ('b = 0', 'While:3', 'range(1, a)'),
+ ('(a > 0)', 'For:4', 'a = 2'),
+ ('range(0, a)', 'If:5', ('(a > 3)', 'a = 2')),
+ ('(a > 2)', 'If:7', ('b += 1', 'a = 2', 'range(0, a)')),
+ ('(a > 3)', 'If:8', ('a = 2', 'range(0, a)')),
+ ('(a > 0)', 'For:17', 'a = 3'),
+ ),
+ )
+
+ def test_finally_straightline(self):
+
+ def test_fn(a):
+ try:
+ a += 1
+ finally:
+ a = 2
+ a = 3
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'a += 1', 'a = 2'),
+ ('a += 1', 'a = 2', 'a = 3'),
+ ('a = 2', 'a = 3', None),
+ ),
+ )
+
+ def test_return_finally(self):
+
+ def test_fn(a):
+ try:
+ return a
+ finally:
+ a = 1
+ a = 2
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'return a', 'a = 1'),
+ ('return a', 'a = 1', None),
+ (None, 'a = 2', None),
+ ),
+ )
+
+ def test_break_finally(self):
+
+ def test_fn(a):
+ while a > 0:
+ try:
+ break
+ finally:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', '(a > 0)', 'break'),
+ ('(a > 0)', 'break', 'a = 1'),
+ ('break', 'a = 1', None),
+ ),
+ )
+
+ def test_continue_finally(self):
+
+ def test_fn(a):
+ while a > 0:
+ try:
+ continue
+ finally:
+ a = 1
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ (('a', 'a = 1'), '(a > 0)', 'continue'),
+ ('(a > 0)', 'continue', 'a = 1'),
+ ('continue', 'a = 1', '(a > 0)'),
+ ),
+ )
+
+ def test_with_straightline(self):
+
+ def test_fn(a):
+ with max(a) as b:
+ a = 0
+ return b
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'max(a)', 'a = 0'),
+ ('max(a)', 'a = 0', 'return b'),
+ ('a = 0', 'return b', None),
+ ),
+ )
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/common_transformers/BUILD b/tensorflow/python/autograph/pyct/common_transformers/BUILD
new file mode 100644
index 0000000000..5e2f8f3ac0
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/common_transformers/BUILD
@@ -0,0 +1,41 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "common_transformers",
+ srcs = [
+ "anf.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ # TODO(aqj) Revisit this dependency direction when pyct is more
+ # modularized
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "anf_test",
+ srcs = ["anf_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":common_transformers",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/python/autograph/pyct/common_transformers/__init__.py b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf.py b/tensorflow/python/autograph/pyct/common_transformers/anf.py
new file mode 100644
index 0000000000..192621b1cd
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf.py
@@ -0,0 +1,424 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Conversion to A-normal form.
+
+The general idea of A-normal form is that every intermediate value is
+explicitly named with a variable. For more, see
+https://en.wikipedia.org/wiki/A-normal_form.
+
+The specific converters used here are based on Python AST semantics as
+documented at https://greentreesnakes.readthedocs.io/en/latest/.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+import six
+
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.autograph.pyct import transformer
+
+
+class DummyGensym(object):
+ """A dumb gensym that suffixes a stem by sequential numbers from 1000."""
+
+ def __init__(self, entity_info):
+ del entity_info
+ # A proper implementation needs to account for:
+ # * entity_info.namespace
+ # * all the symbols defined in the AST
+ # * the symbols generated so far
+ self._idx = 0
+
+ def new_name(self, stem='tmp'):
+ self._idx += 1
+ return stem + '_' + str(1000 + self._idx)
+
+
+class AnfTransformer(transformer.Base):
+ """Performs the conversion to A-normal form (ANF)."""
+
+ # The algorithm is a postorder recursive tree walk. Any given node A may, in
+ # general, require creation of a series B of Assign statements, which compute
+ # and explicitly name the intermediate values needed to compute the value of
+ # A. If A was already a statement, it can be replaced with the sequence B +
+ # [A]. If A was an expression, B needs to be propagated up the tree until a
+ # statement is encountered. Since the `ast.NodeTransformer` framework makes
+ # no provision for subtraversals returning side information, this class
+ # accumulates the sequence B in an instance variable.
+
+ # The only other subtlety is that some Python statements (like `if`) have both
+ # expression fields (`test`) and statement list fields (`body` and `orelse`).
+ # Any additional assignments needed to name all the intermediate values in the
+ # `test` can be prepended to the `if` node, but assignments produced by
+ # processing the `body` and the `orelse` need to be kept together with them,
+ # and not accidentally lifted out of the `if`.
+
+ def __init__(self, entity_info, gensym_source=None):
+ """Creates an ANF transformer.
+
+ Args:
+ entity_info: transformer.EntityInfo
+ gensym_source: An optional object with the same interface as `DummyGensym`
+ for generating unique names
+ """
+ super(AnfTransformer, self).__init__(entity_info)
+ if gensym_source is None:
+ self._gensym = DummyGensym(entity_info)
+ else:
+ self._gensym = gensym_source(entity_info)
+ self._pending_statements = []
+
+ def _consume_pending_statements(self):
+ ans = self._pending_statements
+ self._pending_statements = []
+ return ans
+
+ def _add_pending_statement(self, stmt):
+ self._pending_statements.append(stmt)
+
+ _trivial_nodes = (
+ # Non-nodes that show up as AST fields
+ bool, six.string_types,
+ # Leaf nodes that are already in A-normal form
+ gast.expr_context, gast.Name, gast.Num, gast.Str, gast.Bytes,
+ gast.NameConstant, gast.Ellipsis,
+ # Binary operators
+ gast.Add, gast.Sub, gast.Mult, gast.Div, gast.Mod, gast.Pow, gast.LShift,
+ gast.RShift, gast.BitOr, gast.BitXor, gast.BitAnd, gast.FloorDiv,
+ # Unary operators
+ gast.Invert, gast.Not, gast.UAdd, gast.USub,
+ # Comparison operators
+ gast.Eq, gast.NotEq, gast.Lt, gast.LtE, gast.Gt, gast.GtE,
+ gast.Is, gast.IsNot, gast.In, gast.NotIn,
+ )
+
+ def _is_node_trivial(self, node):
+ if node is None:
+ return True
+ elif isinstance(node, self._trivial_nodes):
+ return True
+ elif isinstance(node, gast.keyword):
+ return self._is_node_trivial(node.value)
+ elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
+ return self._are_children_trivial(node)
+ return False
+
+ def _are_children_trivial(self, node):
+ for field in node._fields:
+ if not field.startswith('__'):
+ if not self._is_node_trivial(getattr(node, field)):
+ return False
+ return True
+
+ def _ensure_node_is_trivial(self, node):
+ if node is None:
+ return node
+ elif isinstance(node, self._trivial_nodes):
+ return node
+ elif isinstance(node, list):
+ # If something's field was actually a list, e.g., variadic arguments.
+ return [self._ensure_node_is_trivial(n) for n in node]
+ elif isinstance(node, gast.keyword):
+ node.value = self._ensure_node_is_trivial(node.value)
+ return node
+ elif isinstance(node, (gast.Starred, gast.withitem, gast.slice)):
+ return self._ensure_fields_trivial(node)
+ elif isinstance(node, gast.expr):
+ temp_name = self._gensym.new_name()
+ temp_assign = templates.replace(
+ 'temp_name = expr', temp_name=temp_name, expr=node)[0]
+ self._add_pending_statement(temp_assign)
+ answer = templates.replace('temp_name', temp_name=temp_name)[0]
+ return answer
+ else:
+ raise ValueError('Do not know how to treat {}'.format(node))
+
+ def _ensure_fields_trivial(self, node):
+ for field in node._fields:
+ if field.startswith('__'):
+ continue
+ setattr(node, field, self._ensure_node_is_trivial(getattr(node, field)))
+ return node
+
+ def _visit_strict_statement(self, node, trivialize_children=True):
+ assert not self._pending_statements
+ node = self.generic_visit(node)
+ if trivialize_children:
+ self._ensure_fields_trivial(node)
+ results = self._consume_pending_statements()
+ results.append(node)
+ return results
+
+ def _visit_strict_expression(self, node):
+ node = self.generic_visit(node)
+ self._ensure_fields_trivial(node)
+ return node
+
+ # Note on code order: These are listed in the same order as the grammar
+ # elements on https://github.com/serge-sans-paille/gast
+
+ # FunctionDef, AsyncFunctionDef, and ClassDef should be correct by default.
+
+ def visit_Return(self, node):
+ return self._visit_strict_statement(node)
+
+ def visit_Delete(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_Assign(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_AugAssign(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ def visit_Print(self, node):
+ return self._visit_strict_statement(node)
+
+ def visit_For(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.iter first, because any statements created
+ # thereby need to live outside the body.
+ self.visit(node.iter)
+ node.iter = self._ensure_node_is_trivial(node.iter)
+ iter_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.iter, but that is both correct and
+ # cheap because by this point node.iter is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ iter_stmts.append(node)
+ return iter_stmts
+
+ def visit_AsyncFor(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial AsyncFor nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_While(self, node):
+ if not self._is_node_trivial(node.test):
+ msg = ('While with nontrivial test not supported yet '
+ '(need to avoid precomputing the test).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_If(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.test first, because any statements created
+ # thereby need to live outside the body.
+ self.visit(node.test)
+ node.test = self._ensure_node_is_trivial(node.test)
+ condition_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.test, but that is both correct and
+ # cheap because by this point node.test is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ condition_stmts.append(node)
+ return condition_stmts
+
+ def visit_With(self, node):
+ assert not self._pending_statements
+ # It's important to visit node.items first, because any statements created
+ # thereby need to live outside the body.
+ for item in node.items:
+ self.visit(item)
+ node.items = [self._ensure_node_is_trivial(n) for n in node.items]
+ contexts_stmts = self._consume_pending_statements()
+ # This generic_visit will revisit node.items, but that is both correct and
+ # cheap because by this point node.items is trivial.
+ node = self.generic_visit(node)
+ assert not self._pending_statements
+ contexts_stmts.append(node)
+ return contexts_stmts
+
+ def visit_AsyncWith(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial AsyncWith nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Raise(self, node):
+ return self._visit_strict_statement(node)
+
+ # Try should be correct by default.
+
+ def visit_Assert(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Assert nodes not supported yet '
+ '(need to avoid computing the test when assertions are off, and '
+ 'avoid computing the irritant when the assertion does not fire).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ # Import and ImportFrom should be correct by default.
+
+ def visit_Exec(self, node):
+ return self._visit_strict_statement(node)
+
+ # Global and Nonlocal should be correct by default.
+
+ def visit_Expr(self, node):
+ return self._visit_strict_statement(node, trivialize_children=False)
+
+ # Pass, Break, and Continue should be correct by default.
+
+ def visit_BoolOp(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial BoolOp nodes not supported yet '
+ '(need to preserve short-circuiting semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_BinOp(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_UnaryOp(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Lambda(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Lambda nodes not supported '
+ '(cannot insert statements into lambda bodies).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_IfExp(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial IfExp nodes not supported yet '
+ '(need to convert to If statement, to evaluate branches lazily '
+ 'and insert statements into them).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Dict(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Set(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_ListComp(self, node):
+ msg = ('ListComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_SetComp(self, node):
+ msg = ('SetComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_DictComp(self, node):
+ msg = ('DictComp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_GeneratorExp(self, node):
+ msg = ('GeneratorExp nodes not supported '
+ '(need to convert to a form that tolerates '
+ 'assignment statements in clause bodies).')
+ raise ValueError(msg)
+
+ def visit_Await(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Await nodes not supported yet '
+ '(need to think through the semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Yield(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_YieldFrom(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial YieldFrom nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Compare(self, node):
+ if len(node.ops) > 1:
+ msg = ('Multi-ary compare nodes not supported yet '
+ '(need to preserve short-circuiting semantics).')
+ raise ValueError(msg)
+ return self._visit_strict_expression(node)
+
+ def visit_Call(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Repr(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial Repr nodes not supported yet '
+ '(need to research their syntax and semantics).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_FormattedValue(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial FormattedValue nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_JoinedStr(self, node):
+ if not self._are_children_trivial(node):
+ msg = ('Nontrivial JoinedStr nodes not supported yet '
+ '(need to unit-test them in Python 2).')
+ raise ValueError(msg)
+ return self.generic_visit(node)
+
+ def visit_Attribute(self, node):
+ return self._visit_strict_expression(node)
+
+ def visit_Subscript(self, node):
+ return self._visit_strict_expression(node)
+
+ # Starred and Name are correct by default, because the right thing to do is to
+ # just recur.
+
+ def visit_List(self, node):
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
+
+ def visit_Tuple(self, node):
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
+
+
+def transform(node, entity_info, gensym_source=None):
+ """Converts the given node to A-normal form (ANF).
+
+ The general idea of A-normal form: https://en.wikipedia.org/wiki/A-normal_form
+
+ The specific converters used here are based on Python AST semantics as
+ documented at https://greentreesnakes.readthedocs.io/en/latest/.
+
+ Args:
+ node: The node to transform.
+ entity_info: transformer.EntityInfo. TODO(mdan): What information does this
+ argument provide?
+ gensym_source: An optional object with the same interface as `DummyGensym`
+ for generating unique names.
+ """
+ return AnfTransformer(entity_info, gensym_source=gensym_source).visit(node)
diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
new file mode 100644
index 0000000000..ccc7e4ca8f
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py
@@ -0,0 +1,443 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for anf module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.common_transformers import anf
+from tensorflow.python.platform import test
+
+
+class DummyGensym(object):
+ """A dumb gensym that suffixes a stem by sequential numbers from 1000."""
+
+ def __init__(self, entity_info):
+ del entity_info
+ # A proper implementation needs to account for:
+ # * entity_info.namespace
+ # * all the symbols defined in the AST
+ # * the symbols generated so far
+ self._idx = 0
+
+ def new_name(self, stem='tmp'):
+ self._idx += 1
+ return stem + '_' + str(1000 + self._idx)
+
+
+class AnfTransformerTest(test.TestCase):
+
+ def _simple_source_info(self):
+ return transformer.EntityInfo(
+ source_code=None,
+ source_file=None,
+ namespace=None,
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+
+ def test_basic(self):
+ def test_function():
+ a = 0
+ return a
+ node, _ = parser.parse_entity(test_function)
+ node = anf.transform(node.body[0], self._simple_source_info())
+ result, _ = compiler.ast_to_object(node)
+ self.assertEqual(test_function(), result.test_function())
+
+ def assert_same_ast(self, expected_node, node, msg=None):
+ expected_source = compiler.ast_to_source(expected_node, indentation=' ')
+ expected_str = textwrap.dedent(expected_source).strip()
+ got_source = compiler.ast_to_source(node, indentation=' ')
+ got_str = textwrap.dedent(got_source).strip()
+ self.assertEqual(expected_str, got_str, msg=msg)
+
+ def assert_body_anfs_as_expected(self, expected_fn, test_fn):
+ # Testing the code bodies only. Wrapping them in functions so the
+ # syntax highlights nicely, but Python doesn't try to execute the
+ # statements.
+ exp_node, _ = parser.parse_entity(expected_fn)
+ node, _ = parser.parse_entity(test_fn)
+ node = anf.transform(
+ node, self._simple_source_info(), gensym_source=DummyGensym)
+ exp_name = exp_node.body[0].name
+ # Ignoring the function names in the result because they can't be
+ # the same (because both functions have to exist in the same scope
+ # at the same time).
+ node.body[0].name = exp_name
+ self.assert_same_ast(exp_node, node)
+ # Check that ANF is idempotent
+ node_repeated = anf.transform(
+ node, self._simple_source_info(), gensym_source=DummyGensym)
+ self.assert_same_ast(node_repeated, node)
+
+ def test_binop_basic(self):
+
+ def test_function(x, y, z):
+ a = x + y + z
+ return a
+
+ def expected_result(x, y, z):
+ tmp_1001 = x + y
+ a = tmp_1001 + z
+ return a
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_if_basic(self):
+
+ def test_function(a, b, c, e, f, g):
+ if a + b + c:
+ d = e + f + g
+ return d
+
+ def expected_result(a, b, c, e, f, g):
+ tmp_1001 = a + b
+ tmp_1002 = tmp_1001 + c
+ if tmp_1002:
+ tmp_1003 = e + f
+ d = tmp_1003 + g
+ return d
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_nested_binop_and_return(self):
+
+ def test_function(b, c, d, e):
+ return (2 * b + c) + (d + e)
+
+ def expected_result(b, c, d, e):
+ tmp_1001 = 2 * b
+ tmp_1002 = tmp_1001 + c
+ tmp_1003 = d + e
+ tmp_1004 = tmp_1002 + tmp_1003
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_function_call_and_expr(self):
+
+ def test_function(call_something, a, b, y, z, c, d, e, f, g, h, i):
+ call_something(a + b, y * z, kwarg=c + d, *(e + f), **(g + h + i))
+
+ def expected_result(call_something, a, b, y, z, c, d, e, f, g, h, i):
+ tmp_1001 = g + h
+ tmp_1002 = a + b
+ tmp_1003 = y * z
+ tmp_1004 = e + f
+ tmp_1005 = c + d
+ tmp_1006 = tmp_1001 + i
+ call_something(tmp_1002, tmp_1003, kwarg=tmp_1005, *tmp_1004, **tmp_1006)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_with_and_print(self):
+
+ def test_function(a, b, c):
+ with a + b + c as d:
+ print(2 * d + 1)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ tmp_1002 = tmp_1001 + c
+ with tmp_1002 as d:
+ tmp_1003 = 2 * d
+ tmp_1004 = tmp_1003 + 1
+ print(tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_nested_multi_value_assign(self):
+
+ def test_function(a, b, c):
+ x, y = a, a + b
+ (z, y), x = (c, y + b), x + a
+ return z, (y, x)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ x, y = a, tmp_1001
+ tmp_1002 = y + b
+ tmp_1003 = (c, tmp_1002)
+ tmp_1004 = x + a
+ (z, y), x = tmp_1003, tmp_1004
+ tmp_1005 = y, x
+ tmp_1006 = z, tmp_1005
+ return tmp_1006
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_deeply_nested_multi_value_assign(self):
+
+ def test_function(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
+
+ def expected_result(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ tmp_1001 = b, c
+ tmp_1002 = [d, e]
+ tmp_1003 = [tmp_1001, tmp_1002]
+ tmp_1004 = f, g
+ tmp_1005 = h, i, j
+ tmp_1006 = tmp_1003, tmp_1004
+ tmp_1007 = [tmp_1005, k]
+ tmp_1008 = [tmp_1006, tmp_1007]
+ return tmp_1008
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_local_definition_and_binary_compare(self):
+
+ def test_function():
+ def foo(a, b):
+ return 2 * a < b
+ return foo
+
+ def expected_result():
+ def foo(a, b):
+ tmp_1001 = 2 * a
+ tmp_1002 = tmp_1001 < b
+ return tmp_1002
+ return foo
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_list_literal(self):
+
+ def test_function(a, b, c, d, e, f):
+ return [a + b, c + d, e + f]
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a + b
+ tmp_1002 = c + d
+ tmp_1003 = e + f
+ tmp_1004 = [tmp_1001, tmp_1002, tmp_1003]
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_tuple_literal_and_unary(self):
+
+ def test_function(a, b, c, d, e, f):
+ return (a + b, -(c + d), e + f)
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = c + d
+ tmp_1002 = a + b
+ tmp_1003 = -tmp_1001
+ tmp_1004 = e + f
+ tmp_1005 = (tmp_1002, tmp_1003, tmp_1004)
+ return tmp_1005
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_set_literal(self):
+
+ def test_function(a, b, c, d, e, f):
+ return set(a + b, c + d, e + f)
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a + b
+ tmp_1002 = c + d
+ tmp_1003 = e + f
+ tmp_1004 = set(tmp_1001, tmp_1002, tmp_1003)
+ return tmp_1004
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_dict_literal_and_repr(self):
+
+ def test_function(foo, bar, baz):
+ return repr({foo + bar + baz: 7 | 8})
+
+ def expected_result(foo, bar, baz):
+ tmp_1001 = foo + bar
+ tmp_1002 = tmp_1001 + baz
+ tmp_1003 = 7 | 8
+ tmp_1004 = {tmp_1002: tmp_1003}
+ tmp_1005 = repr(tmp_1004)
+ return tmp_1005
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_field_read_and_write(self):
+
+ def test_function(a, d):
+ a.b.c = d.e.f + 3
+
+ def expected_result(a, d):
+ tmp_1001 = a.b
+ tmp_1002 = d.e
+ tmp_1003 = tmp_1002.f
+ tmp_1001.c = tmp_1003 + 3
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_subscript_read_and_write(self):
+
+ def test_function(a, b, c, d, e, f):
+ a[b][c] = d[e][f] + 3
+
+ def expected_result(a, b, c, d, e, f):
+ tmp_1001 = a[b]
+ tmp_1002 = d[e]
+ tmp_1003 = tmp_1002[f]
+ tmp_1001[c] = tmp_1003 + 3
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_augassign_and_delete(self):
+
+ def test_function(a, x, y, z):
+ a += x + y + z
+ del a
+ del z[y][x]
+
+ def expected_result(a, x, y, z):
+ tmp_1001 = x + y
+ a += tmp_1001 + z
+ del a
+ tmp_1002 = z[y]
+ del tmp_1002[x]
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_raise_yield_and_raise(self):
+
+ def test_function(a, c, some_computed, exception):
+ yield a ** c
+ raise some_computed('complicated' + exception)
+
+ def expected_result(a, c, some_computed, exception):
+ tmp_1001 = a ** c
+ yield tmp_1001
+ tmp_1002 = 'complicated' + exception
+ tmp_1003 = some_computed(tmp_1002)
+ raise tmp_1003
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_with_and_if_with_expressions(self):
+
+ def test_function(foo, bar, function, quux, quozzle, w, x, y, z):
+ with foo + bar:
+ function(x + y)
+ if quux + quozzle:
+ function(z / w)
+
+ def expected_result(foo, bar, function, quux, quozzle, w, x, y, z):
+ tmp_1001 = foo + bar
+ with tmp_1001:
+ tmp_1002 = x + y
+ function(tmp_1002)
+ tmp_1003 = quux + quozzle
+ if tmp_1003:
+ tmp_1004 = z / w
+ function(tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_exec(self):
+
+ def test_function():
+ # The point is to test A-normal form conversion of exec
+ # pylint: disable=exec-used
+ exec('computed' + 5 + 'stuff', globals(), locals())
+
+ def expected_result():
+ # pylint: disable=exec-used
+ tmp_1001 = 'computed' + 5
+ tmp_1002 = tmp_1001 + 'stuff'
+ tmp_1003 = globals()
+ tmp_1004 = locals()
+ exec(tmp_1002, tmp_1003, tmp_1004)
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_simple_while_and_assert(self):
+
+ def test_function(foo, quux):
+ while foo:
+ assert quux
+ foo = foo + 1 * 3
+
+ def expected_result(foo, quux):
+ while foo:
+ assert quux
+ tmp_1001 = 1 * 3
+ foo = foo + tmp_1001
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_for(self):
+
+ def test_function(compute, something, complicated, foo):
+ for foo in compute(something + complicated):
+ bar = foo + 1 * 3
+ return bar
+
+ def expected_result(compute, something, complicated, foo):
+ tmp_1001 = something + complicated
+ tmp_1002 = compute(tmp_1001)
+ for foo in tmp_1002:
+ tmp_1003 = 1 * 3
+ bar = foo + tmp_1003
+ return bar
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ # This test collects several examples where the definition of A-normal form
+ # implemented by this transformer is questionable. Mostly it's here to spell
+ # out what the definition is in these cases.
+ def test_controversial(self):
+
+ def test_function(b, c, d, f):
+ a = c + d
+ a.b = c + d
+ a[b] = c + d
+ a += c + d
+ a, b = c
+ a, b = c, d
+ a = f(c)
+ a = f(c + d)
+ a[b + d] = f.e(c + d)
+
+ def expected_result(b, c, d, f):
+ a = c + d
+ a.b = c + d # Should be a.b = tmp? (Definitely not tmp = c + d)
+ a[b] = c + d # Should be a[b] = tmp? (Definitely not tmp = c + d)
+ a += c + d # Should be a += tmp? (Definitely not tmp = c + d)
+ a, b = c # Should be a = c[0], b = c[1]? Or not?
+ a, b = c, d # Should be a = c, b = d? Or not?
+ a = f(c)
+ tmp_1001 = c + d
+ a = f(tmp_1001)
+ tmp_1002 = b + d
+ tmp_1003 = f.e
+ tmp_1004 = c + d
+ a[tmp_1002] = tmp_1003(tmp_1004) # Or should be a[tmp1] = tmp2?
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/compiler.py b/tensorflow/python/autograph/pyct/compiler.py
new file mode 100644
index 0000000000..9e1b6bdbe8
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/compiler.py
@@ -0,0 +1,141 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converting AST to code.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Use six for compatibility here.
+import atexit
+import imp
+import os
+import tempfile
+
+import astor
+import gast
+
+from tensorflow.python.autograph.pyct import origin_info
+
+
+def ast_to_source(node, indentation=' '):
+ """Return the source code of given AST.
+
+ Args:
+ node: The code to compile, as an AST object.
+ indentation: The string to use for indentation.
+
+ Returns:
+ code: The source code generated from the AST object
+ source_mapping: A mapping between the user and AutoGraph generated code.
+ """
+ if not isinstance(node, (list, tuple)):
+ node = (node,)
+ generator = astor.codegen.SourceGenerator(indentation, False,
+ astor.string_repr.pretty_string)
+
+ for n in node:
+ if isinstance(n, gast.AST):
+ n = gast.gast_to_ast(n)
+ generator.visit(n)
+ generator.result.append('\n')
+
+ # In some versions of Python, literals may appear as actual values. This
+ # ensures everything is string.
+ code = map(str, generator.result)
+ code = astor.source_repr.pretty_source(code).lstrip()
+
+ return code
+
+
+def ast_to_object(nodes,
+ indentation=' ',
+ include_source_map=False,
+ source_prefix=None,
+ delete_on_exit=True):
+ """Return the Python objects represented by given AST.
+
+ Compiling the AST code this way ensures that the source code is readable by
+ e.g. `pdb` or `inspect`.
+
+ Args:
+ nodes: Union[ast.AST, Iterable[ast.AST]], the code to compile, as an AST
+ object.
+ indentation: Text, the string to use for indentation.
+ include_source_map: bool, whether to attach a source map to the compiled
+ object. Also see origin_info.py.
+ source_prefix: Optional[Text], string to print as-is into the source file.
+ delete_on_exit: bool, whether to delete the temporary file used for
+ compilation on exit.
+
+ Returns:
+ compiled_nodes: A module object containing the compiled source code.
+ source: The source code of the compiled object
+ Raises:
+ ValueError: If ag_source_map__ is already in the namespace of the compiled
+ nodes.
+ """
+ if not isinstance(nodes, (list, tuple)):
+ nodes = (nodes,)
+
+ source = ast_to_source(nodes, indentation=indentation)
+
+ if source_prefix:
+ source = source_prefix + '\n' + source
+
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
+ module_name = os.path.basename(f.name[:-3])
+ f.write(source)
+
+ if isinstance(nodes, (list, tuple)):
+ indices = range(-len(nodes), 0)
+ else:
+ indices = (-1,)
+
+ if include_source_map:
+ source_map = origin_info.source_map(nodes, source, f.name, indices)
+
+ # TODO(mdan): Try flush() and delete=False instead.
+ if delete_on_exit:
+ atexit.register(lambda: os.remove(f.name))
+ compiled_nodes = imp.load_source(module_name, f.name)
+
+ # TODO(znado): Clean this up so we don't need to attach it to the namespace.
+ # TODO(znado): This does not work for classes because their methods share a
+ # namespace.
+ # This attaches the source map which is needed for error handling. Note that
+ # api.to_graph copies this source map into an attribute of the function.
+ #
+ # We need this so the ag_source_map__ variable is available to the call to
+ # rewrite_graph_construction_error in the except block inside each function
+ # that handles graph construction errors.
+ #
+ # We cannot get the rewritten function name until it is too late so templating
+ # is hard, and this cleanly fixes the
+ # issues encountered with nested functions because this is attached to the
+ # outermost one.
+ if include_source_map:
+ # TODO(mdan): This name should be decided by the caller.
+ source_map_name = 'ag_source_map__'
+ if source_map_name in compiled_nodes.__dict__:
+ raise ValueError('cannot convert %s because is has namespace attribute '
+ '"%s", which is reserved for AutoGraph.' %
+ (compiled_nodes, source_map_name))
+ compiled_nodes.__dict__[source_map_name] = source_map
+
+ return compiled_nodes, source
diff --git a/tensorflow/python/autograph/pyct/compiler_test.py b/tensorflow/python/autograph/pyct/compiler_test.py
new file mode 100644
index 0000000000..6fa289d3cc
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/compiler_test.py
@@ -0,0 +1,108 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for compiler module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+import gast
+
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
+
+
+class CompilerTest(test.TestCase):
+
+ def test_parser_compile_idempotent(self):
+
+ def test_fn(x):
+ a = True
+ b = ''
+ if a:
+ b = x + 1
+ return b
+
+ self.assertEqual(
+ textwrap.dedent(tf_inspect.getsource(test_fn)),
+ tf_inspect.getsource(
+ compiler.ast_to_object(
+ parser.parse_entity(test_fn)[0].body[0])[0].test_fn))
+
+ def test_ast_to_source(self):
+ node = gast.If(
+ test=gast.Num(1),
+ body=[
+ gast.Assign(
+ targets=[gast.Name('a', gast.Store(), None)],
+ value=gast.Name('b', gast.Load(), None))
+ ],
+ orelse=[
+ gast.Assign(
+ targets=[gast.Name('a', gast.Store(), None)],
+ value=gast.Str('c'))
+ ])
+
+ source = compiler.ast_to_source(node, indentation=' ')
+ self.assertEqual(
+ textwrap.dedent("""
+ if 1:
+ a = b
+ else:
+ a = 'c'
+ """).strip(), source.strip())
+
+ def test_ast_to_object(self):
+ node = gast.FunctionDef(
+ name='f',
+ args=gast.arguments(
+ args=[gast.Name('a', gast.Param(), None)],
+ vararg=None,
+ kwonlyargs=[],
+ kwarg=None,
+ defaults=[],
+ kw_defaults=[]),
+ body=[
+ gast.Return(
+ gast.BinOp(
+ op=gast.Add(),
+ left=gast.Name('a', gast.Load(), None),
+ right=gast.Num(1)))
+ ],
+ decorator_list=[],
+ returns=None)
+
+ module, source = compiler.ast_to_object(node)
+
+ expected_source = """
+ def f(a):
+ return a + 1
+ """
+ self.assertEqual(
+ textwrap.dedent(expected_source).strip(),
+ source.strip())
+ self.assertEqual(2, module.f(1))
+ with open(module.__file__, 'r') as temp_output:
+ self.assertEqual(
+ textwrap.dedent(expected_source).strip(),
+ temp_output.read().strip())
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
new file mode 100644
index 0000000000..eef74599a7
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -0,0 +1,161 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Live entity inspection utilities.
+
+This module contains whatever inspect doesn't offer out of the box.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import types
+
+import six
+
+from tensorflow.python.util import tf_inspect
+
+
+def isbuiltin(f):
+ # Note these return false for isinstance(f, types.BuiltinFunctionType) so we
+ # need to specifically check for them.
+ if f in (range, int, float):
+ return True
+ if isinstance(f, types.BuiltinFunctionType):
+ return True
+ if tf_inspect.isbuiltin(f):
+ return True
+ return False
+
+
+def getnamespace(f):
+ """Returns the complete namespace of a function.
+
+ Namespace is defined here as the mapping of all non-local variables to values.
+ This includes the globals and the closure variables. Note that this captures
+ the entire globals collection of the function, and may contain extra symbols
+ that it does not actually use.
+
+ Args:
+ f: User defined function.
+ Returns:
+ A dict mapping symbol names to values.
+ """
+ namespace = dict(six.get_function_globals(f))
+ closure = six.get_function_closure(f)
+ freevars = six.get_function_code(f).co_freevars
+ if freevars and closure:
+ for name, cell in zip(freevars, closure):
+ namespace[name] = cell.cell_contents
+ return namespace
+
+
+def _get_unbound_function(m):
+ # TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
+ # The failure case is for tf.keras.Model.
+ if hasattr(m, 'im_func'):
+ return m.im_func
+ return m
+
+
+def getdefiningclass(m, owner_class):
+ """Resolves the class (e.g. one of the superclasses) that defined a method."""
+ # Normalize bound functions to their respective unbound versions.
+ m = _get_unbound_function(m)
+ for superclass in owner_class.__bases__:
+ if hasattr(superclass, m.__name__):
+ superclass_m = getattr(superclass, m.__name__)
+ if _get_unbound_function(superclass_m) is m:
+ return superclass
+ elif hasattr(m, '__self__') and m.__self__ == owner_class:
+ # Python 3 class methods only work this way it seems :S
+ return superclass
+ return owner_class
+
+
+def getmethodclass(m):
+ """Resolves a function's owner, e.g. a method's class.
+
+ Note that this returns the object that the function was retrieved from, not
+ necessarily the class where it was defined.
+
+ This function relies on Python stack frame support in the interpreter, and
+ has the same limitations that inspect.currentframe.
+
+ Limitations. This function will only work correctly if the owned class is
+ visible in the caller's global or local variables.
+
+ Args:
+ m: A user defined function
+
+ Returns:
+ The class that this function was retrieved from, or None if the function
+ is not an object or class method, or the class that owns the object or
+ method is not visible to m.
+
+ Raises:
+ ValueError: if the class could not be resolved for any unexpected reason.
+ """
+
+ # Callable objects: return their own class.
+ if (not hasattr(m, '__name__') and hasattr(m, '__class__') and
+ hasattr(m, '__call__')):
+ if isinstance(m.__class__, six.class_types):
+ return m.__class__
+
+ # Instance method and class methods: should be bound to a non-null "self".
+ # If self is a class, then it's a class method.
+ if hasattr(m, '__self__'):
+ if m.__self__:
+ if tf_inspect.isclass(m.__self__):
+ return m.__self__
+ return type(m.__self__)
+
+ # Class, static and unbound methods: search all defined classes in any
+ # namespace. This is inefficient but more robust method.
+ owners = []
+ caller_frame = tf_inspect.currentframe().f_back
+ try:
+ # TODO(mdan): This doesn't consider cell variables.
+ # TODO(mdan): This won't work if the owner is hidden inside a container.
+ # Cell variables may be pulled using co_freevars and the closure.
+ for v in itertools.chain(caller_frame.f_locals.values(),
+ caller_frame.f_globals.values()):
+ if hasattr(v, m.__name__):
+ candidate = getattr(v, m.__name__)
+ # Py2 methods may be bound or unbound, extract im_func to get the
+ # underlying function.
+ if hasattr(candidate, 'im_func'):
+ candidate = candidate.im_func
+ if hasattr(m, 'im_func'):
+ m = m.im_func
+ if candidate is m:
+ owners.append(v)
+ finally:
+ del caller_frame
+
+ if owners:
+ if len(owners) == 1:
+ return owners[0]
+
+ # If multiple owners are found, and are not subclasses, raise an error.
+ owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners)
+ for o in owner_types:
+ if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)):
+ return o
+ raise ValueError('Found too many owners of %s: %s' % (m, owners))
+
+ return None
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
new file mode 100644
index 0000000000..f3eb027822
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -0,0 +1,277 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for unspect_utils module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import wraps
+
+import six
+
+from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.platform import test
+
+
+def decorator(f):
+ return f
+
+
+def function_decorator():
+ def dec(f):
+ return f
+ return dec
+
+
+def wrapping_decorator():
+ def dec(f):
+ def replacement(*_):
+ return None
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ return replacement(*args, **kwargs)
+ return wrapper
+ return dec
+
+
+class TestClass(object):
+
+ def member_function(self):
+ pass
+
+ @decorator
+ def decorated_member(self):
+ pass
+
+ @function_decorator()
+ def fn_decorated_member(self):
+ pass
+
+ @wrapping_decorator()
+ def wrap_decorated_member(self):
+ pass
+
+ @staticmethod
+ def static_method():
+ pass
+
+ @classmethod
+ def class_method(cls):
+ pass
+
+
+def free_function():
+ pass
+
+
+def factory():
+ return free_function
+
+
+def free_factory():
+ def local_function():
+ pass
+ return local_function
+
+
+class InspectUtilsTest(test.TestCase):
+
+ def test_getnamespace_globals(self):
+ ns = inspect_utils.getnamespace(factory)
+ self.assertEqual(ns['free_function'], free_function)
+
+ def test_getnamespace_hermetic(self):
+
+ # Intentionally hiding the global function to make sure we don't overwrite
+ # it in the global namespace.
+ free_function = object() # pylint:disable=redefined-outer-name
+
+ def test_fn():
+ return free_function
+
+ ns = inspect_utils.getnamespace(test_fn)
+ globs = six.get_function_globals(test_fn)
+ self.assertTrue(ns['free_function'] is free_function)
+ self.assertFalse(globs['free_function'] is free_function)
+
+ def test_getnamespace_locals(self):
+
+ def called_fn():
+ return 0
+
+ closed_over_list = []
+ closed_over_primitive = 1
+
+ def local_fn():
+ closed_over_list.append(1)
+ local_var = 1
+ return called_fn() + local_var + closed_over_primitive
+
+ ns = inspect_utils.getnamespace(local_fn)
+ self.assertEqual(ns['called_fn'], called_fn)
+ self.assertEqual(ns['closed_over_list'], closed_over_list)
+ self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
+ self.assertTrue('local_var' not in ns)
+
+ def test_getmethodclass(self):
+
+ self.assertEqual(
+ inspect_utils.getmethodclass(free_function), None)
+ self.assertEqual(
+ inspect_utils.getmethodclass(free_factory()), None)
+
+ self.assertEqual(
+ inspect_utils.getmethodclass(TestClass.member_function),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(TestClass.decorated_member),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(TestClass.fn_decorated_member),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(TestClass.wrap_decorated_member),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(TestClass.static_method),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(TestClass.class_method),
+ TestClass)
+
+ test_obj = TestClass()
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.member_function),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.decorated_member),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.fn_decorated_member),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.static_method),
+ TestClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.class_method),
+ TestClass)
+
+ def test_getmethodclass_locals(self):
+
+ def local_function():
+ pass
+
+ class LocalClass(object):
+
+ def member_function(self):
+ pass
+
+ @decorator
+ def decorated_member(self):
+ pass
+
+ @function_decorator()
+ def fn_decorated_member(self):
+ pass
+
+ @wrapping_decorator()
+ def wrap_decorated_member(self):
+ pass
+
+ self.assertEqual(
+ inspect_utils.getmethodclass(local_function), None)
+
+ self.assertEqual(
+ inspect_utils.getmethodclass(LocalClass.member_function),
+ LocalClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(LocalClass.decorated_member),
+ LocalClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(LocalClass.fn_decorated_member),
+ LocalClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(LocalClass.wrap_decorated_member),
+ LocalClass)
+
+ test_obj = LocalClass()
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.member_function),
+ LocalClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.decorated_member),
+ LocalClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.fn_decorated_member),
+ LocalClass)
+ self.assertEqual(
+ inspect_utils.getmethodclass(test_obj.wrap_decorated_member),
+ LocalClass)
+
+ def test_getmethodclass_callables(self):
+ class TestCallable(object):
+
+ def __call__(self):
+ pass
+
+ c = TestCallable()
+ self.assertEqual(inspect_utils.getmethodclass(c), TestCallable)
+
+ def test_getdefiningclass(self):
+ class Superclass(object):
+
+ def foo(self):
+ pass
+
+ def bar(self):
+ pass
+
+ @classmethod
+ def class_method(cls):
+ pass
+
+ class Subclass(Superclass):
+
+ def foo(self):
+ pass
+
+ def baz(self):
+ pass
+
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass)
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass)
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass)
+ self.assertTrue(
+ inspect_utils.getdefiningclass(Subclass.class_method, Subclass) is
+ Superclass)
+
+ def test_isbuiltin(self):
+ self.assertTrue(inspect_utils.isbuiltin(range))
+ self.assertTrue(inspect_utils.isbuiltin(float))
+ self.assertTrue(inspect_utils.isbuiltin(int))
+ self.assertTrue(inspect_utils.isbuiltin(len))
+ self.assertFalse(inspect_utils.isbuiltin(function_decorator))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
new file mode 100644
index 0000000000..4c7c4165ef
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -0,0 +1,186 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Container for origin source code information before AutoGraph compilation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import tokenize
+
+import gast
+import six
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.util import tf_inspect
+
+
+class LineLocation(
+ collections.namedtuple('LineLocation', ('filename', 'lineno'))):
+ """Similar to Location, but without column information.
+
+ Attributes:
+ filename: Text
+ lineno: int, 1-based
+ """
+ pass
+
+
+class Location(
+ collections.namedtuple('Location', ('filename', 'lineno', 'col_offset'))):
+ """Encodes code location information.
+
+ Attributes:
+ filename: Text
+ lineno: int, 1-based
+ col_offset: int
+ """
+
+ @property
+ def line_loc(self):
+ return LineLocation(self.filename, self.lineno)
+
+
+class OriginInfo(
+ collections.namedtuple(
+ 'OriginInfo',
+ ('loc', 'function_name', 'source_code_line', 'comment'))):
+ """Container for information about the source code before conversion.
+
+ Attributes:
+ loc: Location
+ function_name: Optional[Text]
+ source_code_line: Text
+ comment: Optional[Text]
+ """
+
+ def as_frame(self):
+ """Returns a 4-tuple consistent with the return of traceback.extract_tb."""
+ return (self.loc.filename, self.loc.lineno, self.function_name,
+ self.source_code_line)
+
+
+# TODO(mdan): This source map should be a class - easier to refer to.
+def source_map(nodes, code, filename, indices_in_code):
+ """Creates a source map between an annotated AST and the code it compiles to.
+
+ Args:
+ nodes: Iterable[ast.AST, ...]
+ code: Text
+ filename: Optional[Text]
+ indices_in_code: Union[int, Iterable[int, ...]], the positions at which
+ nodes appear in code. The parser always returns a module when parsing
+ code. This argument indicates the position in that module's body at
+ which the corresponding of node should appear.
+
+ Returns:
+ Dict[CodeLocation, OriginInfo], mapping locations in code to locations
+ indicated by origin annotations in node.
+ """
+ reparsed_nodes = parser.parse_str(code)
+ reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
+
+ resolve(reparsed_nodes, code)
+ result = {}
+
+ for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
+ # Note: generated code might not be mapped back to its origin.
+ # TODO(mdan): Generated code should always be mapped to something.
+ origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
+ final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
+ if origin_info is None or final_info is None:
+ continue
+
+ line_loc = LineLocation(filename, final_info.loc.lineno)
+
+ existing_origin = result.get(line_loc)
+ if existing_origin is not None:
+ # Overlaps may exist because of child nodes, but almost never to
+ # different line locations. Exception make decorated functions, where
+ # both lines are mapped to the same line in the AST.
+
+ # Line overlaps: keep bottom node.
+ if existing_origin.loc.line_loc == origin_info.loc.line_loc:
+ if existing_origin.loc.lineno >= origin_info.loc.lineno:
+ continue
+
+ # In case of overlaps, keep the leftmost node.
+ if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
+ continue
+
+ result[line_loc] = origin_info
+
+ return result
+
+
+# TODO(znado): Consider refactoring this into a Visitor.
+# TODO(mdan): Does this work correctly with inner functions?
+def resolve(nodes, source, function=None):
+ """Adds an origin information to all nodes inside the body of function.
+
+ Args:
+ nodes: Union[ast.AST, Iterable[ast.AST, ...]]
+ source: Text, the source code string for the function whose body nodes will
+ be annotated.
+ function: Callable, the function that will have all nodes inside of it
+ annotation with an OriginInfo annotation with key anno.Basic.ORIGIN. If
+ it is None then only the line numbers and column offset will be set in the
+ annotation, with the rest of the information being None.
+
+ Returns:
+ A tuple of the AST node for function and a String containing its source
+ code.
+ """
+ if not isinstance(nodes, (list, tuple)):
+ nodes = (nodes,)
+
+ if function:
+ _, function_lineno = tf_inspect.getsourcelines(function)
+ function_filepath = tf_inspect.getsourcefile(function)
+ else:
+ function_lineno = None
+ function_filepath = None
+
+ # TODO(mdan): Pull this to a separate utility.
+ code_reader = six.StringIO(source)
+ comment_map = {}
+ for token in tokenize.generate_tokens(code_reader.readline):
+ tok_type, tok_string, loc, _, _ = token
+ srow, _ = loc
+ if tok_type == tokenize.COMMENT:
+ comment_map[srow] = tok_string.strip()[1:].strip()
+
+ source_lines = source.split('\n')
+ for node in nodes:
+ for n in gast.walk(node):
+ if not hasattr(n, 'lineno'):
+ continue
+
+ lineno_in_body = n.lineno
+
+ source_code_line = source_lines[lineno_in_body - 1]
+ if function:
+ source_lineno = function_lineno + lineno_in_body
+ function_name = function.__name__
+ else:
+ source_lineno = lineno_in_body
+ function_name = None
+
+ location = Location(function_filepath, source_lineno, n.col_offset)
+ origin = OriginInfo(location, function_name,
+ source_code_line, comment_map.get(source_lineno))
+ anno.setanno(n, anno.Basic.ORIGIN, origin)
diff --git a/tensorflow/python/autograph/pyct/origin_info_test.py b/tensorflow/python/autograph/pyct/origin_info_test.py
new file mode 100644
index 0000000000..6b9c30dbd0
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/origin_info_test.py
@@ -0,0 +1,104 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for origin_info module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import origin_info
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class OriginInfoTest(test.TestCase):
+
+ def test_source_map(self):
+
+ def test_fn(x):
+ if x > 0:
+ x += 1
+ return x
+
+ node, source = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ origin_info.resolve(fn_node, source)
+
+ # Insert a traced line.
+ new_node = parser.parse_str('x = abs(x)').body[0]
+ anno.copyanno(fn_node.body[0], new_node, anno.Basic.ORIGIN)
+ fn_node.body.insert(0, new_node)
+
+ # Insert an untraced line.
+ fn_node.body.insert(0, parser.parse_str('x = 0').body[0])
+
+ modified_source = compiler.ast_to_source(fn_node)
+
+ source_map = origin_info.source_map(fn_node, modified_source,
+ 'test_filename', [0])
+
+ loc = origin_info.LineLocation('test_filename', 1)
+ origin = source_map[loc]
+ self.assertEqual(origin.source_code_line, 'def test_fn(x):')
+ self.assertEqual(origin.loc.lineno, 1)
+
+ # The untraced line, inserted second.
+ loc = origin_info.LineLocation('test_filename', 2)
+ self.assertFalse(loc in source_map)
+
+ # The traced line, inserted first.
+ loc = origin_info.LineLocation('test_filename', 3)
+ origin = source_map[loc]
+ self.assertEqual(origin.source_code_line, ' if x > 0:')
+ self.assertEqual(origin.loc.lineno, 2)
+
+ loc = origin_info.LineLocation('test_filename', 4)
+ origin = source_map[loc]
+ self.assertEqual(origin.source_code_line, ' if x > 0:')
+ self.assertEqual(origin.loc.lineno, 2)
+
+ def test_resolve(self):
+
+ def test_fn(x):
+ """Docstring."""
+ return x # comment
+
+ node, source = parser.parse_entity(test_fn)
+ fn_node = node.body[0]
+ origin_info.resolve(fn_node, source)
+
+ origin = anno.getanno(fn_node, anno.Basic.ORIGIN)
+ self.assertEqual(origin.loc.lineno, 1)
+ self.assertEqual(origin.loc.col_offset, 0)
+ self.assertEqual(origin.source_code_line, 'def test_fn(x):')
+ self.assertIsNone(origin.comment)
+
+ origin = anno.getanno(fn_node.body[0], anno.Basic.ORIGIN)
+ self.assertEqual(origin.loc.lineno, 2)
+ self.assertEqual(origin.loc.col_offset, 2)
+ self.assertEqual(origin.source_code_line, ' """Docstring."""')
+ self.assertIsNone(origin.comment)
+
+ origin = anno.getanno(fn_node.body[1], anno.Basic.ORIGIN)
+ self.assertEqual(origin.loc.lineno, 3)
+ self.assertEqual(origin.loc.col_offset, 2)
+ self.assertEqual(origin.source_code_line, ' return x # comment')
+ self.assertEqual(origin.comment, 'comment')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
new file mode 100644
index 0000000000..112ed46a1e
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/parser.py
@@ -0,0 +1,59 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converting code to AST.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+import gast
+
+from tensorflow.python.util import tf_inspect
+
+
+def parse_entity(entity):
+ """Returns the AST of given entity."""
+ source = tf_inspect.getsource(entity)
+ source = textwrap.dedent(source)
+ return parse_str(source), source
+
+
+def parse_str(src):
+ """Returns the AST of given piece of code."""
+ # TODO(mdan): This should exclude the module things are autowrapped in.
+ return gast.parse(src)
+
+
+def parse_expression(src):
+ """Returns the AST of given identifier.
+
+ Args:
+ src: A piece of code that represents a single Python expression
+ Returns:
+ A gast.AST object.
+ Raises:
+ ValueError: if src does not consist of a single Expression.
+ """
+ node = parse_str(src)
+ assert isinstance(node, gast.Module)
+ if len(node.body) != 1 and not isinstance(node.body[0], gast.Expr):
+ raise ValueError(
+ 'Expected a single expression, found instead %s' % node.body)
+ return node.body[0].value
diff --git a/tensorflow/python/autograph/pyct/parser_test.py b/tensorflow/python/autograph/pyct/parser_test.py
new file mode 100644
index 0000000000..d0b465eb73
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/parser_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for parser module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class ParserTest(test.TestCase):
+
+ def test_parse_entity(self):
+
+ def f(x):
+ return x + 1
+
+ mod, _ = parser.parse_entity(f)
+ self.assertEqual('f', mod.body[0].name)
+
+ def test_parse_str(self):
+ mod = parser.parse_str(
+ textwrap.dedent("""
+ def f(x):
+ return x + 1
+ """))
+ self.assertEqual('f', mod.body[0].name)
+
+ def test_parse_expression(self):
+ node = parser.parse_expression('a.b')
+ self.assertEqual('a', node.value.id)
+ self.assertEqual('b', node.attr)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py
new file mode 100644
index 0000000000..bacc1e4a77
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/pretty_printer.py
@@ -0,0 +1,113 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Print an AST tree in a form more readable than ast.dump."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+import termcolor
+
+
+class PrettyPrinter(gast.NodeVisitor):
+ """Print AST nodes."""
+
+ def __init__(self, color):
+ self.indent_lvl = 0
+ self.result = ''
+ self.color = color
+
+ def _color(self, string, color, attrs=None):
+ if self.color:
+ return termcolor.colored(string, color, attrs=attrs)
+ return string
+
+ def _type(self, node):
+ return self._color(node.__class__.__name__, None, ['bold'])
+
+ def _field(self, name):
+ return self._color(name, 'blue')
+
+ def _value(self, name):
+ return self._color(name, 'magenta')
+
+ def _warning(self, name):
+ return self._color(name, 'red')
+
+ def _indent(self):
+ return self._color('| ' * self.indent_lvl, None, ['dark'])
+
+ def _print(self, s):
+ self.result += s
+ self.result += '\n'
+
+ def generic_visit(self, node, name=None):
+ if node._fields:
+ cont = ':'
+ else:
+ cont = '()'
+
+ if name:
+ self._print('%s%s=%s%s' % (self._indent(), self._field(name),
+ self._type(node), cont))
+ else:
+ self._print('%s%s%s' % (self._indent(), self._type(node), cont))
+
+ self.indent_lvl += 1
+ for f in node._fields:
+ if not hasattr(node, f):
+ self._print('%s%s' % (self._indent(), self._warning('%s=<unset>' % f)))
+ continue
+ v = getattr(node, f)
+ if isinstance(v, list):
+ if v:
+ self._print('%s%s=[' % (self._indent(), self._field(f)))
+ self.indent_lvl += 1
+ for n in v:
+ self.generic_visit(n)
+ self.indent_lvl -= 1
+ self._print('%s]' % (self._indent()))
+ else:
+ self._print('%s%s=[]' % (self._indent(), self._field(f)))
+ elif isinstance(v, tuple):
+ if v:
+ self._print('%s%s=(' % (self._indent(), self._field(f)))
+ self.indent_lvl += 1
+ for n in v:
+ self.generic_visit(n)
+ self.indent_lvl -= 1
+ self._print('%s)' % (self._indent()))
+ else:
+ self._print('%s%s=()' % (self._indent(), self._field(f)))
+ elif isinstance(v, gast.AST):
+ self.generic_visit(v, f)
+ elif isinstance(v, str):
+ self._print('%s%s=%s' % (self._indent(), self._field(f),
+ self._value('"%s"' % v)))
+ else:
+ self._print('%s%s=%s' % (self._indent(), self._field(f),
+ self._value(v)))
+ self.indent_lvl -= 1
+
+
+def fmt(node, color=True):
+ printer = PrettyPrinter(color)
+ if isinstance(node, (list, tuple)):
+ for n in node:
+ printer.visit(n)
+ else:
+ printer.visit(node)
+ return printer.result
diff --git a/tensorflow/python/autograph/pyct/pretty_printer_test.py b/tensorflow/python/autograph/pyct/pretty_printer_test.py
new file mode 100644
index 0000000000..1c76744547
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/pretty_printer_test.py
@@ -0,0 +1,52 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for pretty_printer module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+
+from tensorflow.python.autograph.pyct import pretty_printer
+from tensorflow.python.platform import test
+
+
+class PrettyPrinterTest(test.TestCase):
+
+ def test_format(self):
+ node = ast.FunctionDef(
+ name='f',
+ args=ast.arguments(
+ args=[ast.Name(id='a', ctx=ast.Param())],
+ vararg=None,
+ kwarg=None,
+ defaults=[]),
+ body=[
+ ast.Return(
+ ast.BinOp(
+ op=ast.Add(),
+ left=ast.Name(id='a', ctx=ast.Load()),
+ right=ast.Num(1)))
+ ],
+ decorator_list=[],
+ returns=None)
+ # Just checking for functionality, the color control characters make it
+ # difficult to inspect the result.
+ self.assertIsNotNone(pretty_printer.fmt(node))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/qual_names.py b/tensorflow/python/autograph/pyct/qual_names.py
new file mode 100644
index 0000000000..334cbd7d38
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/qual_names.py
@@ -0,0 +1,257 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for manipulating qualified names.
+
+A qualified name is a uniform way to refer to simple (e.g. 'foo') and composite
+(e.g. 'foo.bar') syntactic symbols.
+
+This is *not* related to the __qualname__ attribute used by inspect, which
+refers to scopes.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+
+
+class Symbol(collections.namedtuple('Symbol', ['name'])):
+ """Represents a Python symbol."""
+
+
+class StringLiteral(collections.namedtuple('StringLiteral', ['value'])):
+ """Represents a Python string literal."""
+
+ def __str__(self):
+ return '\'%s\'' % self.value
+
+ def __repr__(self):
+ return str(self)
+
+
+class NumberLiteral(collections.namedtuple('NumberLiteral', ['value'])):
+ """Represents a Python numeric literal."""
+
+ def __str__(self):
+ return '%s' % self.value
+
+ def __repr__(self):
+ return str(self)
+
+
+# TODO(mdan): Use subclasses to remove the has_attr has_subscript booleans.
+class QN(object):
+ """Represents a qualified name."""
+
+ def __init__(self, base, attr=None, subscript=None):
+ if attr is not None and subscript is not None:
+ raise ValueError('A QN can only be either an attr or a subscript, not '
+ 'both: attr={}, subscript={}.'.format(attr, subscript))
+ self._has_attr = False
+ self._has_subscript = False
+
+ if attr is not None:
+ if not isinstance(base, QN):
+ raise ValueError(
+ 'for attribute QNs, base must be a QN; got instead "%s"' % base)
+ if not isinstance(attr, str):
+ raise ValueError('attr may only be a string; got instead "%s"' % attr)
+ self._parent = base
+ # TODO(mdan): Get rid of the tuple - it can only have 1 or 2 elements now.
+ self.qn = (base, attr)
+ self._has_attr = True
+
+ elif subscript is not None:
+ if not isinstance(base, QN):
+ raise ValueError('For subscript QNs, base must be a QN.')
+ self._parent = base
+ self.qn = (base, subscript)
+ self._has_subscript = True
+
+ else:
+ if not isinstance(base, (str, StringLiteral, NumberLiteral)):
+ # TODO(mdan): Require Symbol instead of string.
+ raise ValueError(
+ 'for simple QNs, base must be a string or a Literal object;'
+ ' got instead "%s"' % type(base))
+ assert '.' not in base and '[' not in base and ']' not in base
+ self._parent = None
+ self.qn = (base,)
+
+ def is_symbol(self):
+ return isinstance(self.qn[0], str)
+
+ def is_composite(self):
+ return len(self.qn) > 1
+
+ def has_subscript(self):
+ return self._has_subscript
+
+ def has_attr(self):
+ return self._has_attr
+
+ @property
+ def parent(self):
+ if self._parent is None:
+ raise ValueError('Cannot get parent of simple name "%s".' % self.qn[0])
+ return self._parent
+
+ @property
+ def owner_set(self):
+ """Returns all the symbols (simple or composite) that own this QN.
+
+ In other words, if this symbol was modified, the symbols in the owner set
+ may also be affected.
+
+ Examples:
+ 'a.b[c.d]' has two owners, 'a' and 'a.b'
+ """
+ owners = set()
+ if self.has_attr() or self.has_subscript():
+ owners.add(self.parent)
+ owners.update(self.parent.owner_set)
+ return owners
+
+ @property
+ def support_set(self):
+ """Returns the set of simple symbols that this QN relies on.
+
+ This would be the smallest set of symbols necessary for the QN to
+ statically resolve (assuming properties and index ranges are verified
+ at runtime).
+
+ Examples:
+ 'a.b' has only one support symbol, 'a'
+ 'a[i]' has two support symbols, 'a' and 'i'
+ """
+ # TODO(mdan): This might be the set of Name nodes in the AST. Track those?
+ roots = set()
+ if self.has_attr():
+ roots.update(self.parent.support_set)
+ elif self.has_subscript():
+ roots.update(self.parent.support_set)
+ roots.update(self.qn[1].support_set)
+ else:
+ roots.add(self)
+ return roots
+
+ def __hash__(self):
+ return hash(self.qn + (self._has_attr, self._has_subscript))
+
+ def __eq__(self, other):
+ return (isinstance(other, QN) and self.qn == other.qn and
+ self.has_subscript() == other.has_subscript() and
+ self.has_attr() == other.has_attr())
+
+ def __str__(self):
+ if self.has_subscript():
+ return str(self.qn[0]) + '[' + str(self.qn[1]) + ']'
+ if self.has_attr():
+ return '.'.join(map(str, self.qn))
+ else:
+ return str(self.qn[0])
+
+ def __repr__(self):
+ return str(self)
+
+ def ssf(self):
+ """Simple symbol form."""
+ ssfs = [n.ssf() if isinstance(n, QN) else n for n in self.qn]
+ ssf_string = ''
+ for i in range(0, len(self.qn) - 1):
+ if self.has_subscript():
+ delimiter = '_sub_'
+ else:
+ delimiter = '_'
+ ssf_string += ssfs[i] + delimiter
+ return ssf_string + ssfs[-1]
+
+ def ast(self):
+ # The caller must adjust the context appropriately.
+ if self.has_subscript():
+ return gast.Subscript(self.parent.ast(), gast.Index(self.qn[-1].ast()),
+ None)
+ if self.has_attr():
+ return gast.Attribute(self.parent.ast(), self.qn[-1], None)
+
+ base = self.qn[0]
+ if isinstance(base, str):
+ return gast.Name(base, None, None)
+ elif isinstance(base, StringLiteral):
+ return gast.Str(base.value)
+ elif isinstance(base, NumberLiteral):
+ return gast.Num(base.value)
+ else:
+ assert False, ('the constructor should prevent types other than '
+ 'str, StringLiteral and NumberLiteral')
+
+
+class QnResolver(gast.NodeTransformer):
+ """Annotates nodes with QN information.
+
+ Note: Not using NodeAnnos to avoid circular dependencies.
+ """
+
+ def visit_Name(self, node):
+ node = self.generic_visit(node)
+ anno.setanno(node, anno.Basic.QN, QN(node.id))
+ return node
+
+ def visit_Attribute(self, node):
+ node = self.generic_visit(node)
+ if anno.hasanno(node.value, anno.Basic.QN):
+ anno.setanno(node, anno.Basic.QN,
+ QN(anno.getanno(node.value, anno.Basic.QN), attr=node.attr))
+ return node
+
+ def visit_Subscript(self, node):
+ # TODO(mdan): This may no longer apply if we overload getitem.
+ node = self.generic_visit(node)
+ s = node.slice
+ if not isinstance(s, gast.Index):
+ # TODO(mdan): Support range and multi-dimensional indices.
+ # Continuing silently because some demos use these.
+ return node
+ if isinstance(s.value, gast.Num):
+ subscript = QN(NumberLiteral(s.value.n))
+ elif isinstance(s.value, gast.Str):
+ subscript = QN(StringLiteral(s.value.s))
+ else:
+ # The index may be an expression, case in which a name doesn't make sense.
+ if anno.hasanno(node.slice.value, anno.Basic.QN):
+ subscript = anno.getanno(node.slice.value, anno.Basic.QN)
+ else:
+ return node
+ if anno.hasanno(node.value, anno.Basic.QN):
+ anno.setanno(node, anno.Basic.QN,
+ QN(anno.getanno(node.value, anno.Basic.QN),
+ subscript=subscript))
+ return node
+
+
+def resolve(node):
+ return QnResolver().visit(node)
+
+
+def from_str(qn_str):
+ node = parser.parse_expression(qn_str)
+ node = resolve(node)
+ return anno.getanno(node, anno.Basic.QN)
diff --git a/tensorflow/python/autograph/pyct/qual_names_test.py b/tensorflow/python/autograph/pyct/qual_names_test.py
new file mode 100644
index 0000000000..2da4dfd787
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/qual_names_test.py
@@ -0,0 +1,255 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for qual_names module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import textwrap
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.qual_names import resolve
+from tensorflow.python.platform import test
+
+
+class QNTest(test.TestCase):
+
+ def test_from_str(self):
+ a = QN('a')
+ b = QN('b')
+ a_dot_b = QN(a, attr='b')
+ a_sub_b = QN(a, subscript=b)
+ self.assertEqual(qual_names.from_str('a.b'), a_dot_b)
+ self.assertEqual(qual_names.from_str('a'), a)
+ self.assertEqual(qual_names.from_str('a[b]'), a_sub_b)
+
+ def test_basic(self):
+ a = QN('a')
+ self.assertEqual(a.qn, ('a',))
+ self.assertEqual(str(a), 'a')
+ self.assertEqual(a.ssf(), 'a')
+ self.assertEqual(a.ast().id, 'a')
+ self.assertFalse(a.is_composite())
+ with self.assertRaises(ValueError):
+ _ = a.parent
+
+ a_b = QN(a, attr='b')
+ self.assertEqual(a_b.qn, (a, 'b'))
+ self.assertEqual(str(a_b), 'a.b')
+ self.assertEqual(a_b.ssf(), 'a_b')
+ self.assertEqual(a_b.ast().value.id, 'a')
+ self.assertEqual(a_b.ast().attr, 'b')
+ self.assertTrue(a_b.is_composite())
+ self.assertEqual(a_b.parent.qn, ('a',))
+
+ def test_subscripts(self):
+ a = QN('a')
+ b = QN('b')
+ a_sub_b = QN(a, subscript=b)
+ self.assertEqual(a_sub_b.qn, (a, b))
+ self.assertEqual(str(a_sub_b), 'a[b]')
+ self.assertEqual(a_sub_b.ssf(), 'a_sub_b')
+ self.assertEqual(a_sub_b.ast().value.id, 'a')
+ self.assertEqual(a_sub_b.ast().slice.value.id, 'b')
+ self.assertTrue(a_sub_b.is_composite())
+ self.assertTrue(a_sub_b.has_subscript())
+ self.assertEqual(a_sub_b.parent.qn, ('a',))
+
+ c = QN('c')
+ b_sub_c = QN(b, subscript=c)
+ a_sub_b_sub_c = QN(a, subscript=b_sub_c)
+ self.assertEqual(a_sub_b_sub_c.qn, (a, b_sub_c))
+ self.assertTrue(a_sub_b.is_composite())
+ self.assertTrue(a_sub_b_sub_c.is_composite())
+ self.assertTrue(a_sub_b.has_subscript())
+ self.assertTrue(a_sub_b_sub_c.has_subscript())
+ self.assertEqual(b_sub_c.qn, (b, c))
+ self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
+ self.assertEqual(a_sub_b_sub_c.ssf(), 'a_sub_b_sub_c')
+ self.assertEqual(a_sub_b_sub_c.ast().value.id, 'a')
+ self.assertEqual(a_sub_b_sub_c.ast().slice.value.value.id, 'b')
+ self.assertEqual(a_sub_b_sub_c.ast().slice.value.slice.value.id, 'c')
+ self.assertEqual(b_sub_c.ast().slice.value.id, 'c')
+ self.assertEqual(a_sub_b_sub_c.parent.qn, ('a',))
+ with self.assertRaises(ValueError):
+ QN('a', 'b')
+
+ def test_equality(self):
+ a = QN('a')
+ a2 = QN('a')
+ a_b = QN(a, attr='b')
+ self.assertEqual(a2.qn, ('a',))
+ with self.assertRaises(ValueError):
+ _ = a.parent
+
+ a_b2 = QN(a, attr='b')
+ self.assertEqual(a_b2.qn, (a, 'b'))
+ self.assertEqual(a_b2.parent.qn, ('a',))
+
+ self.assertTrue(a2 == a)
+ self.assertFalse(a2 is a)
+
+ self.assertTrue(a_b.parent == a)
+ self.assertTrue(a_b2.parent == a)
+
+ self.assertTrue(a_b2 == a_b)
+ self.assertFalse(a_b2 is a_b)
+ self.assertFalse(a_b2 == a)
+ a_sub_b = QN(a, subscript='b')
+ a_sub_b2 = QN(a, subscript='b')
+ self.assertTrue(a_sub_b == a_sub_b2)
+ self.assertFalse(a_sub_b == a_b)
+
+ def test_nested_attrs_subscripts(self):
+ a = QN('a')
+ b = QN('b')
+ c = QN('c')
+ b_sub_c = QN(b, subscript=c)
+ a_sub_b_sub_c = QN(a, subscript=b_sub_c)
+
+ b_dot_c = QN(b, attr='c')
+ a_sub__b_dot_c = QN(a, subscript=b_dot_c)
+
+ a_sub_b = QN(a, subscript=b)
+ a_sub_b__dot_c = QN(a_sub_b, attr='c')
+
+ a_dot_b = QN(a, attr='b')
+ a_dot_b_sub_c = QN(a_dot_b, subscript=c)
+
+ self.assertEqual(str(a_sub_b_sub_c), 'a[b[c]]')
+ self.assertEqual(str(a_sub__b_dot_c), 'a[b.c]')
+ self.assertEqual(str(a_sub_b__dot_c), 'a[b].c')
+ self.assertEqual(str(a_dot_b_sub_c), 'a.b[c]')
+
+ self.assertNotEqual(a_sub_b_sub_c, a_sub__b_dot_c)
+ self.assertNotEqual(a_sub_b_sub_c, a_sub_b__dot_c)
+ self.assertNotEqual(a_sub_b_sub_c, a_dot_b_sub_c)
+
+ self.assertNotEqual(a_sub__b_dot_c, a_sub_b__dot_c)
+ self.assertNotEqual(a_sub__b_dot_c, a_dot_b_sub_c)
+
+ self.assertNotEqual(a_sub_b__dot_c, a_dot_b_sub_c)
+
+ def test_hashable(self):
+ d = {QN('a'): 'a', QN('b'): 'b'}
+ self.assertEqual(d[QN('a')], 'a')
+ self.assertEqual(d[QN('b')], 'b')
+ self.assertTrue(QN('c') not in d)
+
+ def test_literals(self):
+ a = QN('a')
+ a_sub_str_b = QN(a, subscript=QN(qual_names.StringLiteral('b')))
+ a_sub_b = QN(a, subscript=QN('b'))
+
+ self.assertNotEqual(a_sub_str_b, a_sub_b)
+ self.assertNotEqual(hash(a_sub_str_b), hash(a_sub_b))
+
+ a_sub_three = QN(a, subscript=QN(qual_names.NumberLiteral(3)))
+ self.assertEqual(a_sub_three.ast().slice.value.n, 3)
+
+ def test_support_set(self):
+ a = QN('a')
+ b = QN('b')
+ c = QN('c')
+ a_sub_b = QN(a, subscript=b)
+ a_dot_b = QN(a, attr='b')
+ a_dot_b_dot_c = QN(a_dot_b, attr='c')
+ a_dot_b_sub_c = QN(a_dot_b, subscript=c)
+
+ self.assertSetEqual(a.support_set, set((a,)))
+ self.assertSetEqual(a_sub_b.support_set, set((a, b)))
+ self.assertSetEqual(a_dot_b.support_set, set((a,)))
+ self.assertSetEqual(a_dot_b_dot_c.support_set, set((a,)))
+ self.assertSetEqual(a_dot_b_sub_c.support_set, set((a, c)))
+
+
+class QNResolverTest(test.TestCase):
+
+ def assertQNStringIs(self, node, qn_str):
+ self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
+
+ def test_resolve(self):
+ samples = """
+ a
+ a.b
+ (c, d.e)
+ [f, (g.h.i)]
+ j(k, l)
+ """
+ nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
+ nodes = tuple(n.value for n in nodes.body)
+
+ self.assertQNStringIs(nodes[0], 'a')
+ self.assertQNStringIs(nodes[1], 'a.b')
+ self.assertQNStringIs(nodes[2].elts[0], 'c')
+ self.assertQNStringIs(nodes[2].elts[1], 'd.e')
+ self.assertQNStringIs(nodes[3].elts[0], 'f')
+ self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
+ self.assertQNStringIs(nodes[4].func, 'j')
+ self.assertQNStringIs(nodes[4].args[0], 'k')
+ self.assertQNStringIs(nodes[4].args[1], 'l')
+
+ def test_subscript_resolve(self):
+ samples = """
+ x[i]
+ x[i.b]
+ a.b[c]
+ a.b[x.y]
+ a[z[c]]
+ a[b[c[d]]]
+ a[b].c
+ a.b.c[d].e.f
+ a.b[c[d]].e.f
+ a.b[c[d.e.f].g].h
+ """
+ nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
+ nodes = tuple(n.value for n in nodes.body)
+
+ self.assertQNStringIs(nodes[0], 'x[i]')
+ self.assertQNStringIs(nodes[1], 'x[i.b]')
+ self.assertQNStringIs(nodes[2], 'a.b[c]')
+ self.assertQNStringIs(nodes[3], 'a.b[x.y]')
+ self.assertQNStringIs(nodes[4], 'a[z[c]]')
+ self.assertQNStringIs(nodes[5], 'a[b[c[d]]]')
+ self.assertQNStringIs(nodes[6], 'a[b].c')
+ self.assertQNStringIs(nodes[7], 'a.b.c[d].e.f')
+ self.assertQNStringIs(nodes[8], 'a.b[c[d]].e.f')
+ self.assertQNStringIs(nodes[9], 'a.b[c[d.e.f].g].h')
+
+ def test_function_calls(self):
+ samples = """
+ a.b
+ a.b()
+ a().b
+ z[i]
+ z[i]()
+ z()[i]
+ """
+ nodes = resolve(parser.parse_str(textwrap.dedent(samples)))
+ nodes = tuple(n.value for n in nodes.body)
+ self.assertQNStringIs(nodes[0], 'a.b')
+ self.assertQNStringIs(nodes[1].func, 'a.b')
+ self.assertQNStringIs(nodes[2].value.func, 'a')
+ self.assertQNStringIs(nodes[3], 'z[i]')
+ self.assertQNStringIs(nodes[4].func, 'z[i]')
+ self.assertQNStringIs(nodes[5].value.func, 'z')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/BUILD b/tensorflow/python/autograph/pyct/static_analysis/BUILD
new file mode 100644
index 0000000000..4a4ccdcbd1
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/BUILD
@@ -0,0 +1,94 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "static_analysis",
+ srcs = [
+ "activity.py",
+ "annos.py",
+ "live_values.py",
+ "liveness.py",
+ "reaching_definitions.py",
+ "type_info.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "activity_test",
+ srcs = ["activity_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":static_analysis",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "live_values_test",
+ srcs = ["live_values_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":static_analysis",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "liveness_test",
+ srcs = ["liveness_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":static_analysis",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "reaching_definitions_test",
+ srcs = ["reaching_definitions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":static_analysis",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ ],
+)
+
+py_test(
+ name = "type_info_test",
+ srcs = ["type_info_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":static_analysis",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
+ ],
+)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/__init__.py b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
new file mode 100644
index 0000000000..9a82de735d
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/__init__.py
@@ -0,0 +1,33 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Static information resolution.
+
+This module contains utilities to help annotate AST nodes with as much runtime
+information as can be possibly extracted without actually executing the code,
+under that assumption that the context in which the code will run is known.
+
+Overall, the different analyses have the functions listed below:
+
+ * activity: inventories symbols read, written to, params, etc. at different
+ levels
+ * liveness, reaching_definitions: dataflow analyses based on the program's CFG
+ and using the symbol information gathered by activity analysis
+ * live_values, type_info: type and value inference based on dataflow
+ analysis and context information
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
new file mode 100644
index 0000000000..9cb5991322
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -0,0 +1,398 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Activity analysis.
+
+Requires qualified name annotations (see qual_names.py).
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
+# TODO(alexbw): Ignore named literals (e.g. None)
+
+
+class Scope(object):
+ """Encloses local symbol definition and usage information.
+
+ This can track for instance whether a symbol is modified in the current scope.
+ Note that scopes do not necessarily align with Python's scopes. For example,
+ the body of an if statement may be considered a separate scope.
+
+ Attributes:
+ modified: identifiers modified in this scope
+ created: identifiers created in this scope
+ used: identifiers referenced in this scope
+ """
+
+ def __init__(self, parent, isolated=True, add_unknown_symbols=False):
+ """Create a new scope.
+
+ Args:
+ parent: A Scope or None.
+ isolated: Whether the scope is isolated, that is, whether variables
+ created in this scope should be visible to the parent scope.
+ add_unknown_symbols: Whether to handle attributed and subscripts
+ without having first seen the base name.
+ E.g., analyzing the statement 'x.y = z' without first having seen 'x'.
+ """
+ self.isolated = isolated
+ self.parent = parent
+ self.add_unknown_symbols = add_unknown_symbols
+ self.modified = set()
+ # TODO(mdan): Completely remove this.
+ self.created = set()
+ self.used = set()
+ self.params = {}
+ self.returned = set()
+
+ # TODO(mdan): Rename to `locals`
+ @property
+ def referenced(self):
+ if not self.isolated and self.parent is not None:
+ return self.used | self.parent.referenced
+ return self.used
+
+ def __repr__(self):
+ return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created),
+ tuple(self.modified))
+
+ def copy_from(self, other):
+ """Recursively copies the contents of this scope from another scope."""
+ if (self.parent is None) != (other.parent is None):
+ raise ValueError('cannot copy scopes of different structures')
+ if other.parent is not None:
+ self.parent.copy_from(other.parent)
+ self.isolated = other.isolated
+ self.modified = copy.copy(other.modified)
+ self.created = copy.copy(other.created)
+ self.used = copy.copy(other.used)
+ self.params = copy.copy(other.params)
+ self.returned = copy.copy(other.returned)
+
+ @classmethod
+ def copy_of(cls, other):
+ if other.parent is not None:
+ parent = cls.copy_of(other.parent)
+ else:
+ parent = None
+ new_copy = cls(parent)
+ new_copy.copy_from(other)
+ return new_copy
+
+ def merge_from(self, other):
+ if (self.parent is None) != (other.parent is None):
+ raise ValueError('cannot merge scopes of different structures')
+ if other.parent is not None:
+ self.parent.merge_from(other.parent)
+ self.modified |= other.modified
+ self.created |= other.created
+ self.used |= other.used
+ self.params.update(other.params)
+ self.returned |= other.returned
+
+ def has(self, name):
+ if name in self.modified:
+ return True
+ elif self.parent is not None:
+ return self.parent.has(name)
+ return False
+
+ def mark_read(self, name):
+ self.used.add(name)
+ if self.parent is not None and name not in self.created:
+ self.parent.mark_read(name)
+
+ def mark_param(self, name, owner):
+ self.params[name] = owner
+
+ def mark_creation(self, name, writes_create_symbol=False):
+ """Mark a qualified name as created."""
+ if name.is_composite():
+ parent = name.parent
+ if not writes_create_symbol:
+ return
+ else:
+ if not self.has(parent):
+ if self.add_unknown_symbols:
+ self.mark_read(parent)
+ else:
+ raise ValueError('Unknown symbol "%s".' % parent)
+ self.created.add(name)
+
+ def mark_write(self, name):
+ """Marks the given symbol as modified in the current scope."""
+ self.modified.add(name)
+ if self.isolated:
+ self.mark_creation(name)
+ else:
+ if self.parent is None:
+ self.mark_creation(name)
+ else:
+ if not self.parent.has(name):
+ self.mark_creation(name)
+ self.parent.mark_write(name)
+
+ def mark_returned(self, name):
+ self.returned.add(name)
+ if not self.isolated and self.parent is not None:
+ self.parent.mark_returned(name)
+
+
+class ActivityAnalyzer(transformer.Base):
+ """Annotates nodes with local scope information.
+
+ See Scope.
+
+ The use of this class requires that qual_names.resolve() has been called on
+ the node. This class will ignore nodes have not been
+ annotated with their qualified names.
+ """
+
+ def __init__(self, context, parent_scope=None, add_unknown_symbols=False):
+ super(ActivityAnalyzer, self).__init__(context)
+ self.scope = Scope(parent_scope, None, add_unknown_symbols)
+ self._in_return_statement = False
+ self._in_aug_assign = False
+
+ @property
+ def _in_constructor(self):
+ if len(self.enclosing_entities) > 1:
+ innermost = self.enclosing_entities[-1]
+ parent = self.enclosing_entities[-2]
+ return isinstance(parent, gast.ClassDef) and innermost.name == '__init__'
+ return False
+
+ def _node_sets_self_attribute(self, node):
+ if anno.hasanno(node, anno.Basic.QN):
+ qn = anno.getanno(node, anno.Basic.QN)
+ # TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
+ if qn.has_attr and qn.parent.qn == ('self',):
+ return True
+ return False
+
+ def _track_symbol(self,
+ node,
+ composite_writes_alter_parent=False,
+ writes_create_symbol=False):
+ # A QN may be missing when we have an attribute (or subscript) on a function
+ # call. Example: a().b
+ if not anno.hasanno(node, anno.Basic.QN):
+ return
+ qn = anno.getanno(node, anno.Basic.QN)
+
+ if isinstance(node.ctx, gast.Store):
+ self.scope.mark_write(qn)
+ if qn.is_composite and composite_writes_alter_parent:
+ self.scope.mark_write(qn.parent)
+ if writes_create_symbol:
+ self.scope.mark_creation(qn, writes_create_symbol=True)
+ if self._in_aug_assign:
+ self.scope.mark_read(qn)
+ elif isinstance(node.ctx, gast.Load):
+ self.scope.mark_read(qn)
+ elif isinstance(node.ctx, gast.Param):
+ # Param contexts appear in function defs, so they have the meaning of
+ # defining a variable.
+ self.scope.mark_write(qn)
+ self.scope.mark_param(qn, self.enclosing_entities[-1])
+ else:
+ raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
+
+ anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
+
+ if self._in_return_statement:
+ self.scope.mark_returned(qn)
+
+ def _enter_scope(self, isolated):
+ self.scope = Scope(self.scope, isolated=isolated)
+
+ def _exit_scope(self):
+ self.scope = self.scope.parent
+
+ def _process_statement(self, node):
+ self._enter_scope(False)
+ node = self.generic_visit(node)
+ anno.setanno(node, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+ return node
+
+ def visit_Expr(self, node):
+ return self._process_statement(node)
+
+ def visit_Return(self, node):
+ self._in_return_statement = True
+ node = self._process_statement(node)
+ self._in_return_statement = False
+ return node
+
+ def visit_Assign(self, node):
+ return self._process_statement(node)
+
+ def visit_AugAssign(self, node):
+ # Special rules for AugAssign. In Assign, the target is only written,
+ # but in AugAssig (e.g. a += b), the target is both read and written.
+ self._in_aug_assign = True
+ node = self._process_statement(node)
+ self._in_aug_assign = False
+ return node
+
+ def visit_Name(self, node):
+ node = self.generic_visit(node)
+ self._track_symbol(node)
+ return node
+
+ def visit_Attribute(self, node):
+ node = self.generic_visit(node)
+ if self._in_constructor and self._node_sets_self_attribute(node):
+ self._track_symbol(
+ node, composite_writes_alter_parent=True, writes_create_symbol=True)
+ else:
+ self._track_symbol(node)
+ return node
+
+ def visit_Subscript(self, node):
+ node = self.generic_visit(node)
+ # Subscript writes (e.g. a[b] = "value") are considered to modify
+ # both the element itself (a[b]) and its parent (a).
+ self._track_symbol(node)
+ return node
+
+ def visit_Print(self, node):
+ self._enter_scope(False)
+ node.values = self.visit_block(node.values)
+ anno.setanno(node, anno.Static.SCOPE, self.scope)
+ anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
+ self._exit_scope()
+ return node
+
+ def visit_Assert(self, node):
+ return self._process_statement(node)
+
+ def visit_Call(self, node):
+ self._enter_scope(False)
+ node.args = self.visit_block(node.args)
+ node.keywords = self.visit_block(node.keywords)
+ # TODO(mdan): Account starargs, kwargs
+ anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
+ self._exit_scope()
+ node.func = self.visit(node.func)
+ return node
+
+ def _process_block_node(self, node, block, scope_name):
+ self._enter_scope(False)
+ block = self.visit_block(block)
+ anno.setanno(node, scope_name, self.scope)
+ self._exit_scope()
+ return node
+
+ def _process_parallel_blocks(self, parent, children):
+ # Because the scopes are not isolated, processing any child block
+ # modifies the parent state causing the other child blocks to be
+ # processed incorrectly. So we need to checkpoint the parent scope so that
+ # each child sees the same context.
+ before_parent = Scope.copy_of(self.scope)
+ after_children = []
+ for child, scope_name in children:
+ self.scope.copy_from(before_parent)
+ parent = self._process_block_node(parent, child, scope_name)
+ after_child = Scope.copy_of(self.scope)
+ after_children.append(after_child)
+ for after_child in after_children:
+ self.scope.merge_from(after_child)
+ return parent
+
+ def visit_arguments(self, node):
+ return self._process_statement(node)
+
+ def visit_FunctionDef(self, node):
+ # The FunctionDef node itself has a Scope object that tracks the creation
+ # of its name, along with the usage of any decorator accompany it.
+ self._enter_scope(False)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ self.scope.mark_write(qual_names.QN(node.name))
+ anno.setanno(node, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+
+ # A separate Scope tracks the actual function definition.
+ self._enter_scope(True)
+ node.args = self.visit(node.args)
+
+ # Track the body separately. This is for compatibility reasons, it may not
+ # be strictly needed.
+ self._enter_scope(False)
+ node.body = self.visit_block(node.body)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
+ self._exit_scope()
+
+ self._exit_scope()
+ return node
+
+ def visit_With(self, node):
+ self._enter_scope(False)
+ node = self.generic_visit(node)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
+ self._exit_scope()
+ return node
+
+ def visit_withitem(self, node):
+ return self._process_statement(node)
+
+ def visit_If(self, node):
+ self._enter_scope(False)
+ node.test = self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
+ anno.setanno(node.test, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+ node = self._process_parallel_blocks(node,
+ ((node.body, NodeAnno.BODY_SCOPE),
+ (node.orelse, NodeAnno.ORELSE_SCOPE)))
+ return node
+
+ def visit_For(self, node):
+ self._enter_scope(False)
+ node.target = self.visit(node.target)
+ node.iter = self.visit(node.iter)
+ anno.setanno(node.iter, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+ node = self._process_parallel_blocks(node,
+ ((node.body, NodeAnno.BODY_SCOPE),
+ (node.orelse, NodeAnno.ORELSE_SCOPE)))
+ return node
+
+ def visit_While(self, node):
+ self._enter_scope(False)
+ node.test = self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
+ anno.setanno(node.test, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+ node = self._process_parallel_blocks(node,
+ ((node.body, NodeAnno.BODY_SCOPE),
+ (node.orelse, NodeAnno.ORELSE_SCOPE)))
+ return node
+
+
+def resolve(node, context, parent_scope=None):
+ return ActivityAnalyzer(context, parent_scope).visit(node)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
new file mode 100644
index 0000000000..d4a6ce8ac3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -0,0 +1,508 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for activity module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.qual_names import QN
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+from tensorflow.python.platform import test
+
+
+class ScopeTest(test.TestCase):
+
+ def test_basic(self):
+ scope = activity.Scope(None)
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_read(QN('foo'))
+ self.assertFalse(scope.has(QN('foo')))
+
+ scope.mark_write(QN('foo'))
+ self.assertTrue(scope.has(QN('foo')))
+
+ scope.mark_read(QN('bar'))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_copy_from(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+
+ other = activity.Scope(None)
+ other.copy_from(scope)
+
+ self.assertTrue(QN('foo') in other.modified)
+
+ scope.mark_write(QN('bar'))
+ scope.copy_from(other)
+
+ self.assertFalse(QN('bar') in scope.modified)
+
+ scope.mark_write(QN('bar'))
+ scope.merge_from(other)
+
+ self.assertTrue(QN('bar') in scope.modified)
+ self.assertFalse(QN('bar') in other.modified)
+
+ def test_copy_of(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('foo'))
+
+ self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used)
+
+ child_scope = activity.Scope(scope)
+ child_scope.mark_read(QN('bar'))
+
+ self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used)
+
+ def test_nesting(self):
+ scope = activity.Scope(None)
+ scope.mark_write(QN('foo'))
+ scope.mark_read(QN('bar'))
+
+ child = activity.Scope(scope)
+ self.assertTrue(child.has(QN('foo')))
+ self.assertTrue(scope.has(QN('foo')))
+
+ child.mark_write(QN('bar'))
+ self.assertTrue(child.has(QN('bar')))
+ self.assertFalse(scope.has(QN('bar')))
+
+ def test_referenced(self):
+ scope = activity.Scope(None)
+ scope.mark_read(QN('a'))
+
+ child = activity.Scope(scope)
+ child.mark_read(QN('b'))
+
+ child2 = activity.Scope(child, isolated=False)
+ child2.mark_read(QN('c'))
+
+ self.assertTrue(QN('c') in child2.referenced)
+ self.assertTrue(QN('b') in child2.referenced)
+ self.assertFalse(QN('a') in child2.referenced)
+
+ self.assertTrue(QN('c') in child.referenced)
+ self.assertTrue(QN('b') in child.referenced)
+ self.assertFalse(QN('a') in child.referenced)
+
+
+class ActivityAnalyzerTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ return node, entity_info
+
+ def test_local_markers(self):
+
+ def test_fn(a): # pylint:disable=unused-argument
+ b = c # pylint:disable=undefined-variable
+ while b > 0:
+ b -= 1
+ return b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ self.assertFalse(
+ anno.getanno(node.body[0].body[0].value,
+ NodeAnno.IS_LOCAL)) # c in b = c
+ self.assertTrue(
+ anno.getanno(node.body[0].body[1].test.left,
+ NodeAnno.IS_LOCAL)) # b in b > 0
+ self.assertTrue(
+ anno.getanno(node.body[0].body[2].value,
+ NodeAnno.IS_LOCAL)) # b in return b
+
+ def assertSymbolSetsAre(self, expected, actual, name):
+ expected = set(expected)
+ actual = set(str(s) for s in actual)
+ self.assertSetEqual(
+ expected, actual, 'for symbol set: %s\n'
+ ' Expected: %s\n'
+ ' Got: %s\n'
+ ' Missing: %s\n'
+ ' Extra: %s\n' % (name.upper(), expected, actual,
+ expected - actual, actual - expected))
+
+ def assertScopeIsRmc(self, scope, used, modified, created):
+ """Assert the scope contains specific used, modified & created variables."""
+ self.assertSymbolSetsAre(used, scope.used, 'read')
+ self.assertSymbolSetsAre(modified, scope.modified, 'modified')
+ # Created is deprecated, we're no longer verifying it.
+ # self.assertSymbolSetsAre(created, scope.created, 'created')
+
+ def test_print_statement(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ print(a, b)
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ print_node = node.body[0].body[2]
+ if isinstance(print_node, gast.Print):
+ # Python 2
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ else:
+ # Python 3
+ assert isinstance(print_node, gast.Expr)
+ # The call node should be the one being annotated.
+ print_node = print_node.value
+ print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ())
+
+ def test_call_args(self):
+
+ def test_fn(a):
+ b = 0
+ c = 1
+ foo(a, b) # pylint:disable=undefined-variable
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ # We basically need to detect which variables are captured by the call
+ # arguments.
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+
+ def test_call_args_attributes(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ a.c = 0
+ foo(a.b, a.c)
+ return a.d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[1].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a.b', 'a.c'),
+ (),
+ (),
+ )
+
+ def test_call_args_subscripts(self):
+
+ def foo(*_):
+ pass
+
+ def test_fn(a):
+ b = 1
+ c = 2
+ foo(a[0], a[b])
+ return a[c]
+
+ node, _ = self._parse_and_analyze(test_fn)
+ call_node = node.body[0].body[2].value
+ self.assertScopeIsRmc(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
+ ('a', 'a[0]', 'a[b]', 'b'),
+ (),
+ (),
+ )
+
+ def test_while(self):
+
+ def test_fn(a):
+ b = a
+ while b > 0:
+ c = b
+ b -= 1
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ while_node = node.body[0].body[1]
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
+ ('c',))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c'), ('a', 'b', 'c'))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
+
+ def test_for(self):
+
+ def test_fn(a):
+ b = a
+ for _ in a:
+ c = b
+ b -= 1
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ for_node = node.body[0].body[1]
+ self.assertScopeIsRmc(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
+ self.assertScopeIsRmc(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
+ ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+
+ def test_if(self):
+
+ def test_fn(x):
+ if x > 0:
+ x = -x
+ y = 2 * x
+ z = -y
+ else:
+ x = 2 * x
+ y = -x
+ u = -y
+ return z, u
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
+ ('y', 'z'))
+ # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
+ ('x', 'y', 'u'), ('y', 'u'))
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
+
+ def test_if_attributes(self):
+
+ def test_fn(a):
+ if a > 0:
+ a.b = -a.c
+ d = 2 * a
+ else:
+ a.b = a.c
+ d = 1
+ return d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'a.c'),
+ ('a.b', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
+ ('a', 'a.c', 'd'),
+ ('a.b', 'd'),
+ ('a', 'd'),
+ )
+
+ def test_if_subscripts(self):
+
+ def test_fn(a, b, c, e):
+ if a > 0:
+ a[b] = -a[c]
+ d = 2 * a
+ else:
+ a[0] = e
+ d = 1
+ return d
+
+ node, _ = self._parse_and_analyze(test_fn)
+ if_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b', 'c', 'a[c]'),
+ ('a[b]', 'd'),
+ ('d',),
+ )
+ # TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
+ ('a', 'e'),
+ ('a[0]', 'd'),
+ ('d',),
+ )
+ self.assertScopeIsRmc(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
+ ('a', 'b', 'c', 'd', 'e', 'a[c]'),
+ ('d', 'a[b]', 'a[0]'),
+ ('a', 'b', 'c', 'd', 'e'),
+ )
+
+ def test_nested_if(self):
+
+ def test_fn(b):
+ if b > 0:
+ if b < 5:
+ a = b
+ else:
+ a = b * b
+ return a
+
+ node, _ = self._parse_and_analyze(test_fn)
+ inner_if_node = node.body[0].body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',),
+ ('a',))
+ self.assertScopeIsRmc(
+ anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',),
+ ('a',))
+
+ def test_nested_function(self):
+
+ def test_fn(a):
+
+ def f(x):
+ y = x * x
+ return y
+
+ b = a
+ for i in a:
+ c = b
+ b -= f(i)
+ return b, c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_def_node = node.body[0].body[0]
+
+ self.assertScopeIsRmc(
+ anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
+ 'x',
+ 'y',
+ ))
+
+ def test_constructor_attributes(self):
+
+ class TestClass(object):
+
+ def __init__(self, a):
+ self.b = a
+ self.b.c = 1
+
+ node, _ = self._parse_and_analyze(TestClass)
+ init_node = node.body[0].body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(init_node, NodeAnno.BODY_SCOPE),
+ ('self', 'a', 'self.b'),
+ ('self', 'self.b', 'self.b.c'),
+ ('self', 'a', 'self.b'),
+ )
+
+ def test_aug_assign_subscripts(self):
+
+ def test_fn(a):
+ a[0] += 1
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a', 'a[0]'),
+ ('a[0]',),
+ ('a',),
+ )
+
+ def test_return_vars_are_read(self):
+
+ def test_fn(a, b, c): # pylint: disable=unused-argument
+ return c
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('c',),
+ (),
+ (
+ 'a',
+ 'b',
+ 'c',
+ ),
+ )
+
+ def test_aug_assign(self):
+
+ def test_fn(a, b):
+ a += b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('a', 'b'),
+ ('a'),
+ ('a', 'b'),
+ )
+
+ def test_aug_assign_rvalues(self):
+
+ a = dict(bar=3)
+
+ def foo():
+ return a
+
+ def test_fn(x):
+ foo()['bar'] += x
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
+ ('foo', 'x'),
+ (),
+ ('x',),
+ )
+
+ def test_params_created(self):
+
+ def test_fn(a, b): # pylint: disable=unused-argument
+ return b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ self.assertScopeIsRmc(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')),
+ (('a', 'b')))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/annos.py b/tensorflow/python/autograph/pyct/static_analysis/annos.py
new file mode 100644
index 0000000000..5eefecf278
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/annos.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Annotations used by the static analyzer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from enum import Enum
+
+
+# TODO(mdan): Remove.
+
+
+class NoValue(Enum):
+
+ def __repr__(self):
+ return self.name
+
+
+class NodeAnno(NoValue):
+ """Additional annotations used by the static analyzer.
+
+ These are in addition to the basic annotations declared in anno.py.
+ """
+
+ # Symbols
+ # These flags are boolean.
+ IS_LOCAL = 'Symbol is local to the function scope being analyzed.'
+ IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
+ IS_MODIFIED_SINCE_ENTRY = (
+ 'Symbol has been explicitly replaced in the current function scope.')
+
+ # Scopes
+ # Scopes are represented by objects of type activity.Scope.
+ ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
+ BODY_SCOPE = (
+ 'The scope for the main body of a statement (True branch for if '
+ 'statements, main body for loops).')
+ ORELSE_SCOPE = (
+ 'The scope for the orelse body of a statement (False branch for if '
+ 'statements, orelse body for loops).')
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
new file mode 100644
index 0000000000..48b442f3bd
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -0,0 +1,137 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Live value resolution.
+
+Live values are extracted from the known execution context.
+
+Requires activity and reaching definitions analyses.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
+
+# TODO(aqj): Do we need this? Do other builtins fail in similar ways
+# See b/114389775 for a related bug in pyct
+# These symbols are legal in Python, but don't appear in the namespace.
+_special_symbols = {'range': range}
+
+
+class LiveValueResolver(transformer.Base):
+ """Annotates nodes with live values."""
+
+ def __init__(self, context, literals):
+ super(LiveValueResolver, self).__init__(context)
+ self.literals = literals
+
+ def visit_ClassDef(self, node):
+ self.generic_visit(node)
+ anno.setanno(node, 'live_val', self.entity_info.namespace[node.name])
+ return node
+
+ def visit_Name(self, node):
+ self.generic_visit(node)
+ if isinstance(node.ctx, gast.Load):
+ defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
+
+ is_defined = bool(defs)
+ has_single_def = len(defs) == 1
+
+ if not is_defined:
+ if node.id in self.literals:
+ anno.setanno(node, 'live_val', self.literals[node.id])
+ elif node.id in self.entity_info.namespace:
+ obj = self.entity_info.namespace[node.id]
+ anno.setanno(node, 'live_val', obj)
+ if hasattr(obj, '__name__'):
+ anno.setanno(node, 'fqn', (obj.__name__,))
+ elif hasattr(obj, '__class__'):
+ obj_class = obj.__class__
+ anno.setanno(node, 'fqn',
+ (obj_class.__module__, obj_class.__name__))
+ else:
+ # If the symbol value is for example a primitive, then it will not
+ # have a name.
+ pass
+ elif node.id in _special_symbols:
+ anno.setanno(node, 'live_val', _special_symbols[node.id])
+ else:
+ pass
+ # TODO(mdan): Should we raise an error here?
+ # Can encounter this when:
+ # * a symbol truly lacks reference
+ # * a symbol is new, like the new name of a function we just renamed.
+ else:
+ pass
+ # TODO(mdan): Attempt to trace its value through the local chain.
+ # TODO(mdan): Use type annotations as fallback.
+
+ if has_single_def:
+ def_, = defs
+ if def_.param_of is self.enclosing_entities[0]:
+ if node.id in self.entity_info.arg_values:
+ obj = self.entity_info.arg_values[node.id]
+ anno.setanno(node, 'live_val', obj)
+ anno.setanno(node, 'fqn', (obj.__class__.__name__,))
+ return node
+
+ def visit_Attribute(self, node):
+ self.generic_visit(node)
+ if anno.hasanno(node.value, 'live_val'):
+ assert anno.hasanno(node.value, 'fqn')
+ parent_object = anno.getanno(node.value, 'live_val')
+
+ anno.setanno(node, 'parent_type', type(parent_object))
+ anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
+ if hasattr(parent_object, node.attr):
+ # This can happen when the attribute's creation and use depend on the
+ # same static condition, for example:
+ #
+ # if cond:
+ # foo.bar = baz
+ # if cond:
+ # x = foo.bar
+ #
+ anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
+
+ # TODO(mdan): Investigate the role built-in annotations can play here.
+ elif anno.hasanno(node.value, 'type'):
+ parent_type = anno.getanno(node.value, 'type')
+ if hasattr(parent_type, node.attr):
+ # This should hold for static members like methods.
+ # This would not hold for dynamic members like function attributes.
+ # For the dynamic case, we simply leave the node without an annotation,
+ # and let downstream consumers figure out what to do.
+ anno.setanno(node, 'parent_type', parent_type)
+ anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
+ anno.setanno(node, 'fqn',
+ anno.getanno(node.value, 'type_fqn') + (node.attr,))
+ elif isinstance(node.value, gast.Name):
+ stem_name = node.value
+ # All nonlocal symbols should be fully resolved.
+ assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
+ # TODO(mdan): Figure out what to do when calling attribute on local object
+ # Maybe just leave as-is?
+ return node
+
+
+def resolve(node, context, literals):
+ return LiveValueResolver(context, literals).visit(node)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
new file mode 100644
index 0000000000..882c380b78
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values_test.py
@@ -0,0 +1,132 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for live_values module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class LiveValuesResolverTest(test.TestCase):
+
+ def _parse_and_analyze(self,
+ test_fn,
+ namespace,
+ literals=None,
+ arg_types=None):
+ literals = literals or {}
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ graphs = cfg.build(node)
+ node = activity.resolve(node, entity_info)
+ node = reaching_definitions.resolve(node, entity_info, graphs,
+ reaching_definitions.Definition)
+ node = live_values.resolve(node, entity_info, literals)
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, literals)
+ return node
+
+ def test_literals(self):
+
+ a = None
+
+ def test_fn():
+ return a
+
+ node = self._parse_and_analyze(test_fn, {}, literals={'a': 'bar'})
+ retval_node = node.body[0].body[0].value
+ self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
+
+ def test_primitive_values(self):
+
+ a = None
+
+ def test_fn():
+ return a
+
+ node = self._parse_and_analyze(test_fn, {'a': True})
+ retval_node = node.body[0].body[0].value
+ if six.PY2:
+ self.assertEqual(
+ anno.getanno(retval_node, 'fqn'), ('__builtin__', 'bool'))
+ else:
+ self.assertEqual(anno.getanno(retval_node, 'fqn'), ('builtins', 'bool'))
+
+ def test_namespace(self):
+
+ def foo():
+ return 'bar'
+
+ def test_fn():
+ return foo()
+
+ node = self._parse_and_analyze(test_fn, {'foo': foo})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
+ self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
+
+ def test_attribute_names(self):
+
+ def test_fn():
+ return constant_op.constant(0)
+
+ node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
+ self.assertEquals((constant_op.__name__, 'constant'),
+ anno.getanno(func_node, 'fqn'))
+
+ def test_attributes_with_type_hints(self):
+
+ class TestClass(object):
+
+ def member(self):
+ pass
+
+ def test_fn(self):
+ return self.member()
+
+ node = self._parse_and_analyze(
+ TestClass.test_fn, {'constant_op': constant_op},
+ arg_types={'self': (TestClass.__name__, TestClass)})
+ func_node = node.body[0].body[0].value.func
+ self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val'))
+ self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type'))
+ self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
new file mode 100644
index 0000000000..41c903beb9
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
@@ -0,0 +1,200 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Live variable analysis.
+
+This analysis attaches a set containing the live symbols that are live at the
+exit of control flow statements.
+
+Requires activity analysis.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
+
+
+class Analyzer(cfg.GraphVisitor):
+ """CFG visitor that performs liveness analysis at statement level."""
+
+ def __init__(self, graph):
+ super(Analyzer, self).__init__(graph)
+ # This allows communicating that nodes generate extra symbols,
+ # e.g. those that a function definition closes over.
+ self.extra_gen = {}
+
+ def init_state(self, _):
+ return set()
+
+ def visit_node(self, node):
+ prev_live_in = self.in_[node]
+
+ if anno.hasanno(node.ast_node, anno.Static.SCOPE):
+ node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
+
+ gen = node_scope.used | self.extra_gen.get(node.ast_node, frozenset())
+ # TODO(mdan): verify whether composites' parents need to be added.
+ # E.g. if x.y is live whether x needs to be added. Theoretically the
+ # activity analysis should have both so that wouldn't be needed.
+ kill = node_scope.modified
+
+ live_out = set()
+ for n in node.next:
+ live_out |= self.in_[n]
+ live_in = gen | (live_out - kill)
+
+ else:
+ # Nodes that don't have a scope annotation are assumed not to touch any
+ # symbols.
+ # This Name node below is a literal name, e.g. False
+ assert isinstance(node.ast_node,
+ (gast.Name, gast.Continue, gast.Break)), type(
+ node.ast_node)
+ live_in = prev_live_in
+ live_out = live_in
+
+ self.in_[node] = live_in
+ self.out[node] = live_out
+
+ # TODO(mdan): Move this to the superclass?
+ return prev_live_in != live_in
+
+
+class WholeTreeAnalyzer(transformer.Base):
+ """Runs liveness analysis on each of the functions defined in the AST.
+
+ If a function defined other local functions, those will have separate CFGs.
+ However, dataflow analysis needs to tie up these CFGs to properly emulate the
+ effect of closures. In the case of liveness, the parent function's live
+ variables must account for the variables that are live at the entry of each
+ subfunction. For example:
+
+ def foo():
+ # baz is live here
+ def bar():
+ print(baz)
+
+ This analyzer runs liveness analysis on each individual function, accounting
+ for the effect above.
+ """
+
+ def __init__(self, source_info, graphs):
+ super(WholeTreeAnalyzer, self).__init__(source_info)
+ self.graphs = graphs
+ self.current_analyzer = None
+ self.analyzers = {}
+
+ def visit_FunctionDef(self, node):
+ parent_analyzer = self.current_analyzer
+ subgraph = self.graphs[node]
+
+ # Postorder tree processing makes this a bit complicated:
+ # 1. construct an analyzer object and put it on stack
+ # 2. recursively walk the subtree; this will initialize the analyzer's
+ # in_ state properly (done in a block below)
+ # 3. run the final analysis
+ analyzer = Analyzer(subgraph)
+ self.current_analyzer = analyzer
+ node = self.generic_visit(node)
+ analyzer.visit_reverse()
+
+ if parent_analyzer is not None:
+ # Wire the state between the two subgraphs' analyzers.
+ child_in_state = analyzer.in_[subgraph.entry]
+ # Exception: symbols modified in the child function are local to it
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ for qn in body_scope.modified:
+ # Note: a function modifying the symbol doesn't make that symbol
+ # live at the function's entry. In fact when that happens it is
+ # probably a case of undefined assignment, like this:
+ #
+ # bar = 0
+ # def foo():
+ # print(bar) # bar is undefined here!
+ # bar = 1
+ #
+ # Hence we use discard and not remove below.
+ child_in_state.discard(qn)
+ parent_analyzer.extra_gen[node] = frozenset(child_in_state,)
+
+ self.analyzers[node] = analyzer
+ self.current_analyzer = parent_analyzer
+ return node
+
+ def visit_nonlocal(self, node):
+ raise NotImplementedError()
+
+ def visit_global(self, node):
+ raise NotImplementedError()
+
+
+class Annotator(transformer.Base):
+ """AST visitor that annotates each control flow block with live symbols."""
+
+ # Note: additional nodes may be added as needed.
+
+ def __init__(self, source_info, cross_function_analyzer):
+ super(Annotator, self).__init__(source_info)
+ self.cross_function_analyzer = cross_function_analyzer
+ self.current_analyzer = None
+
+ def visit_FunctionDef(self, node):
+ parent_analyzer = self.current_analyzer
+ self.current_analyzer = self.cross_function_analyzer.analyzers[node]
+
+ node = self.generic_visit(node)
+ self.current_analyzer = parent_analyzer
+ return node
+
+ def _aggregate_successors_live_in(self, node):
+ successors = self.current_analyzer.graph.stmt_next[node]
+ node_live_out = set()
+ for s in successors:
+ node_live_out.update(self.current_analyzer.in_[s])
+ anno.setanno(node, anno.Static.LIVE_VARS_OUT, frozenset(node_live_out))
+ node = self.generic_visit(node)
+ return node
+
+ def visit_If(self, node):
+ return self._aggregate_successors_live_in(node)
+
+ def visit_For(self, node):
+ return self._aggregate_successors_live_in(node)
+
+ def visit_While(self, node):
+ return self._aggregate_successors_live_in(node)
+
+
+def resolve(node, source_info, graphs):
+ """Resolves the live symbols at the exit of control flow statements.
+
+ Args:
+ node: ast.AST
+ source_info: transformer.SourceInfo
+ graphs: Dict[ast.FunctionDef, cfg.Graph]
+ Returns:
+ ast.AST
+ """
+ cross_function_analyzer = WholeTreeAnalyzer(source_info, graphs)
+ node = cross_function_analyzer.visit(node)
+ visitor = Annotator(source_info, cross_function_analyzer)
+ node = visitor.visit(node)
+ return node
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
new file mode 100644
index 0000000000..0d5f369e92
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -0,0 +1,149 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for liveness module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import liveness
+from tensorflow.python.platform import test
+
+
+class LivenessTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ graphs = cfg.build(node)
+ liveness.resolve(node, entity_info, graphs)
+ return node
+
+ def assertHasLiveOut(self, node, expected):
+ live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+ live_out_str = set(str(v) for v in live_out)
+ if not expected:
+ expected = ()
+ if not isinstance(expected, tuple):
+ expected = (expected,)
+ self.assertSetEqual(live_out_str, set(expected))
+
+ def test_stacked_if(self):
+
+ def test_fn(x, a):
+ if a > 0:
+ x = 0
+ if a > 1:
+ x = 1
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], ('a', 'x'))
+ self.assertHasLiveOut(fn_body[1], 'x')
+
+ def test_stacked_if_else(self):
+
+ def test_fn(x, a):
+ if a > 0:
+ x = 0
+ if a > 1:
+ x = 1
+ else:
+ x = 2
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'a')
+ self.assertHasLiveOut(fn_body[1], 'x')
+
+ def test_for_basic(self):
+
+ def test_fn(x, a):
+ for i in range(a):
+ x += i
+ return x
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'x')
+
+ def test_attributes(self):
+
+ def test_fn(x, a):
+ if a > 0:
+ x.y = 0
+ return x.y
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], ('x.y', 'x'))
+
+ def test_nested_functions(self):
+
+ def test_fn(a, b):
+ if b:
+ a = []
+
+ def foo():
+ return a
+
+ foo()
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'a')
+
+ def test_nested_functions_isolation(self):
+
+ def test_fn(b):
+ if b:
+ a = 0 # pylint:disable=unused-variable
+
+ def child():
+ max(a) # pylint:disable=used-before-assignment
+ a = 1
+ return a
+
+ child()
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasLiveOut(fn_body[0], 'max')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
new file mode 100644
index 0000000000..9aaf318a9f
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
@@ -0,0 +1,301 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Reaching definition analysis.
+
+This analysis attaches a set of a Definition objects to each symbol, one
+for each distinct definition that may reach it. The Definition objects are
+mutable and may be used by subsequent analyses to further annotate data like
+static type and value information.
+The analysis also attaches the set of the symbols defined at the entry of
+control flow statements.
+
+Requires activity analysis.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import annos
+
+
+class Definition(object):
+ """Definition objects describe a unique definition of a variable.
+
+ Subclasses of this may be used by passing an appropriate factory function to
+ resolve.
+
+ Attributes:
+ param_of: Optional[ast.AST]
+ """
+
+ def __init__(self):
+ self.param_of = None
+
+ def __repr__(self):
+ return '%s[%d]' % (self.__class__.__name__, id(self))
+
+
+class _NodeState(object):
+ """Abstraction for the state of the CFG walk for reaching definition analysis.
+
+ This is a value type. Only implements the strictly necessary operators.
+
+ Attributes:
+ value: Dict[qual_names.QN, Set[Definition, ...]], the defined symbols and
+ their possible definitions
+ """
+
+ def __init__(self, init_from=None):
+ if init_from:
+ if isinstance(init_from, _NodeState):
+ self.value = {
+ s: set(other_infos) for s, other_infos in init_from.value.items()
+ }
+ elif isinstance(init_from, dict):
+ self.value = {s: set((init_from[s],)) for s in init_from}
+ else:
+ assert False, init_from
+ else:
+ self.value = {}
+
+ def __eq__(self, other):
+ if frozenset(self.value.keys()) != frozenset(other.value.keys()):
+ return False
+ ret = all(self.value[s] == other.value[s] for s in self.value)
+ return ret
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __or__(self, other):
+ assert isinstance(other, _NodeState)
+ result = _NodeState(self)
+ for s, other_infos in other.value.items():
+ if s in result.value:
+ result.value[s].update(other_infos)
+ else:
+ result.value[s] = set(other_infos)
+ return result
+
+ def __sub__(self, other):
+ assert isinstance(other, set)
+ result = _NodeState(self)
+ for s in other:
+ result.value.pop(s, None)
+ return result
+
+ def __repr__(self):
+ return 'NodeState[%s]=%s' % (id(self), repr(self.value))
+
+
+class Analyzer(cfg.GraphVisitor):
+ """CFG visitor that determines reaching definitions at statement level."""
+
+ def __init__(self, graph, definition_factory):
+ self._definition_factory = definition_factory
+ super(Analyzer, self).__init__(graph)
+ # This allows communicating that nodes have extra reaching definitions,
+ # e.g. those that a function closes over.
+ self.extra_in = {}
+
+ self.gen_map = {}
+
+ def init_state(self, _):
+ return _NodeState()
+
+ def visit_node(self, node):
+ prev_defs_out = self.out[node]
+
+ defs_in = _NodeState(self.extra_in.get(node.ast_node, None))
+ for n in node.prev:
+ defs_in |= self.out[n]
+
+ if anno.hasanno(node.ast_node, anno.Static.SCOPE):
+ node_scope = anno.getanno(node.ast_node, anno.Static.SCOPE)
+ # The definition objects created by each node must be singletons because
+ # their ids are used in equality checks.
+ if node not in self.gen_map:
+ node_symbols = {}
+ for s in node_scope.modified:
+ def_ = self._definition_factory()
+ if s in node_scope.params:
+ def_.param_of = node_scope.params[s]
+ node_symbols[s] = def_
+ self.gen_map[node] = _NodeState(node_symbols)
+
+ gen = self.gen_map[node]
+ kill = node_scope.modified
+ defs_out = gen | (defs_in - kill)
+
+ else:
+ # Nodes that don't have a scope annotation are assumed not to touch any
+ # symbols.
+ # This Name node below is a literal name, e.g. False
+ # This can also happen if activity.py forgot to annotate the node with a
+ # scope object.
+ assert isinstance(
+ node.ast_node,
+ (gast.Name, gast.Break, gast.Continue, gast.Raise)), (node.ast_node,
+ node)
+ defs_out = defs_in
+
+ self.in_[node] = defs_in
+ self.out[node] = defs_out
+
+ # TODO(mdan): Move this to the superclass?
+ return prev_defs_out != defs_out
+
+
+class TreeAnnotator(transformer.Base):
+ """AST visitor that annotates each symbol name with its reaching definitions.
+
+ Simultaneously, the visitor runs the dataflow analysis on each function node,
+ accounting for the effect of closures. For example:
+
+ def foo():
+ bar = 1
+ def baz():
+ # bar = 1 reaches here
+ """
+
+ def __init__(self, source_info, graphs, definition_factory):
+ super(TreeAnnotator, self).__init__(source_info)
+ self.definition_factory = definition_factory
+ self.graphs = graphs
+ self.current_analyzer = None
+ self.current_cfg_node = None
+
+ def visit_FunctionDef(self, node):
+ parent_analyzer = self.current_analyzer
+ subgraph = self.graphs[node]
+
+ # Preorder tree processing:
+ # 1. if this is a child function, the parent was already analyzed and it
+ # has the proper state value for the subgraph's entry
+ # 2. analyze the current function body
+ # 2. recursively walk the subtree; child functions will be processed
+ analyzer = Analyzer(subgraph, self.definition_factory)
+ if parent_analyzer is not None:
+ # Wire the state between the two subgraphs' analyzers.
+ parent_out_state = parent_analyzer.out[parent_analyzer.graph.index[node]]
+ # Exception: symbols modified in the child function are local to it
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ parent_out_state -= body_scope.modified
+ analyzer.extra_in[node.args] = parent_out_state
+
+ # Complete the analysis for the local function and annotate its body.
+ analyzer.visit_forward()
+
+ # Recursively process any remaining subfunctions.
+ self.current_analyzer = analyzer
+ # Note: not visiting name, decorator_list and returns because they don't
+ # apply to this anlysis.
+ # TODO(mdan): Should we still process the function name?
+ node.args = self.visit(node.args)
+ node.body = self.visit_block(node.body)
+ self.current_analyzer = parent_analyzer
+
+ return node
+
+ def visit_nonlocal(self, node):
+ raise NotImplementedError()
+
+ def visit_global(self, node):
+ raise NotImplementedError()
+
+ def visit_Name(self, node):
+ if self.current_analyzer is None:
+ # Names may appear outside function defs - for example in class
+ # definitions.
+ return node
+
+ analyzer = self.current_analyzer
+ cfg_node = self.current_cfg_node
+
+ assert cfg_node is not None, 'name node outside of any statement?'
+
+ qn = anno.getanno(node, anno.Basic.QN)
+ if isinstance(node.ctx, gast.Load):
+ anno.setanno(node, anno.Static.DEFINITIONS,
+ tuple(analyzer.in_[cfg_node].value.get(qn, ())))
+ else:
+ anno.setanno(node, anno.Static.DEFINITIONS,
+ tuple(analyzer.out[cfg_node].value.get(qn, ())))
+
+ return node
+
+ def _aggregate_predecessors_defined_in(self, node):
+ preds = self.current_analyzer.graph.stmt_prev[node]
+ node_defined_in = set()
+ for p in preds:
+ node_defined_in |= set(self.current_analyzer.out[p].value.keys())
+ anno.setanno(node, anno.Static.DEFINED_VARS_IN, frozenset(node_defined_in))
+
+ def visit_If(self, node):
+ self._aggregate_predecessors_defined_in(node)
+ return self.generic_visit(node)
+
+ def visit_For(self, node):
+ self._aggregate_predecessors_defined_in(node)
+
+ # Manually accounting for the shortcoming described in
+ # cfg.AstToCfg.visit_For.
+ parent = self.current_cfg_node
+ self.current_cfg_node = self.current_analyzer.graph.index[node.iter]
+ node.target = self.visit(node.target)
+ self.current_cfg_node = parent
+
+ node.iter = self.visit(node.iter)
+ node.body = self.visit_block(node.body)
+ node.orelse = self.visit_block(node.orelse)
+
+ return node
+
+ def visit_While(self, node):
+ self._aggregate_predecessors_defined_in(node)
+ return self.generic_visit(node)
+
+ def visit(self, node):
+ parent = self.current_cfg_node
+
+ if (self.current_analyzer is not None and
+ node in self.current_analyzer.graph.index):
+ self.current_cfg_node = self.current_analyzer.graph.index[node]
+ node = super(TreeAnnotator, self).visit(node)
+
+ self.current_cfg_node = parent
+ return node
+
+
+def resolve(node, source_info, graphs, definition_factory):
+ """Resolves reaching definitions for each symbol.
+
+ Args:
+ node: ast.AST
+ source_info: transformer.SourceInfo
+ graphs: Dict[ast.FunctionDef, cfg.Graph]
+ definition_factory: Callable[[], Definition]
+ Returns:
+ ast.AST
+ """
+ visitor = TreeAnnotator(source_info, graphs, definition_factory)
+ node = visitor.visit(node)
+ return node
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
new file mode 100644
index 0000000000..373a2cb38f
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -0,0 +1,263 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for reaching_definitions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.platform import test
+
+
+class DefinitionInfoTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace={},
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ node = activity.resolve(node, entity_info)
+ graphs = cfg.build(node)
+ node = reaching_definitions.resolve(node, entity_info, graphs,
+ reaching_definitions.Definition)
+ return node
+
+ def assertHasDefs(self, node, num):
+ defs = anno.getanno(node, anno.Static.DEFINITIONS)
+ self.assertEqual(len(defs), num)
+ for r in defs:
+ self.assertIsInstance(r, reaching_definitions.Definition)
+
+ def assertHasDefinedIn(self, node, expected):
+ defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
+ defined_in_str = set(str(v) for v in defined_in)
+ if not expected:
+ expected = ()
+ if not isinstance(expected, tuple):
+ expected = (expected,)
+ self.assertSetEqual(defined_in_str, set(expected))
+
+ def assertSameDef(self, first, second):
+ self.assertHasDefs(first, 1)
+ self.assertHasDefs(second, 1)
+ self.assertIs(
+ anno.getanno(first, anno.Static.DEFINITIONS)[0],
+ anno.getanno(second, anno.Static.DEFINITIONS)[0])
+
+ def assertNotSameDef(self, first, second):
+ self.assertHasDefs(first, 1)
+ self.assertHasDefs(second, 1)
+ self.assertIsNot(
+ anno.getanno(first, anno.Static.DEFINITIONS)[0],
+ anno.getanno(second, anno.Static.DEFINITIONS)[0])
+
+ def test_conditional(self):
+
+ def test_fn(a, b):
+ a = []
+ if b:
+ a = []
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].test, 1)
+ self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[2].value, 2)
+
+ self.assertHasDefinedIn(fn_body[1], ('a', 'b'))
+
+ def test_while(self):
+
+ def test_fn(a):
+ max(a)
+ while True:
+ a = a
+ a = a
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].value.args[0], 1)
+ self.assertHasDefs(fn_body[1].body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].body[1].targets[0], 1)
+ self.assertHasDefs(fn_body[1].body[1].value, 1)
+ # The loop does have an invariant test, but the CFG doesn't know that.
+ self.assertHasDefs(fn_body[1].body[0].value, 2)
+ self.assertHasDefs(fn_body[2].value, 2)
+
+ def test_while_else(self):
+
+ def test_fn(x, i):
+ y = 0
+ while x:
+ x += i
+ if i:
+ break
+ else:
+ y = 1
+ return x, y
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].test, 2)
+ self.assertHasDefs(fn_body[1].body[0].target, 1)
+ self.assertHasDefs(fn_body[1].body[1].test, 1)
+ self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1)
+ self.assertHasDefs(fn_body[2].value.elts[0], 2)
+ self.assertHasDefs(fn_body[2].value.elts[1], 2)
+
+ def test_for_else(self):
+
+ def test_fn(x, i):
+ y = 0
+ for i in x:
+ x += i
+ if i:
+ break
+ else:
+ continue
+ else:
+ y = 1
+ return x, y
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].target, 1)
+ self.assertHasDefs(fn_body[1].body[0].target, 1)
+ self.assertHasDefs(fn_body[1].body[1].test, 1)
+ self.assertHasDefs(fn_body[1].orelse[0].targets[0], 1)
+ self.assertHasDefs(fn_body[2].value.elts[0], 2)
+ self.assertHasDefs(fn_body[2].value.elts[1], 2)
+
+ def test_nested_functions(self):
+
+ def test_fn(a, b):
+ a = []
+ if b:
+ a = []
+
+ def foo():
+ return a
+
+ foo()
+
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+ def_of_a_in_if = fn_body[1].body[0].targets[0]
+
+ self.assertHasDefs(fn_body[0].targets[0], 1)
+ self.assertHasDefs(fn_body[1].test, 1)
+ self.assertHasDefs(def_of_a_in_if, 1)
+ self.assertHasDefs(fn_body[2].value, 2)
+
+ inner_fn_body = fn_body[1].body[1].body
+ self.assertSameDef(inner_fn_body[0].value, def_of_a_in_if)
+
+ def test_nested_functions_isolation(self):
+
+ def test_fn(a):
+ a = 0
+
+ def child():
+ a = 1
+ return a
+
+ child()
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ parent_return = fn_body[3]
+ child_return = fn_body[1].body[1]
+ # The assignment `a = 1` makes `a` local to `child`.
+ self.assertNotSameDef(parent_return.value, child_return.value)
+
+ def test_function_call_in_with(self):
+
+ def foo(_):
+ pass
+
+ def test_fn(a):
+ with foo(a):
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ self.assertHasDefs(fn_body[0].items[0].context_expr.func, 0)
+ self.assertHasDefs(fn_body[0].items[0].context_expr.args[0], 1)
+
+ def test_mutation_subscript(self):
+
+ def test_fn(a):
+ l = []
+ l[0] = a
+ return l
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ creation = fn_body[0].targets[0]
+ mutation = fn_body[1].targets[0].value
+ use = fn_body[2].value
+ self.assertSameDef(creation, mutation)
+ self.assertSameDef(creation, use)
+
+ def test_replacement(self):
+
+ def foo(a):
+ return a
+
+ def test_fn(a):
+ a = foo(a)
+ return a
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ param = node.body[0].args.args[0]
+ source = fn_body[0].value.args[0]
+ target = fn_body[0].targets[0]
+ retval = fn_body[1].value
+ self.assertSameDef(param, source)
+ self.assertNotSameDef(source, target)
+ self.assertSameDef(target, retval)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_info.py b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
new file mode 100644
index 0000000000..edb2ef0e27
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info.py
@@ -0,0 +1,213 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Type resolution.
+
+This analyzer uses known live values to further infer object types. This
+may include for instance constructed objects and object member functions.
+
+In addition, the analyzer also handles user annotations made in the code (for
+example, the autograph.set_element_type function).
+
+Requires annotations generated by LiveValuesResolver.
+"""
+
+# TODO(mdan): This would be more robust with a CFG.
+# Situations with multiple reaching modifications (e.g. modified inside and
+# outside a control flow statement) should be more robustly detected and
+# analyzed.
+
+# TODO(mdan): Look into using Python AST's type annotation fields instead.
+# It would be desirable to use that mechanism if we can.
+# Some caveats to consider: We may need to annotate other nodes like
+# Attribute. It may also not be feasible for us to faithfully to replicate
+# PY3's type annotations where it isn't available. It would also require us
+# to design rigorous type definitions that can accommodate Python types
+# as well as TensorFLow dtypes and shapes.
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.util import tf_inspect
+
+
+# TODO(mdan): Remove the duplication between this and activity.py.
+# In particular, the symbol definitions we track here could as well be tracked
+# there because they follow the same rules for visibility.
+# TODO(mdan): Use a CFG based Defined analysis instead.
+class Scope(object):
+ """Tracks symbol value references.
+
+ Attributes:
+ values: A dict mapping string to gast.Node, containing the value that was
+ most recently assigned to the symbol.
+ """
+
+ def __init__(self, parent):
+ """Create a new scope.
+
+ Args:
+ parent: A Scope or None.
+ """
+ self.parent = parent
+ self.values = {}
+
+ def __repr__(self):
+ return 'Scope[%s]' % self.values.keys()
+
+ def copy(self):
+ s = Scope(self.parent)
+ s.values = self.values.copy()
+ return s
+
+ def setval(self, name, value):
+ self.values[name] = value
+
+ def hasval(self, name):
+ return (name in self.values or
+ (self.parent is not None and self.parent.hasval(name)))
+
+ def getval(self, name):
+ if name in self.values:
+ return self.values[name]
+ if self.parent is not None:
+ return self.parent.getval(name)
+ raise KeyError(name)
+
+
+class TypeInfoResolver(transformer.Base):
+ """Annotates symbols with type information where possible.
+
+ Nodes currently annotated:
+ * Call (helps detect class constructors)
+ * Attribute (helps resolve object methods)
+ """
+
+ def __init__(self, context):
+ super(TypeInfoResolver, self).__init__(context)
+ self.scope = Scope(None)
+
+ def visit_FunctionDef(self, node):
+ self.scope = Scope(self.scope)
+ node = self.generic_visit(node)
+ self.scope = self.scope.parent
+ return node
+
+ def _visit_block(self, block):
+ self.scope = Scope(self.scope)
+ block = self.visit_block(block)
+ self.scope = self.scope.parent
+ return block
+
+ def visit_For(self, node):
+ self.generic_visit(node.target)
+ self.generic_visit(node.iter)
+ node.body = self._visit_block(node.body)
+ node.orelse = self._visit_block(node.orelse)
+ return node
+
+ def visit_While(self, node):
+ self.generic_visit(node.test)
+ node.body = self._visit_block(node.body)
+ node.orelse = self._visit_block(node.orelse)
+ return node
+
+ def visit_If(self, node):
+ self.generic_visit(node.test)
+ node.body = self._visit_block(node.body)
+ node.orelse = self._visit_block(node.orelse)
+ return node
+
+ def _process_function_arg(self, arg_node):
+ qn = anno.getanno(arg_node, anno.Basic.QN)
+ arg_name = str(qn)
+ self.scope.setval(qn, arg_node)
+ if (len(self.enclosing_entities) == 1 and
+ arg_name in self.entity_info.arg_types):
+ # Forge a node to hold the type information, so that method calls on
+ # it can resolve the type.
+ type_string, type_obj = self.entity_info.arg_types[arg_name]
+ anno.setanno(arg_node, 'type', type_obj)
+ anno.setanno(arg_node, 'type_fqn', tuple(type_string.split('.')))
+
+ def visit_arg(self, node):
+ self._process_function_arg(node.arg)
+ return node
+
+ def visit_Name(self, node):
+ self.generic_visit(node)
+ if isinstance(node.ctx, gast.Param):
+ self._process_function_arg(node)
+ elif isinstance(node.ctx, gast.Load):
+ qn = anno.getanno(node, anno.Basic.QN)
+ if self.scope.hasval(qn):
+ # E.g. if we had
+ # a = b
+ # then for future references to `a` we should have definition = `b`
+ definition = self.scope.getval(qn)
+ anno.copyanno(definition, node, 'type')
+ anno.copyanno(definition, node, 'type_fqn')
+
+ # TODO(mdan): Remove this when the directives module is in.
+ anno.copyanno(definition, node, 'element_type')
+ anno.copyanno(definition, node, 'element_shape')
+ return node
+
+ def _process_variable_assignment(self, target, value):
+ # Constructors
+ if isinstance(value, gast.Call):
+ func = value.func
+ if anno.hasanno(func, 'live_val'):
+ func_obj = anno.getanno(func, 'live_val')
+ if tf_inspect.isclass(func_obj):
+ anno.setanno(value, 'is_constructor', True)
+ anno.setanno(value, 'type', func_obj)
+ anno.setanno(value, 'type_fqn', anno.getanno(func, 'fqn'))
+ # TODO(mdan): Raise an error if constructor has side effects.
+ # We can have a whitelist of no-side-effects constructors.
+ # We can also step inside the constructor and further analyze.
+
+ if isinstance(target, (gast.Name, gast.Attribute)):
+ target_symbol = anno.getanno(target, anno.Basic.QN)
+ self.scope.setval(target_symbol, value)
+ elif isinstance(target, gast.Subscript):
+ pass
+ else:
+ raise ValueError('assignment target has unknown type: %s' % target)
+
+ def visit_With(self, node):
+ for item in node.items:
+ if item.optional_vars is not None:
+ ast_util.apply_to_single_assignments((item.optional_vars,),
+ item.context_expr,
+ self._process_variable_assignment)
+ self.generic_visit(node)
+ return node
+
+ def visit_Assign(self, node):
+ self.generic_visit(node)
+ ast_util.apply_to_single_assignments(node.targets, node.value,
+ self._process_variable_assignment)
+ return node
+
+
+def resolve(node, context):
+ return TypeInfoResolver(context).visit(node)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
new file mode 100644
index 0000000000..34ba3d2f13
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_info_test.py
@@ -0,0 +1,207 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for type_info module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import cfg
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.autograph.pyct.static_analysis import activity
+from tensorflow.python.autograph.pyct.static_analysis import live_values
+from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions
+from tensorflow.python.autograph.pyct.static_analysis import type_info
+from tensorflow.python.client import session
+from tensorflow.python.platform import test
+from tensorflow.python.training import training
+
+
+class ScopeTest(test.TestCase):
+
+ def test_basic(self):
+ scope = type_info.Scope(None)
+ self.assertFalse(scope.hasval('foo'))
+
+ scope.setval('foo', 'bar')
+ self.assertTrue(scope.hasval('foo'))
+
+ self.assertFalse(scope.hasval('baz'))
+
+ def test_nesting(self):
+ scope = type_info.Scope(None)
+ scope.setval('foo', '')
+
+ child = type_info.Scope(scope)
+ self.assertTrue(child.hasval('foo'))
+ self.assertTrue(scope.hasval('foo'))
+
+ child.setval('bar', '')
+ self.assertTrue(child.hasval('bar'))
+ self.assertFalse(scope.hasval('bar'))
+
+
+class TypeInfoResolverTest(test.TestCase):
+
+ def _parse_and_analyze(self,
+ test_fn,
+ namespace,
+ arg_types=None):
+ node, source = parser.parse_entity(test_fn)
+ entity_info = transformer.EntityInfo(
+ source_code=source,
+ source_file=None,
+ namespace=namespace,
+ arg_values=None,
+ arg_types=arg_types,
+ owner_type=None)
+ node = qual_names.resolve(node)
+ graphs = cfg.build(node)
+ node = activity.resolve(node, entity_info)
+ node = reaching_definitions.resolve(node, entity_info, graphs,
+ reaching_definitions.Definition)
+ node = live_values.resolve(node, entity_info, {})
+ node = type_info.resolve(node, entity_info)
+ node = live_values.resolve(node, entity_info, {})
+ return node
+
+ def test_constructor_detection(self):
+
+ def test_fn():
+ opt = training.GradientDescentOptimizer(0.1)
+ return opt
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ call_node = node.body[0].body[0].value
+ self.assertEquals(training.GradientDescentOptimizer,
+ anno.getanno(call_node, 'type'))
+ self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
+ anno.getanno(call_node, 'type_fqn'))
+
+ def test_class_members_of_detected_constructor(self):
+
+ def test_fn():
+ opt = training.GradientDescentOptimizer(0.1)
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertEquals(training.GradientDescentOptimizer.minimize,
+ anno.getanno(method_call, 'live_val'))
+
+ def test_class_members_in_with_stmt(self):
+
+ def test_fn(x):
+ with session.Session() as sess:
+ sess.run(x)
+
+ node = self._parse_and_analyze(test_fn, {'session': session})
+ constructor_call = node.body[0].body[0].items[0].context_expr
+ self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
+ self.assertEquals((session.__name__, 'Session'),
+ anno.getanno(constructor_call, 'type_fqn'))
+
+ method_call = node.body[0].body[0].body[0].value.func
+ self.assertEquals(session.Session.run, anno.getanno(method_call,
+ 'live_val'))
+
+ def test_constructor_data_dependent(self):
+
+ def test_fn(x):
+ if x > 0:
+ opt = training.GradientDescentOptimizer(0.1)
+ else:
+ opt = training.GradientDescentOptimizer(0.01)
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_parameter_class_members(self):
+
+ def test_fn(opt):
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(test_fn, {})
+ method_call = node.body[0].body[0].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_parameter_class_members_with_value_hints(self):
+
+ def test_fn(opt):
+ opt.minimize(0)
+
+ node = self._parse_and_analyze(
+ test_fn, {},
+ arg_types={
+ 'opt': (training.GradientDescentOptimizer.__name__,
+ training.GradientDescentOptimizer)
+ })
+
+ method_call = node.body[0].body[0].value.func
+ self.assertEquals(training.GradientDescentOptimizer.minimize,
+ anno.getanno(method_call, 'live_val'))
+
+ def test_function_variables(self):
+
+ def bar():
+ pass
+
+ def test_fn():
+ foo = bar
+ foo()
+
+ node = self._parse_and_analyze(test_fn, {'bar': bar})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_nested_members(self):
+
+ def test_fn():
+ foo = training.GradientDescentOptimizer(0.1)
+ foo.bar.baz()
+
+ node = self._parse_and_analyze(test_fn, {'training': training})
+ method_call = node.body[0].body[1].value.func
+ self.assertFalse(anno.hasanno(method_call, 'live_val'))
+
+ def test_nested_unpacking(self):
+
+ class Foo(object):
+ pass
+
+ class Bar(object):
+ pass
+
+ def test_fn():
+ a, (b, c) = (Foo(), (Bar(), Foo()))
+ return a, b, c
+
+ node = self._parse_and_analyze(test_fn, {'Foo': Foo, 'Bar': Bar})
+ a, b, c = node.body[0].body[1].value.elts
+ self.assertEquals(anno.getanno(a, 'type'), Foo)
+ self.assertEquals(anno.getanno(b, 'type'), Bar)
+ self.assertEquals(anno.getanno(c, 'type'), Foo)
+ self.assertFalse(anno.hasanno(a, 'live_val'))
+ self.assertFalse(anno.hasanno(b, 'live_val'))
+ self.assertFalse(anno.hasanno(c, 'live_val'))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
new file mode 100644
index 0000000000..68c2a35fac
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -0,0 +1,277 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST conversion templates.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import ast
+import textwrap
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import ast_util
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import qual_names
+
+
+class ReplaceTransformer(gast.NodeTransformer):
+ """Replace AST nodes."""
+
+ def __init__(self, replacements):
+ """Create a new ReplaceTransformer.
+
+ Args:
+ replacements: A mapping from placeholder names to (lists of) AST nodes
+ that these placeholders will be replaced by.
+ """
+ self.replacements = replacements
+ self.in_replacements = False
+ self.preserved_annos = {
+ anno.Basic.ORIGIN,
+ anno.Basic.SKIP_PROCESSING,
+ anno.Static.ORIG_DEFINITIONS,
+ }
+
+ def _prepare_replacement(self, replaced, key):
+ """Prepares a replacement AST that's safe to swap in for a node.
+
+ Args:
+ replaced: ast.AST, the node being replaced
+ key: Hashable, the key of the replacement AST
+ Returns:
+ ast.AST, the replacement AST
+ """
+ repl = self.replacements[key]
+
+ new_nodes = ast_util.copy_clean(repl, preserve_annos=self.preserved_annos)
+ if isinstance(new_nodes, gast.AST):
+ new_nodes = [new_nodes]
+
+ return new_nodes
+
+ def visit_Expr(self, node):
+ # When replacing a placeholder with an entire statement, the replacement
+ # must stand on its own and not be wrapped in an Expr.
+ new_value = self.visit(node.value)
+ if new_value is node.value:
+ return node
+ return new_value
+
+ def visit_keyword(self, node):
+ if node.arg not in self.replacements:
+ return self.generic_visit(node)
+
+ repl = self._prepare_replacement(node, node.arg)
+ if isinstance(repl, gast.keyword):
+ return repl
+ elif (repl and isinstance(repl, (list, tuple)) and
+ all(isinstance(r, gast.keyword) for r in repl)):
+ return repl
+ # TODO(mdan): We may allow replacing with a string as well.
+ # For example, if one wanted to replace foo with bar in foo=baz, then
+ # we could allow changing just node arg, so that we end up with bar=baz.
+ raise ValueError(
+ 'a keyword argument may only be replaced by another keyword or a '
+ 'non-empty list of keywords. Found: %s' % repl)
+
+ def visit_FunctionDef(self, node):
+ node = self.generic_visit(node)
+ if node.name not in self.replacements:
+ return node
+
+ repl = self.replacements[node.name]
+ if not isinstance(repl, (gast.Name, ast.Name)):
+ raise ValueError(
+ 'a function name can only be replaced by a Name node. Found: %s' %
+ repl)
+ node.name = repl.id
+ return node
+
+ def _check_has_context(self, node):
+ if not node.ctx:
+ raise ValueError('node %s is missing ctx value' % node)
+
+ def _check_inner_children_have_context(self, node):
+ if isinstance(node, gast.Attribute):
+ self._check_inner_children_have_context(node.value)
+ self._check_has_context(node)
+ elif isinstance(node, (gast.Tuple, gast.List)):
+ for e in node.elts:
+ self._check_inner_children_have_context(e)
+ self._check_has_context(node)
+ elif isinstance(node, gast.Dict):
+ for e in node.keys:
+ self._check_inner_children_have_context(e)
+ for e in node.values:
+ self._check_inner_children_have_context(e)
+ elif isinstance(node, gast.Subscript):
+ self._check_inner_children_have_context(node.value)
+ self._check_inner_children_have_context(node.slice)
+ elif isinstance(node, gast.Slice):
+ self._check_inner_children_have_context(node.lower)
+ if node.upper:
+ self._check_inner_children_have_context(node.upper)
+ if node.step:
+ self._check_inner_children_have_context(node.step)
+ elif isinstance(node, gast.Name):
+ self._check_has_context(node)
+ elif isinstance(node, (gast.Str, gast.Num)):
+ pass
+ else:
+ raise ValueError('unexpected node type "%s"' % node)
+
+ def _set_inner_child_context(self, node, ctx):
+ if isinstance(node, gast.Attribute):
+ self._set_inner_child_context(node.value, gast.Load())
+ node.ctx = ctx
+ elif isinstance(node, (gast.Tuple, gast.List)):
+ for e in node.elts:
+ self._set_inner_child_context(e, ctx)
+ node.ctx = ctx
+ elif isinstance(node, gast.Name):
+ node.ctx = ctx
+ elif isinstance(node, gast.Call):
+ self._set_inner_child_context(node.func, ctx)
+ # We may be able to override these to Load(), but for now it's simpler
+ # to just assert that they're set.
+ for a in node.args:
+ self._check_inner_children_have_context(a)
+ for k in node.keywords:
+ self._check_inner_children_have_context(k.value)
+ elif isinstance(node, gast.Dict):
+ # We may be able to override these to Load(), but for now it's simpler
+ # to just assert that they're set.
+ for e in node.keys:
+ self._check_inner_children_have_context(e)
+ for e in node.values:
+ self._check_inner_children_have_context(e)
+ elif isinstance(node, gast.Subscript):
+ self._set_inner_child_context(node.value, ctx)
+ self._check_inner_children_have_context(node.slice)
+ elif isinstance(node, (gast.Str, gast.Num)):
+ pass
+ else:
+ raise ValueError('unexpected node type "%s"' % node)
+
+ def visit_Attribute(self, node):
+ node = self.generic_visit(node)
+ if node.attr not in self.replacements:
+ return node
+
+ repl = self.replacements[node.attr]
+ if not isinstance(repl, gast.Name):
+ raise ValueError(
+ 'An attribute can only be replaced by a Name node. Found: %s' % repl)
+ node.attr = repl.id
+ return node
+
+ def visit_Name(self, node):
+ if node.id not in self.replacements:
+ return node
+
+ new_nodes = self._prepare_replacement(node, node.id)
+
+ # Preserve the target context.
+ for n in new_nodes:
+ if isinstance(n, (gast.Tuple, gast.List)):
+ for e in n.elts:
+ self._set_inner_child_context(e, node.ctx)
+ if isinstance(n, gast.Attribute):
+ # For attributes, the inner Name node receives the context, while the
+ # outer ones have it set to Load.
+ self._set_inner_child_context(n, node.ctx)
+ else:
+ n.ctx = node.ctx
+
+ if len(new_nodes) == 1:
+ new_nodes, = new_nodes
+
+ return new_nodes
+
+
+def _convert_to_ast(n):
+ """Converts from a known data type to AST."""
+ if isinstance(n, str):
+ # Note: the node will receive the ctx value from the template, see
+ # ReplaceTransformer.visit_Name.
+ return gast.Name(id=n, ctx=None, annotation=None)
+ if isinstance(n, qual_names.QN):
+ return n.ast()
+ if isinstance(n, list):
+ return [_convert_to_ast(e) for e in n]
+ if isinstance(n, tuple):
+ return tuple(_convert_to_ast(e) for e in n)
+ return n
+
+
+def replace(template, **replacements):
+ """Replaces placeholders in a Python template.
+
+ AST Name and Tuple nodes always receive the context that inferred from
+ the template. However, when replacing more complex nodes (that can potentially
+ contain Name children), then the caller is responsible for setting the
+ appropriate context.
+
+ Args:
+ template: A string representing Python code. Any symbol name can be used
+ that appears in the template code can be used as placeholder.
+ **replacements: A mapping from placeholder names to (lists of) AST nodes
+ that these placeholders will be replaced by. String values are also
+ supported as a shorthand for AST Name nodes with the respective ID.
+
+ Returns:
+ An AST node or list of AST nodes with the replacements made. If the
+ template was a function, a list will be returned. If the template was a
+ node, the same node will be returned. If the template was a string, an
+ AST node will be returned (a `Module` node in the case of a multi-line
+ string, an `Expr` node otherwise).
+
+ Raises:
+ ValueError: if the arguments are incorrect.
+ """
+ if not isinstance(template, str):
+ raise ValueError('Expected string template, got %s' % type(template))
+ tree = parser.parse_str(textwrap.dedent(template))
+ for k in replacements:
+ replacements[k] = _convert_to_ast(replacements[k])
+ results = ReplaceTransformer(replacements).visit(tree).body
+ if isinstance(results, list):
+ return [qual_names.resolve(r) for r in results]
+ return qual_names.resolve(results)
+
+
+def replace_as_expression(template, **replacements):
+ """Variant of replace that generates expressions, instead of code blocks."""
+ replacement = replace(template, **replacements)
+ if len(replacement) != 1:
+ raise ValueError(
+ 'single expression expected; for more general templates use replace')
+ node = replacement[0]
+ node = qual_names.resolve(node)
+
+ if isinstance(node, gast.Expr):
+ return node.value
+ elif isinstance(node, gast.Name):
+ return node
+
+ raise ValueError(
+ 'the template is expected to generate an expression or a name node;'
+ ' instead found %s' % node)
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
new file mode 100644
index 0000000000..66268cfaad
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -0,0 +1,213 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for templates module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+
+import gast
+
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import templates
+from tensorflow.python.platform import test
+
+
+class TemplatesTest(test.TestCase):
+
+ def test_replace_tuple(self):
+ template = """
+ def test_fn(a, c):
+ return b,
+ """
+
+ node = templates.replace(template, b=('a', 'c'))[0]
+ result, _ = compiler.ast_to_object(node)
+
+ self.assertEquals((2, 3), result.test_fn(2, 3))
+
+ def test_replace_variable(self):
+ template = """
+ def test_fn(a):
+ a += 1
+ a = 2 * a + 1
+ return b
+ """
+
+ node = templates.replace(template, a='b')[0]
+ result, _ = compiler.ast_to_object(node)
+ self.assertEquals(7, result.test_fn(2))
+
+ def test_replace_function_name(self):
+ template = """
+ def fname(a):
+ a += 1
+ a = 2 * a + 1
+ return a
+ """
+
+ node = templates.replace(template, fname='test_fn')[0]
+ result, _ = compiler.ast_to_object(node)
+ self.assertEquals(7, result.test_fn(2))
+
+ def test_replace_code_block(self):
+ template = """
+ def test_fn(a):
+ block
+ return a
+ """
+
+ node = templates.replace(
+ template,
+ block=[
+ gast.Assign([
+ gast.Name('a', None, None)
+ ], gast.BinOp(gast.Name('a', None, None), gast.Add(), gast.Num(1))),
+ ] * 2)[0]
+ result, _ = compiler.ast_to_object(node)
+ self.assertEquals(3, result.test_fn(1))
+
+ def test_replace_attribute(self):
+ template = """
+ def test_fn(a):
+ return a.foo
+ """
+
+ node = templates.replace(template, foo='b')[0]
+ result, _ = compiler.ast_to_object(node)
+ mod = imp.new_module('test')
+ mod.b = 3
+ self.assertEquals(3, result.test_fn(mod))
+
+ with self.assertRaises(ValueError):
+ templates.replace(template, foo=1)
+
+ def test_replace_attribute_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template,
+ foo=parser.parse_expression('a.b.c'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
+ self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
+
+ def test_replace_list_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_tuple_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_complex_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+
+ def test_replace_call_keyword(self):
+ template = """
+ def test_fn():
+ def f(a, d, f):
+ return a + d + f
+ return f(1, kws=None)
+ """
+
+ source = parser.parse_expression('f(d=3, f=5)')
+ node = templates.replace(template, kws=source.keywords)[0]
+ result, _ = compiler.ast_to_object(node)
+ self.assertEquals(9, result.test_fn())
+
+ with self.assertRaises(ValueError):
+ templates.replace(template, kws=[])
+ templates.replace(template, kws=1)
+
+ def test_replace_name_with_call(self):
+ template = """
+ def test_fn():
+ b = 5
+ def g(a):
+ return 3 * a
+ def f():
+ return g
+ return foo
+ """
+
+ source = parser.parse_expression('f()(b)')
+ node = templates.replace(template, foo=source)[0]
+ result, _ = compiler.ast_to_object(node)
+ self.assertEquals(15, result.test_fn())
+
+ def test_replace_name_with_dict(self):
+ template = """
+ def test_fn():
+ return foo['bar']
+ """
+
+ source = parser.parse_expression('{\'bar\': 3}')
+ node = templates.replace(template, foo=source)[0]
+ result, _ = compiler.ast_to_object(node)
+ self.assertEquals(3, result.test_fn())
+
+ def replace_as_expression(self):
+ template = """
+ foo(a)
+ """
+
+ node = templates.replace(template, foo='bar', a='baz')
+ self.assertTrue(node is gast.Call)
+ self.assertEqual(node.func.id, 'bar')
+ self.assertEqual(node.func.args[0].id, 'baz')
+
+ def test_replace_as_expression_restrictions(self):
+ template = """
+ foo(a)
+ bar(b)
+ """
+ with self.assertRaises(ValueError):
+ templates.replace_as_expression(template)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/testing/BUILD b/tensorflow/python/autograph/pyct/testing/BUILD
new file mode 100644
index 0000000000..c244cbd747
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/testing/BUILD
@@ -0,0 +1,48 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "testing",
+ srcs = [
+ "codegen.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "codegen_test",
+ size = "large",
+ srcs = ["codegen_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "no_windows",
+ "nomsan",
+ "notap",
+ ],
+ deps = [
+ ":testing",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/autograph/pyct",
+ "@gast_archive//:gast",
+ ],
+)
diff --git a/tensorflow/python/autograph/pyct/testing/codegen.py b/tensorflow/python/autograph/pyct/testing/codegen.py
new file mode 100644
index 0000000000..78b24390c3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/testing/codegen.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Random code generation for testing/fuzzing."""
+# pylint: disable=invalid-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import string
+
+import gast
+import numpy as np
+
+from tensorflow.python.autograph.pyct import templates
+
+
+class NodeSampler(object):
+ sample_map = None
+
+ def sample(self):
+ nodes, magnitudes = zip(*self.sample_map.items())
+ return np.random.choice(
+ nodes, p=np.array(magnitudes, dtype='float32') / np.sum(magnitudes))
+
+
+class StatementSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Assign, 10),
+ (gast.Print, 1),
+ (gast.If, 2),
+ (gast.While, 2),
+ (gast.For, 0),
+ ))
+
+
+class ExpressionSampler(NodeSampler):
+ sample_map = dict((
+ (gast.UnaryOp, 1),
+ (gast.BinOp, 8),
+ (gast.Name, 1),
+ (gast.Call, 0),
+ ))
+
+
+class CompareSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Eq, 1),
+ (gast.NotEq, 1),
+ (gast.Lt, 1),
+ (gast.LtE, 1),
+ (gast.Gt, 1),
+ (gast.GtE, 1),
+ (gast.Is, 1),
+ (gast.IsNot, 1),
+ ))
+
+
+class BinaryOpSampler(NodeSampler):
+ sample_map = dict((
+ (gast.Add, 1),
+ (gast.Sub, 1),
+ (gast.Mult, 1),
+ (gast.Div, 1),
+ (gast.FloorDiv, 1),
+ (gast.Mod, 1),
+ (gast.Pow, 1),
+ ))
+
+
+class UnaryOpSampler(NodeSampler):
+ sample_map = dict(((gast.USub, 1), (gast.UAdd, 0)))
+
+
+class NameSampler(NodeSampler):
+ sample_map = dict((
+ ('new', 1),
+ ('existing', 1),
+ ))
+
+
+N_CONTROLFLOW_STATEMENTS = 10
+N_FUNCTIONDEF_STATEMENTS = 10
+
+
+class CodeGenerator(object):
+ """Generate random syntactically-valid Python ASTs."""
+
+ def __init__(self, max_depth=3, depth=0):
+ self.max_depth = max_depth
+ self.depth = depth
+
+ def generate_statement(self):
+ """Generate a statement node, dispatching to the correct class method."""
+ desired_node = StatementSampler().sample()
+ self.depth += 1
+
+ # Enforce some constraints on generating statements.
+ # E.g., if statements need at least 3 readable variables.
+ # If we fail to satisfy our constraints, draw another sample.
+ if desired_node in (gast.While, gast.For, gast.If):
+ if self.depth > self.max_depth:
+ return self.generate_statement()
+
+ # Go get the generator method and run it
+ method = 'generate_' + desired_node.__name__
+ visitor = getattr(self, method)
+ node = visitor()
+ self.depth -= 1
+ return node
+
+ def sample_node_list(self, low, high, generator):
+ """Generate a list of statements of random length.
+
+ Args:
+ low: Fewest number of statements to generate.
+ high: Highest number of statements to generate.
+ generator: Function to call to generate nodes.
+
+ Returns:
+ A list of statements.
+ """
+ statements = []
+ for _ in range(np.random.randint(low, high)):
+ statements.append(generator())
+ return statements
+
+ def generate_Name(self, ctx=gast.Load()):
+ variable_name = '_' + ''.join(
+ random.choice(string.ascii_lowercase) for _ in range(4))
+ return gast.Name(variable_name, ctx=ctx, annotation=None)
+
+ def generate_BinOp(self):
+ # TODO(alexbw): convert to generate_expression when we get to limit
+ # expression depth.
+ op = BinaryOpSampler().sample()()
+ return gast.BinOp(self.generate_Name(), op, self.generate_Name())
+
+ def generate_Compare(self):
+ op = CompareSampler().sample()()
+ return gast.Compare(self.generate_Name(), [op], [self.generate_Name()])
+
+ def generate_UnaryOp(self):
+ operand = self.generate_Name()
+ op = UnaryOpSampler().sample()()
+ return gast.UnaryOp(op, operand)
+
+ def generate_expression(self):
+ desired_node = ExpressionSampler().sample()
+ # Go get the generator method and run it
+ method = 'generate_' + desired_node.__name__
+ generator = getattr(self, method)
+ return generator()
+
+ def generate_Assign(self):
+ """Generate an Assign node."""
+ # Generate left-hand side
+ target_node = self.generate_Name(gast.Store())
+ # Generate right-hand side
+ value_node = self.generate_expression()
+ # Put it all together
+ node = gast.Assign(targets=[target_node], value=value_node)
+ return node
+
+ def generate_If(self):
+ """Generate an If node."""
+ test = self.generate_Compare()
+
+ # Generate true branch statements
+ body = self.sample_node_list(
+ low=1,
+ high=N_CONTROLFLOW_STATEMENTS // 2,
+ generator=self.generate_statement)
+
+ # Generate false branch statements
+ orelse = self.sample_node_list(
+ low=1,
+ high=N_CONTROLFLOW_STATEMENTS // 2,
+ generator=self.generate_statement)
+
+ node = gast.If(test, body, orelse)
+ return node
+
+ def generate_While(self):
+ """Generate a While node."""
+
+ test = self.generate_Compare()
+ body = self.sample_node_list(
+ low=1, high=N_CONTROLFLOW_STATEMENTS, generator=self.generate_statement)
+ orelse = [] # not generating else statements
+
+ node = gast.While(test, body, orelse)
+ return node
+
+ def generate_Call(self):
+ raise NotImplementedError
+
+ def generate_Return(self):
+ return gast.Return(self.generate_expression())
+
+ def generate_Print(self):
+ return templates.replace('print(x)', x=self.generate_expression())[0]
+
+ def generate_FunctionDef(self):
+ """Generate a FunctionDef node."""
+
+ # Generate the arguments, register them as available
+ arg_vars = self.sample_node_list(
+ low=2, high=10, generator=lambda: self.generate_Name(gast.Param()))
+ args = gast.arguments(arg_vars, None, [], [], None, [])
+
+ # Generate the function body
+ body = self.sample_node_list(
+ low=1, high=N_FUNCTIONDEF_STATEMENTS, generator=self.generate_statement)
+ body.append(self.generate_Return())
+ fn_name = self.generate_Name().id
+ node = gast.FunctionDef(fn_name, args, body, (), None)
+ return node
+
+
+def generate_random_functiondef():
+ return CodeGenerator().generate_FunctionDef()
diff --git a/tensorflow/python/autograph/pyct/testing/codegen_test.py b/tensorflow/python/autograph/pyct/testing/codegen_test.py
new file mode 100644
index 0000000000..71665be039
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/testing/codegen_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for type_info module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct.testing import codegen
+from tensorflow.python.platform import test
+
+
+class CodeGenTest(test.TestCase):
+
+ def test_codegen_gens(self):
+ np.random.seed(0)
+ for _ in range(1000):
+ node = codegen.generate_random_functiondef()
+ fn = compiler.ast_to_object(node)
+ self.assertIsNotNone(
+ fn, 'Generated invalid AST that could not convert to source.')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py
new file mode 100644
index 0000000000..520f5038da
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/transformer.py
@@ -0,0 +1,487 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A node transformer that includes utilities for SCT."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import gast
+import six
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import compiler
+from tensorflow.python.autograph.pyct import pretty_printer
+
+
+class AutographParseError(SyntaxError):
+ pass
+
+
+# TODO(mdan): Use namedtuple.
+class EntityInfo(object):
+ """Contains information about a Python entity. Immutable.
+
+ Examples of entities include functions and classes.
+
+ Attributes:
+ source_code: The entity's source code.
+ source_file: The entity's source file.
+ namespace: Dict[str, ], containing symbols visible to the entity
+ (excluding parameters).
+ arg_values: dict[str->*], containing parameter values, if known.
+ arg_types: dict[str->*], containing parameter types, if known.
+ owner_type: The surrounding class type of the function, if present.
+ """
+
+ # TODO(mdan): Remove the default and update tests.
+ def __init__(self, source_code, source_file, namespace, arg_values, arg_types,
+ owner_type):
+ self.source_code = source_code
+ self.source_file = source_file
+ self.namespace = namespace
+ self.arg_values = {} if arg_values is None else arg_values
+ self.arg_types = {} if arg_types is None else arg_types
+ self.owner_type = owner_type
+
+
+class _StateStack(object):
+ """Typed stack abstraction.
+
+ This class provides syntactic sugar for a stack of objects of known
+ type. It allows accessing attributes of the object at the top of the stack
+ directly against this object, which allows for very terse syntax.
+
+ For example, this code:
+
+ stack = _StateStack(Foo)
+ stack.enter()
+ stack.bar
+
+ Is equivalent to:
+
+ stack = []
+ stack.append(Foo())
+ foo = stack[-1]
+ foo.bar
+
+ See _State for more on how this is used.
+
+ Attributes:
+ type: Any, the type of objects that this stack holds
+ level: int, the current stack depth
+ value: Any, the instance of the object at the top of the stack
+ """
+
+ def __init__(self, type_):
+ # Because we override __setattr__, we need to attach these attributes using
+ # the superclass' setattr.
+ object.__setattr__(self, 'type', type_)
+ object.__setattr__(self, '_stack', [])
+ self.enter()
+
+ def enter(self):
+ self._stack.append(self.type())
+
+ def exit(self):
+ return self._stack.pop()
+
+ @property
+ def level(self):
+ return len(self._stack)
+
+ @property
+ def value(self):
+ return self._stack[-1]
+
+ def __getattr__(self, key):
+ return getattr(self._stack[-1], key)
+
+ def __setattr__(self, key, value):
+ setattr(self._stack[-1], key, value)
+
+
+class _State(object):
+ """Supporting class for nested scope variable space for converter.Base.
+
+ This structure offers syntactic sugar over a dict of stacks of objects
+ of known type. These structures are useful to keep state during AST walks.
+ Multiple different scopes can be tracked in parallel. For example:
+
+ s = _State()
+
+ s[foo].enter()
+ s[bar].enter() # this will not affect s[foo]
+
+ Element access has special semantics:
+ * keys are a data type
+ * element values are _StateStack(type=key) objects
+ * missing elements are automatically added, similarly to defaultdict
+
+ For example, the following block :
+
+ _State s
+ s[Foo]
+
+ Is equivalent to:
+
+ s = {}
+ if Foo not in s:
+ s[Foo] = Foo()
+ s[Foo]
+
+ See Base for how it's used.
+ """
+
+ def __init__(self):
+ self._value = {}
+
+ def __getitem__(self, key):
+ if key not in self._value:
+ self._value[key] = _StateStack(key)
+ return self._value[key]
+
+
+class Base(gast.NodeTransformer):
+ """Base class for general-purpose code transformers transformers.
+
+ This is an extension of ast.NodeTransformer that provides a few additional
+ functions, like state tracking within the scope of arbitrary node, helpers
+ for processing code blocks, debugging, mapping of transformed code to
+ original code, and others.
+
+ Scope-local state tracking: to keep state across nodes, at the level of
+ (possibly nested) scopes, use enter/exit_local_scope and set/get_local.
+ You must call enter/exit_local_scope manually, but the transformer detects
+ when they are not properly paired.
+
+ The transformer allows keeping state across calls to visit_* that is local to
+ arbitrary nodes and their descendants, using the self.state attribute.
+ Multiple independent scopes are allowed and automatically constructed.
+
+ For example, to keep track of the If node that encloses any Name node, one can
+ write:
+
+ class FooType(object):
+
+ def __init__(self):
+ self.foo_property = None
+
+ class DummyTransformer(Base):
+
+ def visit_If(self, node):
+ self.state[FooType].enter()
+ self.state[FooType].foo_property = node
+
+ def visit_Name(self, node):
+ self.state[FooType].foo_property # will hold the innermost enclosing if
+ """
+
+ # TODO(mdan): Document all extra features.
+
+ def __init__(self, entity_info):
+ """Initialize the transformer. Subclasses should call this.
+
+ Args:
+ entity_info: An EntityInfo object.
+ """
+ self._lineno = 0
+ self._col_offset = 0
+ self.entity_info = entity_info
+ self._enclosing_entities = []
+
+ # A stack that allows keeping mutable, scope-local state where scopes may be
+ # nested. For example, it can be used to track the usage of break
+ # statements in each loop, where loops may be nested.
+ self._local_scope_state = []
+ self.enter_local_scope()
+
+ # Allows scoping of local variables to keep state across calls to visit_*
+ # methods. Multiple scope hierchies may exist and are keyed by tag. A scope
+ # is valid at one or more nodes and all its children. Scopes created in
+ # child nodes supersede their parent. Scopes are isolated from one another.
+ self.state = _State()
+
+ @property
+ def enclosing_entities(self):
+ return tuple(self._enclosing_entities)
+
+ @property
+ def local_scope_level(self):
+ return len(self._local_scope_state)
+
+ def enter_local_scope(self, inherit=None):
+ """Deprecated. Use self.state instead.
+
+ Marks entry into a new local scope.
+
+ Args:
+ inherit: Optional enumerable of variable names to copy from the
+ parent scope.
+ """
+ scope_entered = {}
+ if inherit:
+ this_scope = self._local_scope_state[-1]
+ for name in inherit:
+ if name in this_scope:
+ scope_entered[name] = this_scope[name]
+ self._local_scope_state.append(scope_entered)
+
+ def exit_local_scope(self, keep=None):
+ """Deprecated. Use self.state instead.
+
+ Marks exit from the current local scope.
+
+ Args:
+ keep: Optional enumerable of variable names to copy into the
+ parent scope.
+ Returns:
+ A dict containing the scope that has just been exited.
+ """
+ scope_left = self._local_scope_state.pop()
+ if keep:
+ this_scope = self._local_scope_state[-1]
+ for name in keep:
+ if name in scope_left:
+ this_scope[name] = scope_left[name]
+ return scope_left
+
+ def set_local(self, name, value):
+ """Deprecated. Use self.state instead."""
+ self._local_scope_state[-1][name] = value
+
+ def get_local(self, name, default=None):
+ """Deprecated. Use self.state instead."""
+ return self._local_scope_state[-1].get(name, default)
+
+ def debug_print(self, node):
+ """Helper method useful for debugging."""
+ if __debug__:
+ print(pretty_printer.fmt(node))
+ return node
+
+ def visit_block(self, nodes, before_visit=None, after_visit=None):
+ """A more powerful version of generic_visit for statement blocks.
+
+ An example of a block is the body of an if statement.
+
+ This function allows specifying a postprocessing callback (the
+ after_visit argument) argument which can be used to move nodes to a new
+ destination. This is done by after_visit by returning a non-null
+ second return value, e.g. return new_node, new_destination.
+
+ For example, a transformer could perform the following move:
+
+ foo()
+ bar()
+ baz()
+
+ foo()
+ if cond:
+ bar()
+ baz()
+
+ The above could be done with a postprocessor of this kind:
+
+ def after_visit(node):
+ if node_is_function_call(bar):
+ new_container_node = build_cond()
+ new_container_node.body.append(node)
+ return new_container_node, new_container_node.body
+ else:
+ # Once we set a new destination, all subsequent items will be
+ # moved to it, so we don't need to explicitly handle baz.
+ return node, None
+
+ Args:
+ nodes: enumerable of AST node objects
+ before_visit: optional callable that is called before visiting each item
+ in nodes
+ after_visit: optional callable that takes in an AST node and
+ returns a tuple (new_node, new_destination). It is called after
+ visiting each item in nodes. Is used in the same was as the
+ visit_* methods: new_node will replace the node; if not None,
+ new_destination must be a list, and subsequent nodes will be placed
+ in this list instead of the list returned by visit_block.
+ Returns:
+ A list of AST node objects containing the transformed items fron nodes,
+ except those nodes that have been relocated using after_visit.
+ """
+ results = []
+ node_destination = results
+ for node in nodes:
+ if before_visit:
+ # TODO(mdan): We can modify node here too, if ever needed.
+ before_visit()
+
+ replacement = self.visit(node)
+
+ if after_visit and replacement:
+ replacement, new_destination = after_visit(replacement)
+ else:
+ new_destination = None
+
+ if replacement:
+ if isinstance(replacement, (list, tuple)):
+ node_destination.extend(replacement)
+ else:
+ node_destination.append(replacement)
+
+ # Allow the postprocessor to reroute the remaining nodes to a new list.
+ if new_destination is not None:
+ node_destination = new_destination
+ return results
+
+ # TODO(mdan): Remove.
+ def apply_to_single_assignments(self, targets, values, apply_fn):
+ """Applies a function to each individual assignment.
+
+ This function can process a possibly-unpacked (e.g. a, b = c, d) assignment.
+ It tries to break down the unpacking if possible. In effect, it has the same
+ effect as passing the assigned values in SSA form to apply_fn.
+
+ Examples:
+
+ The following will result in apply_fn(a, c), apply_fn(b, d):
+
+ a, b = c, d
+
+ The following will result in apply_fn(a, c[0]), apply_fn(b, c[1]):
+
+ a, b = c
+
+ The following will result in apply_fn(a, (b, c)):
+
+ a = b, c
+
+ It uses the visitor pattern to allow subclasses to process single
+ assignments individually.
+
+ Args:
+ targets: list, tuple of or individual AST node. Should be used with the
+ targets field of an ast.Assign node.
+ values: an AST node.
+ apply_fn: a function of a single argument, which will be called with the
+ respective nodes of each single assignment. The signature is
+ apply_fn(target, value), no return value.
+ """
+ if not isinstance(targets, (list, tuple)):
+ targets = (targets,)
+ for target in targets:
+ if isinstance(target, (gast.Tuple, gast.List)):
+ for i in range(len(target.elts)):
+ target_el = target.elts[i]
+ if isinstance(values, (gast.Tuple, gast.List)):
+ value_el = values.elts[i]
+ else:
+ value_el = gast.Subscript(values, gast.Index(i), ctx=gast.Store())
+ self.apply_to_single_assignments(target_el, value_el, apply_fn)
+ else:
+ # TODO(mdan): Look into allowing to rewrite the AST here.
+ apply_fn(target, values)
+
+ def _get_source(self, node):
+ try:
+ source, _ = compiler.ast_to_source(node)
+ return source
+ # pylint: disable=broad-except
+ # This function is used for error reporting. If an exception occurs here,
+ # it should be suppressed, in favor of emitting as informative a message
+ # about the original error as possible.
+ except Exception:
+ return '<could not convert AST to source>'
+
+ def visit(self, node):
+ if not isinstance(node, gast.AST):
+ # This is not that uncommon a mistake: various node bodies are lists, for
+ # example, posing a land mine for transformers that need to recursively
+ # call `visit`. The error needs to be raised before the exception handler
+ # below is installed, because said handler will mess up if `node` is not,
+ # in fact, a node.
+ msg = (
+ 'invalid value for "node": expected "ast.AST", got "{}"; to'
+ ' visit lists of nodes, use "visit_block" instead').format(type(node))
+ raise ValueError(msg)
+
+ source_code = self.entity_info.source_code
+ source_file = self.entity_info.source_file
+ did_enter_function = False
+ local_scope_size_at_entry = len(self._local_scope_state)
+ processing_expr_node = False
+
+ try:
+ if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
+ did_enter_function = True
+ elif isinstance(node, gast.Expr):
+ processing_expr_node = True
+
+ if did_enter_function:
+ self._enclosing_entities.append(node)
+
+ if source_code and hasattr(node, 'lineno'):
+ self._lineno = node.lineno
+ self._col_offset = node.col_offset
+
+ if processing_expr_node:
+ entry_expr_value = node.value
+
+ if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
+ result = super(Base, self).visit(node)
+
+ # Adjust for consistency: replacing the value of an Expr with
+ # an Assign node removes the need for the Expr node.
+ if processing_expr_node:
+ if isinstance(result, gast.Expr) and result.value != entry_expr_value:
+ # When the replacement is a list, it is assumed that the list came
+ # from a template that contained a number of statements, which
+ # themselves are standalone and don't require an enclosing Expr.
+ if isinstance(result.value,
+ (list, tuple, gast.Assign, gast.AugAssign)):
+ result = result.value
+
+ # On exception, the local scope integrity is not guaranteed.
+ if did_enter_function:
+ self._enclosing_entities.pop()
+
+ if local_scope_size_at_entry != len(self._local_scope_state):
+ raise AssertionError(
+ 'Inconsistent local scope stack. Before entering node %s, the'
+ ' stack had length %d, after exit it has length %d. This'
+ ' indicates enter_local_scope and exit_local_scope are not'
+ ' well paired.' % (
+ node,
+ local_scope_size_at_entry,
+ len(self._local_scope_state)
+ ))
+ return result
+
+ except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
+ msg = '%s: %s\nOffending source:\n%s\n\nOccurred at node:\n%s' % (
+ e.__class__.__name__, str(e), self._get_source(node),
+ pretty_printer.fmt(node, color=False))
+ if source_code:
+ line = source_code.splitlines()[self._lineno - 1]
+ else:
+ line = '<no source available>'
+ # TODO(mdan): Avoid the printing of the original exception.
+ # In other words, we need to find how to suppress the "During handling
+ # of the above exception, another exception occurred" message.
+ six.reraise(AutographParseError,
+ AutographParseError(
+ msg,
+ (source_file, self._lineno, self._col_offset + 1, line)),
+ sys.exc_info()[2])
diff --git a/tensorflow/python/autograph/pyct/transformer_test.py b/tensorflow/python/autograph/pyct/transformer_test.py
new file mode 100644
index 0000000000..23bf9a8e16
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/transformer_test.py
@@ -0,0 +1,369 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for templates module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import transformer
+from tensorflow.python.platform import test
+
+
+class TransformerTest(test.TestCase):
+
+ def _simple_source_info(self):
+ return transformer.EntityInfo(
+ source_code=None,
+ source_file=None,
+ namespace=None,
+ arg_values=None,
+ arg_types=None,
+ owner_type=None)
+
+ def test_entity_scope_tracking(self):
+
+ class TestTransformer(transformer.Base):
+
+ # The choice of note to assign to is arbitrary. Using Assign because it's
+ # easy to find in the tree.
+ def visit_Assign(self, node):
+ anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
+ return self.generic_visit(node)
+
+ # This will show up in the lambda function.
+ def visit_BinOp(self, node):
+ anno.setanno(node, 'enclosing_entities', self.enclosing_entities)
+ return self.generic_visit(node)
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function():
+ a = 0
+
+ class TestClass(object):
+
+ def test_method(self):
+ b = 0
+ def inner_function(x):
+ c = 0
+ d = lambda y: (x + y)
+ return c, d
+ return b, inner_function
+ return a, TestClass
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ test_function_node = node.body[0]
+ test_class = test_function_node.body[1]
+ test_method = test_class.body[0]
+ inner_function = test_method.body[1]
+ lambda_node = inner_function.body[1].value
+
+ a = test_function_node.body[0]
+ b = test_method.body[0]
+ c = inner_function.body[0]
+ lambda_expr = lambda_node.body
+
+ self.assertEqual(
+ (test_function_node,), anno.getanno(a, 'enclosing_entities'))
+ self.assertEqual((test_function_node, test_class, test_method),
+ anno.getanno(b, 'enclosing_entities'))
+ self.assertEqual(
+ (test_function_node, test_class, test_method, inner_function),
+ anno.getanno(c, 'enclosing_entities'))
+ self.assertEqual((test_function_node, test_class, test_method,
+ inner_function, lambda_node),
+ anno.getanno(lambda_expr, 'enclosing_entities'))
+
+ def assertSameAnno(self, first, second, key):
+ self.assertIs(anno.getanno(first, key), anno.getanno(second, key))
+
+ def assertDifferentAnno(self, first, second, key):
+ self.assertIsNot(anno.getanno(first, key), anno.getanno(second, key))
+
+ def test_state_tracking(self):
+
+ class LoopState(object):
+ pass
+
+ class CondState(object):
+ pass
+
+ class TestTransformer(transformer.Base):
+
+ def visit(self, node):
+ anno.setanno(node, 'loop_state', self.state[LoopState].value)
+ anno.setanno(node, 'cond_state', self.state[CondState].value)
+ return super(TestTransformer, self).visit(node)
+
+ def visit_While(self, node):
+ self.state[LoopState].enter()
+ node = self.generic_visit(node)
+ self.state[LoopState].exit()
+ return node
+
+ def visit_If(self, node):
+ self.state[CondState].enter()
+ node = self.generic_visit(node)
+ self.state[CondState].exit()
+ return node
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function(a):
+ a = 1
+ while a:
+ _ = 'a'
+ if a > 2:
+ _ = 'b'
+ while True:
+ raise '1'
+ if a > 3:
+ _ = 'c'
+ while True:
+ raise '1'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ fn_body = node.body[0].body
+ outer_while_body = fn_body[1].body
+ self.assertSameAnno(fn_body[0], outer_while_body[0], 'cond_state')
+ self.assertDifferentAnno(fn_body[0], outer_while_body[0], 'loop_state')
+
+ first_if_body = outer_while_body[1].body
+ self.assertDifferentAnno(outer_while_body[0], first_if_body[0],
+ 'cond_state')
+ self.assertSameAnno(outer_while_body[0], first_if_body[0], 'loop_state')
+
+ first_inner_while_body = first_if_body[1].body
+ self.assertSameAnno(first_if_body[0], first_inner_while_body[0],
+ 'cond_state')
+ self.assertDifferentAnno(first_if_body[0], first_inner_while_body[0],
+ 'loop_state')
+
+ second_if_body = outer_while_body[2].body
+ self.assertDifferentAnno(first_if_body[0], second_if_body[0], 'cond_state')
+ self.assertSameAnno(first_if_body[0], second_if_body[0], 'loop_state')
+
+ second_inner_while_body = second_if_body[1].body
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'cond_state')
+ self.assertDifferentAnno(first_inner_while_body[0],
+ second_inner_while_body[0], 'loop_state')
+
+ def test_local_scope_info_stack(self):
+
+ class TestTransformer(transformer.Base):
+
+ # Extract all string constants from the block.
+ def visit_Str(self, node):
+ self.set_local('string', self.get_local('string', default='') + node.s)
+ return self.generic_visit(node)
+
+ def _annotate_result(self, node):
+ self.enter_local_scope()
+ node = self.generic_visit(node)
+ anno.setanno(node, 'test', self.get_local('string'))
+ self.exit_local_scope()
+ return node
+
+ def visit_While(self, node):
+ return self._annotate_result(node)
+
+ def visit_For(self, node):
+ return self._annotate_result(node)
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def test_function(a):
+ """Docstring."""
+ assert a == 'This should not be counted'
+ for i in range(3):
+ _ = 'a'
+ if i > 2:
+ return 'b'
+ else:
+ _ = 'c'
+ while True:
+ raise '1'
+ return 'nor this'
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+
+ for_node = node.body[0].body[2]
+ while_node = for_node.body[1].orelse[1]
+
+ self.assertFalse(anno.hasanno(for_node, 'string'))
+ self.assertEqual('abc', anno.getanno(for_node, 'test'))
+ self.assertFalse(anno.hasanno(while_node, 'string'))
+ self.assertEqual('1', anno.getanno(while_node, 'test'))
+
+ def test_local_scope_info_stack_checks_integrity(self):
+
+ class TestTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ self.enter_local_scope()
+ return self.generic_visit(node)
+
+ def visit_For(self, node):
+ node = self.generic_visit(node)
+ self.exit_local_scope()
+ return node
+
+ tr = TestTransformer(self._simple_source_info())
+
+ def no_exit(a):
+ if a > 0:
+ print(a)
+ return None
+
+ node, _ = parser.parse_entity(no_exit)
+ with self.assertRaises(AssertionError):
+ tr.visit(node)
+
+ def no_entry(a):
+ for _ in a:
+ print(a)
+
+ node, _ = parser.parse_entity(no_entry)
+ with self.assertRaises(AssertionError):
+ tr.visit(node)
+
+ def test_visit_block_postprocessing(self):
+
+ class TestTransformer(transformer.Base):
+
+ def _process_body_item(self, node):
+ if isinstance(node, gast.Assign) and (node.value.id == 'y'):
+ if_node = gast.If(gast.Name('x', gast.Load(), None), [node], [])
+ return if_node, if_node.body
+ return node, None
+
+ def visit_FunctionDef(self, node):
+ node.body = self.visit_block(
+ node.body, after_visit=self._process_body_item)
+ return node
+
+ def test_function(x, y):
+ z = x
+ z = y
+ return z
+
+ tr = TestTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ node = tr.visit(node)
+ node = node.body[0]
+
+ self.assertEqual(len(node.body), 2)
+ self.assertTrue(isinstance(node.body[0], gast.Assign))
+ self.assertTrue(isinstance(node.body[1], gast.If))
+ self.assertTrue(isinstance(node.body[1].body[0], gast.Assign))
+ self.assertTrue(isinstance(node.body[1].body[1], gast.Return))
+
+ def test_robust_error_on_list_visit(self):
+
+ class BrokenTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ # This is broken because visit expects a single node, not a list, and
+ # the body of an if is a list.
+ # Importantly, the default error handling in visit also expects a single
+ # node. Therefore, mistakes like this need to trigger a type error
+ # before the visit called here installs its error handler.
+ # That type error can then be caught by the enclosing call to visit,
+ # and correctly blame the If node.
+ self.visit(node.body)
+ return node
+
+ def test_function(x):
+ if x > 0:
+ return x
+
+ tr = BrokenTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ with self.assertRaises(transformer.AutographParseError) as cm:
+ node = tr.visit(node)
+ obtained_message = str(cm.exception)
+ expected_message = r'expected "ast.AST", got "\<(type|class) \'list\'\>"'
+ self.assertRegexpMatches(obtained_message, expected_message)
+ # The exception should point at the if statement, not any place else. Could
+ # also check the stack trace.
+ self.assertTrue(
+ 'Occurred at node:\nIf' in obtained_message, obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nFunctionDef' not in obtained_message,
+ obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
+
+ def test_robust_error_on_ast_corruption(self):
+ # A child class should not be able to be so broken that it causes the error
+ # handling in `transformer.Base` to raise an exception. Why not? Because
+ # then the original error location is dropped, and an error handler higher
+ # up in the call stack gives misleading information.
+
+ # Here we test that the error handling in `visit` completes, and blames the
+ # correct original exception, even if the AST gets corrupted.
+
+ class NotANode(object):
+ pass
+
+ class BrokenTransformer(transformer.Base):
+
+ def visit_If(self, node):
+ node.body = NotANode()
+ raise ValueError('I blew up')
+
+ def test_function(x):
+ if x > 0:
+ return x
+
+ tr = BrokenTransformer(self._simple_source_info())
+
+ node, _ = parser.parse_entity(test_function)
+ with self.assertRaises(transformer.AutographParseError) as cm:
+ node = tr.visit(node)
+ obtained_message = str(cm.exception)
+ # The message should reference the exception actually raised, not anything
+ # from the exception handler.
+ expected_substring = 'I blew up'
+ self.assertTrue(expected_substring in obtained_message, obtained_message)
+ # Expect the exception to have failed to parse the corrupted AST
+ self.assertTrue(
+ '<could not convert AST to source>' in obtained_message,
+ obtained_message)
+ # The exception should point at the if statement, not any place else. Could
+ # also check the stack trace.
+ self.assertTrue(
+ 'Occurred at node:\nIf' in obtained_message, obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nFunctionDef' not in obtained_message,
+ obtained_message)
+ self.assertTrue(
+ 'Occurred at node:\nReturn' not in obtained_message, obtained_message)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/BUILD b/tensorflow/python/autograph/utils/BUILD
new file mode 100644
index 0000000000..22451d4f3f
--- /dev/null
+++ b/tensorflow/python/autograph/utils/BUILD
@@ -0,0 +1,114 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "utils",
+ srcs = [
+ "__init__.py",
+ "context_managers.py",
+ "misc.py",
+ "multiple_dispatch.py",
+ "py_func.py",
+ "tensor_list.py",
+ "tensors.py",
+ "testing.py",
+ "type_check.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "context_managers_test",
+ srcs = ["context_managers_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "misc_test",
+ srcs = ["misc_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "multiple_dispatch_test",
+ srcs = ["multiple_dispatch_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "py_func_test",
+ srcs = ["py_func_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "type_check_test",
+ srcs = ["type_check_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "tensor_list_test",
+ srcs = ["tensor_list_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:list_ops",
+ ],
+)
+
+py_test(
+ name = "tensors_test",
+ srcs = ["tensors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/python/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py
new file mode 100644
index 0000000000..e38c82a079
--- /dev/null
+++ b/tensorflow/python/autograph/utils/__init__.py
@@ -0,0 +1,29 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility module that contains APIs usable in the generated code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
+from tensorflow.python.autograph.utils.misc import alias_tensors
+from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is
+from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is_not
+from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
+from tensorflow.python.autograph.utils.py_func import wrap_py_func
+from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
+from tensorflow.python.autograph.utils.testing import fake_tf
+from tensorflow.python.autograph.utils.type_check import is_tensor
diff --git a/tensorflow/python/autograph/utils/context_managers.py b/tensorflow/python/autograph/utils/context_managers.py
new file mode 100644
index 0000000000..3d150a9581
--- /dev/null
+++ b/tensorflow/python/autograph/utils/context_managers.py
@@ -0,0 +1,49 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Various context managers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+def control_dependency_on_returns(return_value):
+ """Create a TF control dependency on the return values of a function.
+
+ If the function had no return value, a no-op context is returned.
+
+ Args:
+ return_value: The return value to set as control dependency.
+
+ Returns:
+ A context manager.
+ """
+ def control_dependency_handle(t):
+ if isinstance(t, tensor_array_ops.TensorArray):
+ return t.flow
+ return t
+
+ if return_value is None:
+ return contextlib.contextmanager(lambda: (yield))()
+ # TODO(mdan): Filter to tensor objects.
+ if not isinstance(return_value, (list, tuple)):
+ return_value = (return_value,)
+ return_value = tuple(control_dependency_handle(t) for t in return_value)
+ return ops.control_dependencies(return_value)
diff --git a/tensorflow/python/autograph/utils/context_managers_test.py b/tensorflow/python/autograph/utils/context_managers_test.py
new file mode 100644
index 0000000000..7f0a15b076
--- /dev/null
+++ b/tensorflow/python/autograph/utils/context_managers_test.py
@@ -0,0 +1,47 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for context_managers module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils import context_managers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class ContextManagersTest(test.TestCase):
+
+ def test_control_dependency_on_returns(self):
+ # Just dry run them.
+ with context_managers.control_dependency_on_returns(None):
+ pass
+ with context_managers.control_dependency_on_returns(
+ constant_op.constant(1)):
+ pass
+ with context_managers.control_dependency_on_returns(
+ tensor_array_ops.TensorArray(dtypes.int32, size=1)):
+ pass
+ with context_managers.control_dependency_on_returns(
+ [constant_op.constant(1),
+ constant_op.constant(2)]):
+ pass
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/misc.py b/tensorflow/python/autograph/utils/misc.py
new file mode 100644
index 0000000000..1b06caf0bd
--- /dev/null
+++ b/tensorflow/python/autograph/utils/misc.py
@@ -0,0 +1,50 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Miscellaneous utilities that don't fit anywhere else."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+
+
+def alias_tensors(*args):
+ """Wrap any Tensor arguments with an identity op.
+
+ Any other argument, including Variables, is returned unchanged.
+
+ Args:
+ *args: Any arguments. Must contain at least one element.
+
+ Returns:
+ Same as *args, with Tensor instances replaced as described.
+
+ Raises:
+ ValueError: If args doesn't meet the requirements.
+ """
+
+ def alias_if_tensor(a):
+ return array_ops.identity(a) if isinstance(a, ops.Tensor) else a
+
+ # TODO(mdan): Recurse into containers?
+ # TODO(mdan): Anything we can do about variables? Fake a scope reuse?
+ if len(args) > 1:
+ return (alias_if_tensor(a) for a in args)
+ elif len(args) == 1:
+ return alias_if_tensor(args[0])
+
+ raise ValueError('at least one argument required')
diff --git a/tensorflow/python/autograph/utils/misc_test.py b/tensorflow/python/autograph/utils/misc_test.py
new file mode 100644
index 0000000000..8d2b0d6e13
--- /dev/null
+++ b/tensorflow/python/autograph/utils/misc_test.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for misc module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils.misc import alias_tensors
+from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.ops.variables import Variable
+from tensorflow.python.platform import test
+
+
+class MiscTest(test.TestCase):
+
+ def test_alias_single_tensor(self):
+ a = constant(1)
+
+ new_a = alias_tensors(a)
+ self.assertFalse(new_a is a)
+ with self.cached_session() as sess:
+ self.assertEqual(1, sess.run(new_a))
+
+ def test_alias_tensors(self):
+ a = constant(1)
+ v = Variable(2)
+ s = 'a'
+ l = [1, 2, 3]
+
+ new_a, new_v, new_s, new_l = alias_tensors(a, v, s, l)
+
+ self.assertFalse(new_a is a)
+ self.assertTrue(new_v is v)
+ self.assertTrue(new_s is s)
+ self.assertTrue(new_l is l)
+ with self.cached_session() as sess:
+ self.assertEqual(1, sess.run(new_a))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py
new file mode 100644
index 0000000000..33f521db2c
--- /dev/null
+++ b/tensorflow/python/autograph/utils/multiple_dispatch.py
@@ -0,0 +1,66 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for type-dependent behavior used in autograph-generated code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils.type_check import is_tensor
+from tensorflow.python.ops import control_flow_ops
+
+
+def dynamic_is(left, right):
+ # TODO(alexbw) if we're sure we should leave 'is' in place,
+ # then change the semantics in converters/logical_expressions.py
+ return left is right
+
+
+def dynamic_is_not(left, right):
+ return left is not right
+
+
+def run_cond(condition, true_fn, false_fn):
+ """Type-dependent functional conditional.
+
+ Args:
+ condition: A Tensor or Python bool.
+ true_fn: A Python callable implementing the true branch of the conditional.
+ false_fn: A Python callable implementing the false branch of the
+ conditional.
+
+ Returns:
+ result: The result of calling the appropriate branch. If condition is a
+ Tensor, tf.cond will be used. Otherwise, a standard Python if statement will
+ be ran.
+ """
+ if is_tensor(condition):
+ return control_flow_ops.cond(condition, true_fn, false_fn)
+ else:
+ return py_cond(condition, true_fn, false_fn)
+
+
+def py_cond(condition, true_fn, false_fn):
+ """Functional version of Python's conditional."""
+ if condition:
+ results = true_fn()
+ else:
+ results = false_fn()
+
+ # The contract for the branch functions is to return tuples, but they should
+ # be collapsed to a single element when there is only one output.
+ if len(results) == 1:
+ return results[0]
+ return results
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
new file mode 100644
index 0000000000..ed20822529
--- /dev/null
+++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
@@ -0,0 +1,75 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for multiple_dispatch."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.autograph.utils import multiple_dispatch
+from tensorflow.python.client.session import Session
+from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.platform import test
+
+
+class MultipleDispatchTest(test.TestCase):
+
+ def test_dynamic_is_python(self):
+ a = np.eye(3)
+ also_a = a
+ not_actually_a = np.eye(3)
+ should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
+ should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
+ should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
+ should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
+ self.assertTrue(should_be_true1)
+ self.assertTrue(should_be_true2)
+ self.assertFalse(should_be_false1)
+ self.assertFalse(should_be_false2)
+
+ def test_dynamic_is_tf(self):
+ with Session().as_default():
+ a = constant([2.0])
+ also_a = a
+ not_actually_a = constant([2.0])
+ should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
+ should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
+ should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
+ should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
+ self.assertTrue(should_be_true1)
+ self.assertTrue(should_be_true2)
+ self.assertFalse(should_be_false1)
+ self.assertFalse(should_be_false2)
+
+ def test_run_cond_python(self):
+ true_fn = lambda: (2,)
+ false_fn = lambda: (3,)
+ self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2)
+ self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3)
+
+ def test_run_cond_tf(self):
+ true_fn = lambda: (constant(2),)
+ false_fn = lambda: (constant(3),)
+ with Session() as sess:
+ out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn)
+ self.assertEqual(sess.run(out), 2)
+ out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
+ self.assertEqual(sess.run(out), 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/py_func.py b/tensorflow/python/autograph/utils/py_func.py
new file mode 100644
index 0000000000..11ebfb2e49
--- /dev/null
+++ b/tensorflow/python/autograph/utils/py_func.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Pyfunc creation utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import namedtuple
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import script_ops
+
+
+class MatchDType(namedtuple('MatchDType', ('arg_number',))):
+ """Allows matching the dtype of an argument.
+
+ Used in conjunction with function calls. For example, MatchDType(0) will
+ match the DType of the first argument.
+ """
+
+ pass
+
+
+def wrap_py_func(f, return_dtypes, args, kwargs=None, use_dummy_return=False):
+ """Helper that wraps a callable to py_func.
+
+ The helper passes tensor arguments through the py_func interface. Non-tensor
+ arguments are allowed, and will be passed to f directly. Note that non-tensor
+ arguments are captured by f will not update every time the wrapper is
+ called (this is consistent with its argument list, which only includes
+ the tensor arguments). In general, it's safest not to reuse this wrapper.
+
+ Args:
+ f: Callable
+ return_dtypes: None, individual of tuple/list of DType or MatchDType, the
+ data type for each of f's return value(s). Set to None if f has no
+ return values or use_dummy_return is True. Use MatchDType to define a
+ dtype identical to that of `i`th argument (argument 0 is the first);
+ an argument must of Tensor type if it is to be used with MatchDType.
+ args: Positional arguments for f, as list or tuple.
+ kwargs: Keyword arguments for f, as dict with string keys. May be None.
+ use_dummy_return: If True, the function will return a dummy value of 1
+ and discard its actual return value.
+ Returns:
+ The return values of f converted to tensor.
+ Raises:
+ ValueError: if any of the arguments are incorrect.
+ """
+
+ if return_dtypes and use_dummy_return:
+ raise ValueError('if use_dummy_return is True, return_dtypes must be empty')
+
+ tensor_args = []
+ tensor_args_idx = {}
+
+ # Of the positional arguments, only grab the tensor ones to be passed through
+ # the py_func.
+ n_args = len(args)
+ arg_is_tensor = tuple(map(tensor_util.is_tensor, args))
+ for i in range(n_args):
+ if arg_is_tensor[i]:
+ tensor_args_idx[i] = len(tensor_args)
+ tensor_args.append(args[i])
+
+ # We essentially take the tensor kwargs, if any, and add them to the list of
+ # positional arguments. The kwargs are then reconstructed inside the py_func.
+ #
+ # For example, if
+ #
+ # args = [Tensor(1), 'foo']
+ # kwargs = {'a': Tensor(2), 'b': 'bar'}
+ #
+ # Then
+ #
+ # tensor_args = (Tensor(1), Tensor(2))
+ # kwarg_keys = ('a', 'b')
+ if kwargs:
+ kwarg_keys = tuple(kwargs.keys())
+ kwarg_is_tensor = {k: tensor_util.is_tensor(kwargs[k]) for k in kwarg_keys}
+ for k in kwarg_keys:
+ if kwarg_is_tensor[k]:
+ tensor_args_idx[k] = len(tensor_args)
+ tensor_args.append(kwargs[k])
+ else:
+ kwarg_keys = ()
+
+ # Set up return dtypes.
+ def match_arg_dtype(arg_number):
+ arg = args[arg_number]
+ if not arg_is_tensor[arg_number]:
+ raise ValueError(
+ 'argument %d was used with MatchDType and must be a tf.Tensor, but '
+ 'was %s instead' % (arg_number, type(arg)))
+ return arg.dtype
+
+ if return_dtypes:
+ if isinstance(return_dtypes, MatchDType):
+ return_dtypes = match_arg_dtype(return_dtypes.arg_number)
+ elif isinstance(return_dtypes, (list, tuple)):
+ return_dtypes = tuple(
+ match_arg_dtype(a.arg_number) if isinstance(a, MatchDType) else a
+ for a in return_dtypes)
+ else:
+ assert isinstance(return_dtypes, dtypes.DType)
+
+ def f_wrapper(*tensor_args):
+ f_args = tuple(tensor_args[tensor_args_idx[i]] if arg_is_tensor[i] else a
+ for i, a in enumerate(args))
+ f_kwargs = {
+ k: tensor_args[tensor_args_idx[k]] if kwarg_is_tensor[k] else kwargs[k]
+ for i, k in enumerate(kwarg_keys)
+ }
+ retval = f(*f_args, **f_kwargs)
+ return 1 if use_dummy_return else retval
+
+ return script_ops.py_func(f_wrapper, tensor_args, dtypes.int64
+ if use_dummy_return else return_dtypes)
diff --git a/tensorflow/python/autograph/utils/py_func_test.py b/tensorflow/python/autograph/utils/py_func_test.py
new file mode 100644
index 0000000000..1c220d9492
--- /dev/null
+++ b/tensorflow/python/autograph/utils/py_func_test.py
@@ -0,0 +1,103 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for wrap_py_func module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class PyFuncTest(test.TestCase):
+
+ def test_wrap_py_func_simple(self):
+
+ def test_fn(a, b, c):
+ return a + b + c
+
+ with self.cached_session() as sess:
+ result = py_func.wrap_py_func(test_fn, dtypes.int64,
+ (1, constant_op.constant(1), 1))
+ self.assertEqual(3, sess.run(result))
+ result = py_func.wrap_py_func(test_fn, dtypes.int64, (1, 1, 1))
+ self.assertEqual(3, sess.run(result))
+ result = py_func.wrap_py_func(
+ test_fn, dtypes.int64,
+ (constant_op.constant(1), 1, constant_op.constant(1)))
+ self.assertEqual(3, sess.run(result))
+
+ def test_wrap_py_func_complex_args(self):
+
+ class TestClass(object):
+
+ def __init__(self):
+ self.foo = 5
+
+ def test_fn(a, b):
+ return a * b.foo
+
+ with self.cached_session() as sess:
+ result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
+ self.assertEqual(35, sess.run(result))
+ result = py_func.wrap_py_func(test_fn, dtypes.int64,
+ (constant_op.constant(7), TestClass()))
+ self.assertEqual(35, sess.run(result))
+
+ def test_wrap_py_func_kwargs(self):
+
+ class TestClass(object):
+
+ def __init__(self, foo):
+ self.foo = foo
+
+ def test_fn(a, b, c, d):
+ return a * b.foo + c * d.foo
+
+ with self.cached_session() as sess:
+ result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
+ 'c': 11,
+ 'd': TestClass(13)
+ })
+ self.assertEqual(178, sess.run(result))
+ result = py_func.wrap_py_func(test_fn, dtypes.int64,
+ (constant_op.constant(7), TestClass(5)), {
+ 'c': constant_op.constant(11),
+ 'd': TestClass(13)
+ })
+ self.assertEqual(178, sess.run(result))
+
+ def test_wrap_py_func_dummy_return(self):
+
+ side_counter = [0]
+
+ def test_fn(_):
+ side_counter[0] += 1
+
+ with self.cached_session() as sess:
+ result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
+ self.assertEqual(1, sess.run(result))
+ self.assertEqual([1], side_counter)
+ result = py_func.wrap_py_func(
+ test_fn, None, (constant_op.constant(5),), use_dummy_return=True)
+ self.assertEqual(1, sess.run(result))
+ self.assertEqual([2], side_counter)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/tensor_list.py b/tensorflow/python/autograph/utils/tensor_list.py
new file mode 100644
index 0000000000..2556f41289
--- /dev/null
+++ b/tensorflow/python/autograph/utils/tensor_list.py
@@ -0,0 +1,68 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A typed list in Python."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+def dynamic_list_append(target, element):
+ """Converts a list append call inline."""
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return target.write(target.size(), element)
+ # TODO(mdan): What's the right way to check this?
+ # TODO(mdan): We may not need this branch.
+ # It may be possible to use TensorList alone if the loop body will not
+ # require wrapping it, although we'd have to think about an autoboxing
+ # mechanism for lists received as parameter.
+ if isinstance(target, ops.Tensor):
+ return list_ops.tensor_list_push_back(target, element)
+
+ # Python targets (including TensorList): fallback to their original append.
+ target.append(element)
+ return target
+
+
+class TensorList(object):
+ """Tensor list wrapper API-compatible with Python built-in list."""
+
+ def __init__(self, shape, dtype):
+ self.dtype = dtype
+ self.shape = shape
+ self.clear()
+
+ def append(self, value):
+ self.list_ = list_ops.tensor_list_push_back(self.list_, value)
+
+ def pop(self):
+ self.list_, value = list_ops.tensor_list_pop_back(self.list_, self.dtype)
+ return value
+
+ def clear(self):
+ self.list_ = list_ops.empty_tensor_list(self.shape, self.dtype)
+
+ def count(self):
+ return list_ops.tensor_list_length(self.list_)
+
+ def __getitem__(self, key):
+ return list_ops.tensor_list_get_item(self.list_, key, self.dtype)
+
+ def __setitem__(self, key, value):
+ self.list_ = list_ops.tensor_list_set_item(self.list_, key, value)
diff --git a/tensorflow/python/autograph/utils/tensor_list_test.py b/tensorflow/python/autograph/utils/tensor_list_test.py
new file mode 100644
index 0000000000..697c166eb1
--- /dev/null
+++ b/tensorflow/python/autograph/utils/tensor_list_test.py
@@ -0,0 +1,117 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Autograph lists."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils import tensor_list as tl
+from tensorflow.python.client.session import Session
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorListTest(test.TestCase):
+
+ def _shape(self, shape_tuple):
+ return constant(shape_tuple, dtypes.int32)
+
+ def test_dynamic_list_append(self):
+ l = []
+ l = tl.dynamic_list_append(l, 1)
+ self.assertListEqual(l, [1])
+
+ l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
+ l = tl.dynamic_list_append(l, 1)
+ s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(s), [1])
+
+ l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
+ l = tl.dynamic_list_append(l, 1)
+ s = l.stack()
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(s), [1])
+
+ l = tl.TensorList(self._shape(()), dtypes.int32)
+ l = tl.dynamic_list_append(l, 1)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(l[0]), 1)
+
+ def test_list_append_python(self):
+ with context.eager_mode():
+ a = constant(3.0)
+ l = tl.TensorList(a.shape, a.dtype)
+ l.append(a)
+ self.assertEqual(l.count().numpy(), 1)
+ l.append(a)
+ self.assertEqual(l.count().numpy(), 2)
+ _ = l.pop()
+ self.assertEqual(l.count().numpy(), 1)
+ a2 = l.pop()
+ self.assertEqual(l.count().numpy(), 0)
+ self.assertEqual(a.numpy(), a2.numpy())
+
+ def test_list_index_python(self):
+ with context.eager_mode():
+ a = constant(3.0)
+ b = constant(2.0)
+ l = tl.TensorList(a.shape, a.dtype)
+ l.append(a)
+ self.assertEqual(l[0].numpy(), a.numpy())
+ l[0] = ops.convert_to_tensor(b)
+ self.assertEqual(l[0].numpy(), b.numpy())
+
+ def test_list_append_tf(self):
+ a = constant(3.0)
+ l = tl.TensorList(a.shape, a.dtype)
+ l.append(a)
+ c1 = l.count()
+ l.append(a)
+ c2 = l.count()
+ _ = l.pop()
+ c3 = l.count()
+ a2 = l.pop()
+ c4 = l.count()
+ with Session() as sess:
+ c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2])
+ self.assertEqual(c1, 1)
+ self.assertEqual(c2, 2)
+ self.assertEqual(c3, 1)
+ self.assertEqual(c4, 0)
+ self.assertEqual(a, a2)
+
+ def test_list_index_tf(self):
+ a = constant(3.0)
+ b = constant(2.0)
+ l = tl.TensorList(a.shape, a.dtype)
+ l.append(a)
+ l0 = l[0]
+ l[0] = b
+ l1 = l[0]
+ with self.cached_session() as sess:
+ l0, l1, a, b = sess.run([l0, l1, a, b])
+ self.assertEqual(l0, a)
+ self.assertEqual(l1, b)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/tensors.py b/tensorflow/python/autograph/utils/tensors.py
new file mode 100644
index 0000000000..fa5db81a71
--- /dev/null
+++ b/tensorflow/python/autograph/utils/tensors.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module defines tensor utilities not found in TensorFlow.
+
+The reason these utilities are not defined in TensorFlow is because they may
+not be not fully robust, although they work in the vast majority of cases. So
+we define them here in order for their behavior to be consistently verified.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import tensor_array_ops
+
+
+def is_tensor_array(t):
+ return isinstance(t, tensor_array_ops.TensorArray)
+
+
+def is_tensor_list(t):
+ # TODO(mdan): This is just a heuristic.
+ # With TF lacking support for templated types, this is unfortunately the
+ # closest we can get right now. A dedicated op ought to be possible to
+ # construct.
+ return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
+ not t.shape.ndims)
diff --git a/tensorflow/python/autograph/utils/tensors_test.py b/tensorflow/python/autograph/utils/tensors_test.py
new file mode 100644
index 0000000000..1e7cfec9e1
--- /dev/null
+++ b/tensorflow/python/autograph/utils/tensors_test.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorsTest(test.TestCase):
+
+ def _simple_tensor_array(self):
+ return tensor_array_ops.TensorArray(dtypes.int32, size=3)
+
+ def _simple_tensor_list(self):
+ return list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([1]), element_dtype=dtypes.int32)
+
+ def _simple_list_of_tensors(self):
+ return [constant_op.constant(1), constant_op.constant(2)]
+
+ def test_is_tensor_array(self):
+ self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array()))
+ self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_array(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_array(None))
+
+ def test_is_tensor_list(self):
+ self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array()))
+ self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_list(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_list(None))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/utils/testing.py b/tensorflow/python/autograph/utils/testing.py
new file mode 100644
index 0000000000..cb4785d0dc
--- /dev/null
+++ b/tensorflow/python/autograph/utils/testing.py
@@ -0,0 +1,35 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Testing utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import imp
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+def fake_tf():
+ """Creates a fake module that looks like TensorFlow, for testing."""
+ mod = imp.new_module('tensorflow')
+ mod_contents = dict()
+ mod_contents.update(math_ops.__dict__)
+ mod_contents.update(ops.__dict__)
+ mod_contents.update(mod.__dict__)
+ mod.__dict__.update(mod_contents)
+ return mod
diff --git a/tensorflow/python/autograph/utils/type_check.py b/tensorflow/python/autograph/utils/type_check.py
new file mode 100644
index 0000000000..8748abc47b
--- /dev/null
+++ b/tensorflow/python/autograph/utils/type_check.py
@@ -0,0 +1,33 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities used in autograph-generated code."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import tensor_util
+
+
+def is_tensor(*args):
+ """Check if any arguments are tensors.
+
+ Args:
+ *args: Python objects that may or may not be tensors.
+
+ Returns:
+ True if any *args are TensorFlow types, False if none are.
+ """
+ return any([tensor_util.is_tensor(a) for a in args])
diff --git a/tensorflow/python/autograph/utils/type_check_test.py b/tensorflow/python/autograph/utils/type_check_test.py
new file mode 100644
index 0000000000..b3d1304e16
--- /dev/null
+++ b/tensorflow/python/autograph/utils/type_check_test.py
@@ -0,0 +1,43 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for type_check."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy
+
+from tensorflow.python.autograph.utils import type_check
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class TypeCheckTest(test.TestCase):
+
+ def test_checks(self):
+ self.assertTrue(type_check.is_tensor(constant_op.constant([1, 2, 3])))
+ self.assertTrue(
+ type_check.is_tensor(test_util.variables.Variable([1, 2, 3])))
+ self.assertTrue(
+ type_check.is_tensor(
+ test_util.array_ops.placeholder(test_util.dtypes.float32)))
+ self.assertFalse(type_check.is_tensor(3))
+ self.assertFalse(type_check.is_tensor(numpy.eye(3)))
+
+
+if __name__ == '__main__':
+ test.main()