From 93226f635c5c108b3b501d8bbcf27e64dec49fb9 Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Wed, 10 Oct 2018 07:38:42 -0700 Subject: Use overloaded operators for the assert statement. This should remove the reliance on importing tensorflow in the generated code. PiperOrigin-RevId: 216528047 --- .../python/autograph/operators/exceptions.py | 86 ++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 tensorflow/python/autograph/operators/exceptions.py (limited to 'tensorflow/python/autograph/operators/exceptions.py') diff --git a/tensorflow/python/autograph/operators/exceptions.py b/tensorflow/python/autograph/operators/exceptions.py new file mode 100644 index 0000000000..6078160f68 --- /dev/null +++ b/tensorflow/python/autograph/operators/exceptions.py @@ -0,0 +1,86 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exception handling statements: assert, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.util import tf_inspect + + +def assert_stmt(expression1, expression2): + """Functional form of an assert statement. + + This follows the semantics of the Python assert statement, however the + concrete implementations may deviate from it. See the respective + implementation for details. + + In general, the assert statement should not be used for control flow. + Furthermore, it is encouraged that the assertion expressions should not have + side effects. + + Args: + expression1: Any + expression2: Callable[[], Any], returns the expression to include in the + error message when expression1 evaluates to False. When expression1 is + True, the result of expression2 will not be evaluated, however, + expression2 itself may be evaluated in some implementations. + + Returns: + Any, implementation-dependent. + + Raises: + ValueError: if any arguments are illegal. + """ + if not callable(expression2): + raise ValueError('{} must be a callable'.format(expression2)) + args, _, keywords, _ = tf_inspect.getargspec(expression2) + if args or keywords: + raise ValueError('{} may not have any arguments'.format(expression2)) + + if tensor_util.is_tensor(expression1): + return _tf_assert_stmt(expression1, expression2) + else: + return _py_assert_stmt(expression1, expression2) + + +def _tf_assert_stmt(expression1, expression2): + """Overload of assert_stmt that stages a TF Assert. + + This implementation deviates from Python semantics as follows: + (1) the assertion is verified regardless of the state of __debug__ + (2) on assertion failure, the graph execution will fail with + tensorflow.errors.ValueError, rather than AssertionError. + + Args: + expression1: tensorflow.Tensor, must evaluate to a tf.bool scalar + expression2: Callable[[], Union[tensorflow.Tensor, List[tensorflow.Tensor]]] + + Returns: + tensorflow.Operation + """ + expression2_tensors = expression2() + if not isinstance(expression2_tensors, list): + expression2_tensors = [expression2_tensors] + return control_flow_ops.Assert(expression1, expression2_tensors) + + +def _py_assert_stmt(expression1, expression2): + """Overload of assert_stmt that executes a Python assert statement.""" + assert expression1, expression2() + return None -- cgit v1.2.3