aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/side_effect_guards_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/side_effect_guards_test.py132
1 files changed, 65 insertions, 67 deletions
diff --git a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
index a7ad8efed4..de1874321e 100644
--- a/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
+++ b/tensorflow/contrib/autograph/converters/side_effect_guards_test.py
@@ -25,140 +25,138 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
+tf = None # Will be replaced by a mock.
+
+
class SideEffectGuardsTest(converter_testing.TestCase):
def test_side_effect_on_return_only_variable(self):
- tf = None
-
def test_fn(a):
tf.assign(a, a + 1)
return a
- node = self.parse_and_analyze(test_fn, {})
- node = side_effect_guards.transform(node, self.ctx)
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
- with self.compiled(node, state_ops.assign) as result:
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body[0].body), 1)
+
+ with self.compiled(node, {}, state_ops.assign) as result:
with self.test_session() as sess:
- v = variables.Variable(2)
+ v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
- # NOTE: We don't expect the assignment to execute in this case, because
- # variables cannot be reliably guarded.
- self.assertEqual(2, sess.run(result.test_fn(v)))
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Add support for this use case.
+ # Right now the variable `a` is not conditioned on the `assign` because
+ # there's no way to add control dependencies to a variable object.
+ self.assertEqual(2, sess.run(v))
def test_side_effect_on_used_variable(self):
- tf = None
-
def test_fn(a):
tf.assign(a, a + 1)
return a + 1
- node = self.parse_and_analyze(test_fn, {})
- node = side_effect_guards.transform(node, self.ctx)
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
- with self.compiled(node, state_ops.assign) as result:
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body[0].body), 1)
+
+ with self.compiled(node, {}, state_ops.assign) as result:
with self.test_session() as sess:
- v = variables.Variable(2)
+ v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
- # NOTE: Unlike test_side_effect_on_return_only_variable, the variable
- # was used in the local scope and so we could catch the assign's side
- # effect.
- self.assertEqual(4, sess.run(result.test_fn(v)))
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ # Right now it's 3 or 4 based on whether the read is synchronized.
+ self.assertEqual(3, sess.run(v))
def test_side_effect_on_tensor(self):
- tf = None
-
def test_fn(a):
tf.Assert(a > 0, ['expected in throw'])
return a
- node = self.parse_and_analyze(test_fn, {})
- node = side_effect_guards.transform(node, self.ctx)
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
- with self.compiled(node, control_flow_ops.Assert) as result:
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body[0].body), 1)
+
+ with self.compiled(node, {}, control_flow_ops.Assert) as result:
with self.test_session() as sess:
- # NOTE: In this case we can also capture the side effect because the
- # argument is a tensor ans we can wrap it inside an identity.
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'expected in throw'):
sess.run(result.test_fn(constant_op.constant(-1)))
def test_multiline_block(self):
- tf = None
-
def test_fn(a):
- tf.assign(a, a + 1)
+ tf.assign_add(a, 1)
b = a + 1
- tf.assign(a, b + 1)
- c = b + 1
- d = c + 1
- return d
+ tf.assign_add(a, 1)
+ b += 1
+ return b
- node = self.parse_and_analyze(test_fn, {})
- node = side_effect_guards.transform(node, self.ctx)
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
- with self.compiled(node, state_ops.assign) as result:
- self.assertEqual(len(node.body[0].body), 1)
+ self.assertEqual(len(node.body[0].body), 1)
+
+ with self.compiled(node, {}, state_ops.assign_add) as result:
with self.test_session() as sess:
- v = variables.Variable(2)
+ v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
- self.assertEqual(6, sess.run(result.test_fn(v)))
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ self.assertEqual(4, sess.run(v))
def test_multiline_nested_block(self):
- tf = None
-
def test_fn(a):
with tf.name_scope('foo'):
tf.assign(a, a + 1)
b = a + 1
- c = b + 1
- d = c + 1
- return d
+ return b
- node = self.parse_and_analyze(test_fn, {})
- node = side_effect_guards.transform(node, self.ctx)
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
- with self.compiled(node, state_ops.assign, ops.name_scope) as result:
- self.assertEqual(len(node.body[0].body[0].body), 1)
+ self.assertEqual(len(node.body[0].body[0].body), 1)
+
+ with self.compiled(node, {}, state_ops.assign, ops.name_scope) as result:
with self.test_session() as sess:
- v = variables.Variable(2)
+ v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
- self.assertEqual(6, sess.run(result.test_fn(v)))
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ self.assertEqual(3, sess.run(v))
def test_multiline_block_unsafe(self):
- tf = None
-
def test_fn(a):
tf.assign(a, a + 1)
b = a + 1
- tf.assign(a, a + 1)
+ tf.assign_add(a, 1)
c = b + 1
- d = c + 1
- return d
+ return c
+
+ node, ctx = self.prepare(test_fn, {})
+ node = side_effect_guards.transform(node, ctx)
- node = self.parse_and_analyze(test_fn, {})
- node = side_effect_guards.transform(node, self.ctx)
+ self.assertEqual(len(node.body[0].body), 1)
- with self.compiled(node, state_ops.assign) as result:
- self.assertEqual(len(node.body[0].body), 1)
+ with self.compiled(node, {}, state_ops.assign,
+ state_ops.assign_add) as result:
with self.test_session() as sess:
- v = variables.Variable(2)
+ v = variable_scope.get_variable('test', initializer=2)
sess.run(v.initializer)
- # NOTE: This intentionally highlights the flakiness. The test should be
- # tightened down once that is solved.
- self.assertTrue(sess.run(result.test_fn(v)) in (6, 7))
+ sess.run(result.test_fn(v))
+ # TODO(mdan): Ensure the result of test_fn(v) is also deterministic.
+ self.assertEqual(4, sess.run(v))
if __name__ == '__main__':