aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 09:06:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 09:10:30 -0800
commit056c3167b8f6f829ecc2663c7df2bf2c1419747b (patch)
treee385c88bc97a2603f48ec74a36248a17721df2ac
parent6478a30b84a6620b853b450761e12f7075b7a43f (diff)
Desugar IfExp nodes
PiperOrigin-RevId: 188491604
-rw-r--r--tensorflow/contrib/py2tf/converters/BUILD12
-rw-r--r--tensorflow/contrib/py2tf/converters/ifexp.py49
-rw-r--r--tensorflow/contrib/py2tf/converters/ifexp_test.py106
-rw-r--r--tensorflow/contrib/py2tf/impl/conversion.py2
4 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/py2tf/converters/BUILD
index c85ad9200e..f624c42686 100644
--- a/tensorflow/contrib/py2tf/converters/BUILD
+++ b/tensorflow/contrib/py2tf/converters/BUILD
@@ -25,6 +25,7 @@ py_library(
"control_flow.py",
"decorators.py",
"for_loops.py",
+ "ifexp.py",
"list_comprehension.py",
"lists.py",
"logical_expressions.py",
@@ -202,3 +203,14 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "ifexp_test",
+ srcs = ["ifexp_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":test_lib",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/py2tf/converters/ifexp.py b/tensorflow/contrib/py2tf/converters/ifexp.py
new file mode 100644
index 0000000000..5fd6f348af
--- /dev/null
+++ b/tensorflow/contrib/py2tf/converters/ifexp.py
@@ -0,0 +1,49 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Canonicalizes the ternary conditional operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.pyct import templates
+from tensorflow.contrib.py2tf.pyct import transformer
+
+
+class IfExp(transformer.Base):
+ """Canonicalizes all IfExp nodes into plain conditionals."""
+
+ def visit_IfExp(self, node):
+ template = """
+ py2tf_utils.run_cond(test, lambda: body, lambda: orelse)
+ """
+ desugared_ifexp = templates.replace_as_expression(
+ template, test=node.test, body=node.body, orelse=node.orelse)
+ return desugared_ifexp
+
+
+def transform(node, context):
+ """Desugar IfExp nodes into plain conditionals.
+
+ Args:
+ node: an AST node to transform
+ context: a context object
+
+ Returns:
+ new_node: an AST with no IfExp nodes, only conditionals.
+ """
+
+ node = IfExp(context).visit(node)
+ return node
diff --git a/tensorflow/contrib/py2tf/converters/ifexp_test.py b/tensorflow/contrib/py2tf/converters/ifexp_test.py
new file mode 100644
index 0000000000..9c357ef35b
--- /dev/null
+++ b/tensorflow/contrib/py2tf/converters/ifexp_test.py
@@ -0,0 +1,106 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ifexp module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf import utils
+from tensorflow.contrib.py2tf.converters import converter_test_base
+from tensorflow.contrib.py2tf.converters import ifexp
+from tensorflow.python.platform import test
+
+
+class IfExpTest(converter_test_base.TestCase):
+
+ def compiled_fn(self, test_fn, *args):
+ node = self.parse_and_analyze(test_fn, {})
+ node = ifexp.transform(node, self.ctx)
+ module = self.compiled(node, *args)
+ return module
+
+ def test_simple(self):
+
+ def test_fn(x):
+ return 1 if x else 0
+
+ with self.compiled_fn(test_fn) as result:
+ result.py2tf_util = utils
+ for x in [0, 1]:
+ self.assertEqual(test_fn(x), result.test_fn(x))
+
+ def test_fn(self):
+
+ def f(x):
+ return 3 * x
+
+ def test_fn(x):
+ y = f(x * x if x > 0 else x)
+ return y
+
+ with self.compiled_fn(test_fn) as result:
+ result.py2tf_util = utils
+ result.f = f
+ for x in [-2, 2]:
+ self.assertEqual(test_fn(x), result.test_fn(x))
+
+ def test_exp(self):
+
+ def test_fn(x):
+ return x * x if x > 0 else x
+
+ with self.compiled_fn(test_fn) as result:
+ result.py2tf_util = utils
+ for x in [-2, 2]:
+ self.assertEqual(test_fn(x), result.test_fn(x))
+
+ def test_nested(self):
+
+ def test_fn(x):
+ return x * x if x > 0 else x if x else 1
+
+ with self.compiled_fn(test_fn) as result:
+ result.py2tf_util = utils
+ for x in [-2, 0, 2]:
+ self.assertEqual(test_fn(x), result.test_fn(x))
+
+ def test_in_cond(self):
+
+ def test_fn(x):
+ if x > 0:
+ return x * x if x < 5 else x * x * x
+ return -x
+
+ with self.compiled_fn(test_fn) as result:
+ result.py2tf_util = utils
+ for x in [-2, 2, 5]:
+ self.assertEqual(test_fn(x), result.test_fn(x))
+
+ def test_assign_in_cond(self):
+
+ def test_fn(x):
+ if x > 0:
+ x = -x if x < 5 else x
+ return x
+
+ with self.compiled_fn(test_fn) as result:
+ result.py2tf_util = utils
+ for x in [-2, 2, 5]:
+ self.assertEqual(test_fn(x), result.test_fn(x))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py
index 8a3cf9cd0a..37b24ab55f 100644
--- a/tensorflow/contrib/py2tf/impl/conversion.py
+++ b/tensorflow/contrib/py2tf/impl/conversion.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.py2tf.converters import continue_statements
from tensorflow.contrib.py2tf.converters import control_flow
from tensorflow.contrib.py2tf.converters import decorators
from tensorflow.contrib.py2tf.converters import for_loops
+from tensorflow.contrib.py2tf.converters import ifexp
from tensorflow.contrib.py2tf.converters import lists
from tensorflow.contrib.py2tf.converters import logical_expressions
from tensorflow.contrib.py2tf.converters import name_scopes
@@ -307,6 +308,7 @@ def node_to_graph(node, ctx, nocompile_decorators):
# source.
# TODO(mdan): Is it feasible to reconstruct intermediate source code?
ctx.source_code = None
+ node = ifexp.transform(node, ctx)
node, deps = decorators.transform(node, nocompile_decorators)
node = break_statements.transform(node, ctx)
node = asserts.transform(node, ctx)