aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/control_flow_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/control_flow_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py212
1 files changed, 78 insertions, 134 deletions
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 735eb92a0d..ade3501426 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -20,16 +20,23 @@ from __future__ import print_function
from tensorflow.contrib.autograph.converters import control_flow
from tensorflow.contrib.autograph.core import converter_testing
+from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
class ControlFlowTest(converter_testing.TestCase):
- def test_simple_while(self):
+ def assertTransformedResult(self, test_fn, inputs, expected):
+ if not isinstance(inputs, tuple):
+ inputs = (inputs,)
+ with self.converted(test_fn, control_flow, {},
+ constant_op.constant) as result:
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(result.test_fn(*inputs)), expected)
+
+ def test_while_basic(self):
def test_fn(n):
i = 0
@@ -39,29 +46,18 @@ class ControlFlowTest(converter_testing.TestCase):
i += 1
return s, i, n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
-
- with self.compiled(node) as result:
- with self.test_session() as sess:
- self.assertEqual((10, 5, 5),
- sess.run(result.test_fn(constant_op.constant(5))))
+ self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
- def test_while_single_var(self):
+ def test_while_single_output(self):
def test_fn(n):
while n > 0:
n -= 1
return n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
+ self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
- with self.compiled(node) as result:
- with self.test_session() as sess:
- self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
-
- def test_simple_if(self):
+ def test_if_basic(self):
def test_fn(n):
a = 0
@@ -72,114 +68,85 @@ class ControlFlowTest(converter_testing.TestCase):
b = 2 * n
return a, b
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
+ self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
+ self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
+
+ def test_if_complex_outputs(self):
+
+ class TestClass(object):
- with self.compiled(node) as result:
+ def __init__(self, a, b):
+ self.a = a
+ self.b = b
+
+ def test_fn(n, obj):
+ obj.a = 0
+ obj.b = 0
+ if n > 0:
+ obj.a = -n
+ else:
+ obj.b = 2 * n
+ return obj
+
+ with self.converted(test_fn, control_flow, {}) as result:
with self.test_session() as sess:
- self.assertEqual((-1, 0),
- sess.run(result.test_fn(constant_op.constant(1))))
- self.assertEqual((0, -2),
- sess.run(result.test_fn(constant_op.constant(-1))))
+ res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
+ self.assertEqual(sess.run((res_obj.a, res_obj.b)), (-1, 0))
+ res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
+ self.assertEqual(sess.run((res_obj.a, res_obj.b)), (0, -2))
- def test_if_single_var(self):
+ def test_if_single_output(self):
def test_fn(n):
if n > 0:
n = -n
return n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
+ self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
- with self.compiled(node) as result:
- with self.test_session() as sess:
- self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
-
- def test_imbalanced_aliasing(self):
+ def test_if_semi(self):
def test_fn(n):
if n > 0:
n = 3
return n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
-
- with self.compiled(node, control_flow_ops.cond) as result:
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(2))))
- self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+ self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
+ self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
- def test_ignore_unread_variable(self):
+ def test_if_local_var(self):
def test_fn(n):
- b = 3 # pylint: disable=unused-variable
if n > 0:
b = 4
+ n = b + 1
return n
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
+ self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
+ self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
- with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(result.test_fn(constant_op.constant(3))))
- self.assertEqual(-3, sess.run(result.test_fn(constant_op.constant(-3))))
+ def test_if_no_outputs(self):
- def test_handle_temp_variable(self):
+ def test_fn(n):
+ if n > 0:
+ b = 4 # pylint:disable=unused-variable
+ return n
- def test_fn_using_temp(x, y, w):
- if x < y:
- z = x + y
- else:
- w = 2
- tmp = w
- z = x - tmp
- return z, w
+ # Without side effect guards, the if statement will stage a cond,
+ # but that will be pruned at execution.
+ self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
+ self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
- node = self.parse_and_analyze(test_fn_using_temp, {})
- node = control_flow.transform(node, self.ctx)
+ def test_if_imbalanced_outputs(self):
- with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
- with self.test_session() as sess:
- z, w = sess.run(
- result.test_fn_using_temp(
- constant_op.constant(-3), constant_op.constant(3),
- constant_op.constant(3)))
- self.assertEqual(0, z)
- self.assertEqual(3, w)
- z, w = sess.run(
- result.test_fn_using_temp(
- constant_op.constant(3), constant_op.constant(-3),
- constant_op.constant(3)))
- self.assertEqual(1, z)
- self.assertEqual(2, w)
-
- def test_fn_ignoring_temp(x, y, w):
- if x < y:
- z = x + y
- else:
- w = 2
- tmp = w
- z = x - tmp
- return z
+ def test_fn(n):
+ if n > 0:
+ b = 4
+ return b
- node = self.parse_and_analyze(test_fn_ignoring_temp, {})
- node = control_flow.transform(node, self.ctx)
-
- with self.compiled(node, control_flow_ops.cond, array_ops.ones) as result:
- with self.test_session() as sess:
- z = sess.run(
- result.test_fn_ignoring_temp(
- constant_op.constant(-3), constant_op.constant(3),
- constant_op.constant(3)))
- self.assertEqual(0, z)
- z = sess.run(
- result.test_fn_ignoring_temp(
- constant_op.constant(3), constant_op.constant(-3),
- constant_op.constant(3)))
- self.assertEqual(1, z)
+ node, ctx = self.prepare(test_fn, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
def test_simple_for(self):
@@ -191,22 +158,11 @@ class ControlFlowTest(converter_testing.TestCase):
s2 += e * e
return s1, s2
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
+ self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10))
+ empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
+ self.assertTransformedResult(test_fn, empty_vector, (0, 0))
- with self.compiled(node) as result:
- with self.test_session() as sess:
- l = [1, 2, 3]
- self.assertEqual(
- test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
- l = []
- self.assertEqual(
- test_fn(l),
- sess.run(
- result.test_fn(
- constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
-
- def test_for_single_var(self):
+ def test_for_single_output(self):
def test_fn(l):
s = 0
@@ -214,22 +170,11 @@ class ControlFlowTest(converter_testing.TestCase):
s += e
return s
- node = self.parse_and_analyze(test_fn, {})
- node = control_flow.transform(node, self.ctx)
+ self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4)
+ empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
+ self.assertTransformedResult(test_fn, empty_vector, 0)
- with self.compiled(node) as result:
- with self.test_session() as sess:
- l = [1, 2, 3]
- self.assertEqual(
- test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
- l = []
- self.assertEqual(
- test_fn(l),
- sess.run(
- result.test_fn(
- constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
-
- def test_for_with_iterated_expression(self):
+ def test_for_iterated_expression(self):
eval_count = [0]
@@ -243,14 +188,13 @@ class ControlFlowTest(converter_testing.TestCase):
s += e
return s
- node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
- node = control_flow.transform(node, self.ctx)
+ ns = {'count_evals': count_evals}
+ node, ctx = self.prepare(test_fn, ns)
+ node = control_flow.transform(node, ctx)
- with self.compiled(node) as result:
- result.count_evals = count_evals
- self.assertEqual(test_fn(5), result.test_fn(5))
- # count_evals ran twice, once for test_fn and another for result.test_fn
- self.assertEqual(eval_count[0], 2)
+ with self.compiled(node, ns) as result:
+ self.assertEqual(result.test_fn(5), 10)
+ self.assertEqual(eval_count[0], 1)
if __name__ == '__main__':