aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-10 07:38:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 07:42:44 -0700
commit93226f635c5c108b3b501d8bbcf27e64dec49fb9 (patch)
tree0a703b3f99168dc3852c9961c874797827064e10
parente851764c24e5ac5f527a7ce2ce12050edddeb209 (diff)
Use overloaded operators for the assert statement. This should remove the reliance on importing tensorflow in the generated code.
PiperOrigin-RevId: 216528047
-rw-r--r--tensorflow/python/autograph/converters/asserts.py2
-rw-r--r--tensorflow/python/autograph/converters/asserts_test.py24
-rw-r--r--tensorflow/python/autograph/operators/BUILD11
-rw-r--r--tensorflow/python/autograph/operators/__init__.py1
-rw-r--r--tensorflow/python/autograph/operators/exceptions.py86
-rw-r--r--tensorflow/python/autograph/operators/exceptions_test.py87
6 files changed, 201 insertions, 10 deletions
diff --git a/tensorflow/python/autograph/converters/asserts.py b/tensorflow/python/autograph/converters/asserts.py
index 56a97534c4..4ba827c35f 100644
--- a/tensorflow/python/autograph/converters/asserts.py
+++ b/tensorflow/python/autograph/converters/asserts.py
@@ -33,7 +33,7 @@ class AssertTransformer(converter.Base):
# Note: The lone tf.Assert call will be wrapped with control_dependencies
# by side_effect_guards.
template = """
- tf.Assert(test, (msg,))
+ ag__.assert_stmt(test, lambda: msg)
"""
if node.msg is None:
diff --git a/tensorflow/python/autograph/converters/asserts_test.py b/tensorflow/python/autograph/converters/asserts_test.py
index 01282f9e62..eef628aeb6 100644
--- a/tensorflow/python/autograph/converters/asserts_test.py
+++ b/tensorflow/python/autograph/converters/asserts_test.py
@@ -18,24 +18,30 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import gast
-
from tensorflow.python.autograph.converters import asserts
+from tensorflow.python.autograph.converters import side_effect_guards
from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.platform import test
class AssertsTest(converter_testing.TestCase):
- def test_transform(self):
+ def test_basic(self):
def test_fn(a):
- assert a > 0
-
- node, ctx = self.prepare(test_fn, {})
- node = asserts.transform(node, ctx)
-
- self.assertTrue(isinstance(node.body[0].value, gast.Call))
+ assert a, 'test message'
+ return tf.no_op() # pylint:disable=undefined-variable
+
+ with self.converted(test_fn, (asserts, side_effect_guards), {},
+ gen_control_flow_ops.no_op) as result:
+ with self.cached_session() as sess:
+ op = result.test_fn(constant_op.constant(False))
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ 'test message'):
+ sess.run(op)
if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/operators/BUILD b/tensorflow/python/autograph/operators/BUILD
index a116611b64..f422911377 100644
--- a/tensorflow/python/autograph/operators/BUILD
+++ b/tensorflow/python/autograph/operators/BUILD
@@ -22,6 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
+ "exceptions.py",
"py_builtins.py",
"slices.py",
],
@@ -63,6 +64,16 @@ py_test(
)
py_test(
+ name = "exceptions_test",
+ srcs = ["exceptions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "py_builtins_test",
srcs = ["py_builtins_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/autograph/operators/__init__.py b/tensorflow/python/autograph/operators/__init__.py
index 0d3b44b6c4..53f4b0ddc8 100644
--- a/tensorflow/python/autograph/operators/__init__.py
+++ b/tensorflow/python/autograph/operators/__init__.py
@@ -45,6 +45,7 @@ from tensorflow.python.autograph.operators.data_structures import list_stack
from tensorflow.python.autograph.operators.data_structures import ListPopOpts
from tensorflow.python.autograph.operators.data_structures import ListStackOpts
from tensorflow.python.autograph.operators.data_structures import new_list
+from tensorflow.python.autograph.operators.exceptions import assert_stmt
from tensorflow.python.autograph.operators.py_builtins import float_
from tensorflow.python.autograph.operators.py_builtins import int_
from tensorflow.python.autograph.operators.py_builtins import len_
diff --git a/tensorflow/python/autograph/operators/exceptions.py b/tensorflow/python/autograph/operators/exceptions.py
new file mode 100644
index 0000000000..6078160f68
--- /dev/null
+++ b/tensorflow/python/autograph/operators/exceptions.py
@@ -0,0 +1,86 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Exception handling statements: assert, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.util import tf_inspect
+
+
+def assert_stmt(expression1, expression2):
+ """Functional form of an assert statement.
+
+ This follows the semantics of the Python assert statement, however the
+ concrete implementations may deviate from it. See the respective
+ implementation for details.
+
+ In general, the assert statement should not be used for control flow.
+ Furthermore, it is encouraged that the assertion expressions should not have
+ side effects.
+
+ Args:
+ expression1: Any
+ expression2: Callable[[], Any], returns the expression to include in the
+ error message when expression1 evaluates to False. When expression1 is
+ True, the result of expression2 will not be evaluated, however,
+ expression2 itself may be evaluated in some implementations.
+
+ Returns:
+ Any, implementation-dependent.
+
+ Raises:
+ ValueError: if any arguments are illegal.
+ """
+ if not callable(expression2):
+ raise ValueError('{} must be a callable'.format(expression2))
+ args, _, keywords, _ = tf_inspect.getargspec(expression2)
+ if args or keywords:
+ raise ValueError('{} may not have any arguments'.format(expression2))
+
+ if tensor_util.is_tensor(expression1):
+ return _tf_assert_stmt(expression1, expression2)
+ else:
+ return _py_assert_stmt(expression1, expression2)
+
+
+def _tf_assert_stmt(expression1, expression2):
+ """Overload of assert_stmt that stages a TF Assert.
+
+ This implementation deviates from Python semantics as follows:
+ (1) the assertion is verified regardless of the state of __debug__
+ (2) on assertion failure, the graph execution will fail with
+ tensorflow.errors.ValueError, rather than AssertionError.
+
+ Args:
+ expression1: tensorflow.Tensor, must evaluate to a tf.bool scalar
+ expression2: Callable[[], Union[tensorflow.Tensor, List[tensorflow.Tensor]]]
+
+ Returns:
+ tensorflow.Operation
+ """
+ expression2_tensors = expression2()
+ if not isinstance(expression2_tensors, list):
+ expression2_tensors = [expression2_tensors]
+ return control_flow_ops.Assert(expression1, expression2_tensors)
+
+
+def _py_assert_stmt(expression1, expression2):
+ """Overload of assert_stmt that executes a Python assert statement."""
+ assert expression1, expression2()
+ return None
diff --git a/tensorflow/python/autograph/operators/exceptions_test.py b/tensorflow/python/autograph/operators/exceptions_test.py
new file mode 100644
index 0000000000..186535d05b
--- /dev/null
+++ b/tensorflow/python/autograph/operators/exceptions_test.py
@@ -0,0 +1,87 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for exceptions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.operators import exceptions
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.platform import test
+
+
+class ExceptionsTest(test.TestCase):
+
+ def test_assert_tf_untriggered(self):
+ with self.cached_session() as sess:
+ t = exceptions.assert_stmt(
+ constant_op.constant(True), lambda: constant_op.constant('ignored'))
+ sess.run(t)
+
+ def test_assert_tf_triggered(self):
+ with self.cached_session() as sess:
+ t = exceptions.assert_stmt(
+ constant_op.constant(False),
+ lambda: constant_op.constant('test message'))
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ 'test message'):
+ sess.run(t)
+
+ def test_assert_tf_multiple_printed_values(self):
+ two_tensors = [
+ constant_op.constant('test message'),
+ constant_op.constant('another message')
+ ]
+ with self.cached_session() as sess:
+ t = exceptions.assert_stmt(
+ constant_op.constant(False), lambda: two_tensors)
+
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ 'test message.*another message'):
+ sess.run(t)
+
+ def test_assert_python_untriggered(self):
+ side_effect_trace = []
+
+ def expression_with_side_effects():
+ side_effect_trace.append(object())
+ return 'test message'
+
+ exceptions.assert_stmt(True, expression_with_side_effects)
+
+ self.assertListEqual(side_effect_trace, [])
+
+ def test_assert_python_triggered(self):
+ if not __debug__:
+ # Python assertions only be tested when in debug mode.
+ return
+
+ side_effect_trace = []
+ tracer = object()
+
+ def expression_with_side_effects():
+ side_effect_trace.append(tracer)
+ return 'test message'
+
+ with self.assertRaisesRegexp(AssertionError, 'test message'):
+ exceptions.assert_stmt(False, expression_with_side_effects)
+ self.assertListEqual(side_effect_trace, [tracer])
+
+
+if __name__ == '__main__':
+ test.main()