From 93226f635c5c108b3b501d8bbcf27e64dec49fb9 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 10 Oct 2018 07:38:42 -0700 Subject: Use overloaded operators for the assert statement. This should remove the reliance on importing tensorflow in the generated code. PiperOrigin-RevId: 216528047 --- tensorflow/python/autograph/converters/asserts.py | 2 +- .../python/autograph/converters/asserts_test.py | 24 +++--- tensorflow/python/autograph/operators/BUILD | 11 +++ tensorflow/python/autograph/operators/__init__.py | 1 + .../python/autograph/operators/exceptions.py | 86 +++++++++++++++++++++ .../python/autograph/operators/exceptions_test.py | 87 ++++++++++++++++++++++ 6 files changed, 201 insertions(+), 10 deletions(-) create mode 100644 tensorflow/python/autograph/operators/exceptions.py create mode 100644 tensorflow/python/autograph/operators/exceptions_test.py diff --git a/tensorflow/python/autograph/converters/asserts.py b/tensorflow/python/autograph/converters/asserts.py index 56a97534c4..4ba827c35f 100644 --- a/tensorflow/python/autograph/converters/asserts.py +++ b/tensorflow/python/autograph/converters/asserts.py @@ -33,7 +33,7 @@ class AssertTransformer(converter.Base): # Note: The lone tf.Assert call will be wrapped with control_dependencies # by side_effect_guards. template = """ - tf.Assert(test, (msg,)) + ag__.assert_stmt(test, lambda: msg) """ if node.msg is None: diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py index 01282f9e62..eef628aeb6 100644 --- a/tensorflow/python/autograph/converters/asserts_test.py +++ b/tensorflow/python/autograph/converters/asserts_test.py @@ -18,24 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gast - from tensorflow.python.autograph.converters import asserts +from tensorflow.python.autograph.converters import side_effect_guards from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.platform import test class AssertsTest(converter_testing.TestCase): - def test_transform(self): + def test_basic(self): def test_fn(a): - assert a > 0 - - node, ctx = self.prepare(test_fn, {}) - node = asserts.transform(node, ctx) - - self.assertTrue(isinstance(node.body[0].value, gast.Call)) + assert a, 'test message' + return tf.no_op() # pylint:disable=undefined-variable + + with self.converted(test_fn, (asserts, side_effect_guards), {}, + gen_control_flow_ops.no_op) as result: + with self.cached_session() as sess: + op = result.test_fn(constant_op.constant(False)) + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'test message'): + sess.run(op) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD index a116611b64..f422911377 100644 --- a/tensorflow/python/autograph/operators/BUILD +++ b/tensorflow/python/autograph/operators/BUILD @@ -22,6 +22,7 @@ py_library( "__init__.py", "control_flow.py", "data_structures.py", + "exceptions.py", "py_builtins.py", "slices.py", ], @@ -62,6 +63,16 @@ py_test( ], ) +py_test( + name = "exceptions_test", + srcs = ["exceptions_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "py_builtins_test", srcs = ["py_builtins_test.py"], diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py index 0d3b44b6c4..53f4b0ddc8 100644 --- a/tensorflow/python/autograph/operators/__init__.py +++ b/tensorflow/python/autograph/operators/__init__.py @@ -45,6 +45,7 @@ from tensorflow.python.autograph.operators.data_structures import list_stack from tensorflow.python.autograph.operators.data_structures import ListPopOpts from tensorflow.python.autograph.operators.data_structures import ListStackOpts from tensorflow.python.autograph.operators.data_structures import new_list +from tensorflow.python.autograph.operators.exceptions import assert_stmt from tensorflow.python.autograph.operators.py_builtins import float_ from tensorflow.python.autograph.operators.py_builtins import int_ from tensorflow.python.autograph.operators.py_builtins import len_ diff --git a/tensorflow/python/autograph/operators/exceptions.py b/tensorflow/python/autograph/operators/exceptions.py new file mode 100644 index 0000000000..6078160f68 --- /dev/null +++ b/tensorflow/python/autograph/operators/exceptions.py @@ -0,0 +1,86 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exception handling statements: assert, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import tf_inspect + + +def assert_stmt(expression1, expression2): + """Functional form of an assert statement. + + This follows the semantics of the Python assert statement, however the + concrete implementations may deviate from it. See the respective + implementation for details. + + In general, the assert statement should not be used for control flow. + Furthermore, it is encouraged that the assertion expressions should not have + side effects. + + Args: + expression1: Any + expression2: Callable[[], Any], returns the expression to include in the + error message when expression1 evaluates to False. When expression1 is + True, the result of expression2 will not be evaluated, however, + expression2 itself may be evaluated in some implementations. + + Returns: + Any, implementation-dependent. + + Raises: + ValueError: if any arguments are illegal. + """ + if not callable(expression2): + raise ValueError('{} must be a callable'.format(expression2)) + args, _, keywords, _ = tf_inspect.getargspec(expression2) + if args or keywords: + raise ValueError('{} may not have any arguments'.format(expression2)) + + if tensor_util.is_tensor(expression1): + return _tf_assert_stmt(expression1, expression2) + else: + return _py_assert_stmt(expression1, expression2) + + +def _tf_assert_stmt(expression1, expression2): + """Overload of assert_stmt that stages a TF Assert. + + This implementation deviates from Python semantics as follows: + (1) the assertion is verified regardless of the state of __debug__ + (2) on assertion failure, the graph execution will fail with + tensorflow.errors.ValueError, rather than AssertionError. + + Args: + expression1: tensorflow.Tensor, must evaluate to a tf.bool scalar + expression2: Callable[[], Union[tensorflow.Tensor, List[tensorflow.Tensor]]] + + Returns: + tensorflow.Operation + """ + expression2_tensors = expression2() + if not isinstance(expression2_tensors, list): + expression2_tensors = [expression2_tensors] + return control_flow_ops.Assert(expression1, expression2_tensors) + + +def _py_assert_stmt(expression1, expression2): + """Overload of assert_stmt that executes a Python assert statement.""" + assert expression1, expression2() + return None diff --git a/tensorflow/python/autograph/operators/exceptions_test.py b/tensorflow/python/autograph/operators/exceptions_test.py new file mode 100644 index 0000000000..186535d05b --- /dev/null +++ b/tensorflow/python/autograph/operators/exceptions_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for exceptions module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.autograph.operators import exceptions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.platform import test + + +class ExceptionsTest(test.TestCase): + + def test_assert_tf_untriggered(self): + with self.cached_session() as sess: + t = exceptions.assert_stmt( + constant_op.constant(True), lambda: constant_op.constant('ignored')) + sess.run(t) + + def test_assert_tf_triggered(self): + with self.cached_session() as sess: + t = exceptions.assert_stmt( + constant_op.constant(False), + lambda: constant_op.constant('test message')) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'test message'): + sess.run(t) + + def test_assert_tf_multiple_printed_values(self): + two_tensors = [ + constant_op.constant('test message'), + constant_op.constant('another message') + ] + with self.cached_session() as sess: + t = exceptions.assert_stmt( + constant_op.constant(False), lambda: two_tensors) + + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + 'test message.*another message'): + sess.run(t) + + def test_assert_python_untriggered(self): + side_effect_trace = [] + + def expression_with_side_effects(): + side_effect_trace.append(object()) + return 'test message' + + exceptions.assert_stmt(True, expression_with_side_effects) + + self.assertListEqual(side_effect_trace, []) + + def test_assert_python_triggered(self): + if not __debug__: + # Python assertions only be tested when in debug mode. + return + + side_effect_trace = [] + tracer = object() + + def expression_with_side_effects(): + side_effect_trace.append(tracer) + return 'test message' + + with self.assertRaisesRegexp(AssertionError, 'test message'): + exceptions.assert_stmt(False, expression_with_side_effects) + self.assertListEqual(side_effect_trace, [tracer]) + + +if __name__ == '__main__': + test.main() -- cgit v1.2.3