diff options
author | Dan Moldovan <mdan@google.com> | 2018-10-10 07:38:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-10 07:42:44 -0700 |
commit | 93226f635c5c108b3b501d8bbcf27e64dec49fb9 (patch) | |
tree | 0a703b3f99168dc3852c9961c874797827064e10 /tensorflow/python/autograph/converters/asserts_test.py | |
parent | e851764c24e5ac5f527a7ce2ce12050edddeb209 (diff) |
Use overloaded operators for the assert statement. This should remove the reliance on importing tensorflow in the generated code.
PiperOrigin-RevId: 216528047
Diffstat (limited to 'tensorflow/python/autograph/converters/asserts_test.py')
-rw-r--r-- | tensorflow/python/autograph/converters/asserts_test.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py index 01282f9e62..eef628aeb6 100644 --- a/tensorflow/python/autograph/converters/asserts_test.py +++ b/tensorflow/python/autograph/converters/asserts_test.py @@ -18,24 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.python.autograph.converters import asserts +from tensorflow.python.autograph.converters import side_effect_guards from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test class AssertsTest(converter_testing.TestCase): - def test_transform(self): + def test_basic(self): def test_fn(a): - assert a > 0 - - node, ctx = self.prepare(test_fn, {}) - node = asserts.transform(node, ctx) - - self.assertTrue(isinstance(node.body[0].value, gast.Call)) + assert a, 'test message' + return tf.no_op() # pylint:disable=undefined-variable + + with self.converted(test_fn, (asserts, side_effect_guards), {}, + gen_control_flow_ops.no_op) as result: + with self.cached_session() as sess: + op = result.test_fn(constant_op.constant(False)) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'test message'): + sess.run(op) if __name__ == '__main__': |