aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-16 12:09:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-16 12:15:08 -0700
commitd3fb437da12fc326d8229bdb955580c63eaccb5f (patch)
treeb874a5c332f42e262c8d0a9bb3eebecc54dd090e
parent345ccea1ea751e426a2d2d8e8d44455c43336d8c (diff)
Copy the if statement handlers over to the operators module. They will enabled in a follow-up CL.
PiperOrigin-RevId: 193078348
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py32
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py29
2 files changed, 55 insertions, 6 deletions
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 81ae64f110..d9d8b0d593 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -25,6 +25,9 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
+# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature.
+# TODO(mdan): Rename arguments to match the AST names.
+
def for_loop(iterated, extra_cond, loop_body, init_state):
"""Functional form of a for statement.
@@ -182,3 +185,32 @@ def _py_while_loop(loop_cond, loop_body, init_state, opts):
while loop_cond(*state):
state = loop_body(*state)
return state
+
+
+def if_stmt(cond, body, orelse):
+ """Functional form of an if statement.
+
+ Args:
+ cond: Boolean.
+ body: Callable with no arguments, and outputs of the positive (if) branch
+ as return type.
+ orelse: Callable with no arguments, and outputs of the negative (else)
+ branch as return type.
+
+ Returns:
+ Tuple containing the statement outputs.
+ """
+ if tensor_util.is_tensor(cond):
+ return _tf_if_stmt(cond, body, orelse)
+ else:
+ return _py_if_stmt(cond, body, orelse)
+
+
+def _tf_if_stmt(cond, body, orelse):
+ """Overload of if_stmt that stages a TF cond."""
+ return control_flow_ops.cond(cond, body, orelse)
+
+
+def _py_if_stmt(cond, body, orelse):
+ """Overload of if_stmt that executes a Python if statement."""
+ return body() if cond else orelse()
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
index 9112b1627f..a0cd0bfa82 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph import operators
+from tensorflow.contrib.autograph.operators import control_flow
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class ForLoopTest(test.TestCase):
def test_tensor(self):
- s = operators.for_loop(
+ s = control_flow.for_loop(
constant_op.constant([1, 2, 3, 4]),
extra_cond=lambda s: True,
loop_body=lambda i, s: (s + i,),
@@ -38,7 +38,7 @@ class ForLoopTest(test.TestCase):
self.assertEqual((10,), sess.run(s))
def test_python(self):
- s = operators.for_loop(
+ s = control_flow.for_loop(
range(5),
extra_cond=lambda s: True,
loop_body=lambda i, s: (s + i,),
@@ -47,7 +47,7 @@ class ForLoopTest(test.TestCase):
def test_dataset(self):
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
- s = operators.for_loop(
+ s = control_flow.for_loop(
dataset_ops.Dataset.range(5).map(to_int32),
extra_cond=lambda s: True,
loop_body=lambda i, s: (s + i,),
@@ -60,7 +60,7 @@ class WhileLoopTest(test.TestCase):
def test_tensor(self):
n = constant_op.constant(5)
- results = operators.while_loop(
+ results = control_flow.while_loop(
loop_cond=lambda i, s: i < n,
loop_body=lambda i, s: (i + 1, s + i,),
init_state=(0, 0),
@@ -70,7 +70,7 @@ class WhileLoopTest(test.TestCase):
def test_python(self):
n = 5
- results = operators.while_loop(
+ results = control_flow.while_loop(
loop_cond=lambda i, s: i < n,
loop_body=lambda i, s: (i + 1, s + i),
init_state=(0, 0),
@@ -78,5 +78,22 @@ class WhileLoopTest(test.TestCase):
self.assertEqual((5, 10), results)
+class IfStmtTest(test.TestCase):
+
+ def test_tensor(self):
+ def test_if_stmt(cond):
+ return control_flow.if_stmt(
+ cond=cond,
+ body=lambda: 1,
+ orelse=lambda: -1)
+ with self.test_session() as sess:
+ self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True))))
+ self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False))))
+
+ def test_python(self):
+ self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1))
+ self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1))
+
+
if __name__ == '__main__':
test.main()