diff options
author | 2018-08-02 11:19:07 -0700 | |
---|---|---|
committer | 2018-08-02 11:23:12 -0700 | |
commit | cf89b81401ea651dccf72ff322c1159618d0c8aa (patch) | |
tree | 52d3f88eaf4707a2e19042d91fc5db0985bae41d | |
parent | 4238e0f19cbbaea96a2a37f432f97fda584189ba (diff) |
Throw ValueError with clear error message when function to be traced by defun
has operations that modify the input arguments in place.
PiperOrigin-RevId: 207133095
-rw-r--r-- | tensorflow/python/eager/function.py | 29 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 169 |
2 files changed, 196 insertions, 2 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 99129c2537..de55f999d7 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -680,6 +680,11 @@ def _trace_and_define_function(name, func, compiled, args, kwds): func_args = _get_defun_inputs(args) func_kwds = _get_defun_inputs(kwds) + # Variables to help check whether mutation happens in calling the function + # Copy the recursive list, tuple and map structure, but not base objects + func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) + func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds)) + def convert(x): if x is None: return None @@ -691,6 +696,25 @@ def _trace_and_define_function(name, func, compiled, args, kwds): try: func_outputs = func(*func_args, **func_kwds) func_outputs = nest.map_structure(convert, func_outputs) + + def check_mutation(n1, n2): + """Check if two list of arguments are exactly the same.""" + errmsg = ("Function to be traced should not modify structure of input " + "arguments. Check if your function has list and dictionary " + "operations that alter input arguments, " + "such as `list.pop`, `list.append`") + try: + nest.assert_same_structure(n1, n2) + except ValueError: + raise ValueError(errmsg) + + for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)): + if arg1 is not arg2: + raise ValueError(errmsg) + + check_mutation(func_args_before, func_args) + check_mutation(func_kwds_before, func_kwds) + finally: tape.pop_tape(this_tape) variables = this_tape.watched_variables() @@ -894,8 +918,9 @@ def defun(func=None, compiled=False): `defun`-generated graphs. For a Python function to be compatible with `defun`, all of its arguments must - be hashable Python objects or lists thereof. Additionally, it must return zero - or more @{tf.Tensor} objects. + be hashable Python objects or lists thereof. The function itself may not + modify the list/map structure of its arguments. Additionally, it must return + zero or more @{tf.Tensor} objects. Executing a graph generated by `defun` respects device annotations (i.e., all `with tf.device` directives present in a Python function will also be diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 2e86563a7d..afd4bbf4f3 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import collections +import sys from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import iterator_ops @@ -1212,6 +1213,174 @@ class AutomaticControlDependenciesTest(test.TestCase): train() self.assertEqual(v.numpy(), -1.0) + def testFunctionModifiesInputList(self): + # Tests on `list` methods that do in place modification, except `list.sort` + # since it cannot even be "defunned" in the first place + + def get_list(): + return [constant_op.constant(0.), constant_op.constant(1.)] + + expected_msg = ( + 'Function to be traced should not modify structure of input ' + 'arguments. Check if your function has list and dictionary ' + 'operations that alter input arguments, ' + 'such as `list.pop`, `list.append`') + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def append(l): + l.append(constant_op.constant(0.)) + + append(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def extend(l): + l.extend([constant_op.constant(0.)]) + + extend(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def insert(l): + l.insert(0, constant_op.constant(0.)) + + insert(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def pop(l): + l.pop() + + pop(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def reverse(l): + l.reverse() + + reverse(get_list()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def remove(l): + l.remove(l[0]) + + remove(get_list()) + + # `list.clear` is a method that is in Py3 but not Py2 + if sys.version.startswith('3'): + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def clear(l): + l.clear() + + clear(get_list()) + + # One last test for keyword arguments + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def kwdappend(**kwargs): + l = kwargs['l'] + l.append(constant_op.constant(0.)) + + kwdappend(l=get_list()) + + def testFunctionModifiesInputDict(self): + + def get_dict(): + return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)} + + expected_msg = ( + 'Function to be traced should not modify structure of input ' + 'arguments. Check if your function has list and dictionary ' + 'operations that alter input arguments, ' + 'such as `list.pop`, `list.append`') + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def clear(m): + m.clear() + + clear(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def pop(m): + m.pop('t1') + + pop(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def popitem(m): + m.popitem() + + popitem(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def update(m): + m.update({'t1': constant_op.constant(3.)}) + + update(get_dict()) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def setdefault(m): + m.setdefault('t3', constant_op.constant(3.)) + + setdefault(get_dict()) + + def testFunctionModifiesInputNest(self): + # Test on functions that modify structure of nested input arguments + expected_msg = ( + 'Function to be traced should not modify structure of input ' + 'arguments. Check if your function has list and dictionary ' + 'operations that alter input arguments, ' + 'such as `list.pop`, `list.append`') + + with self.assertRaisesRegexp(ValueError, expected_msg): + + @function.defun + def modify(n): + n[0]['t1'].append(constant_op.constant(1.)) + + nested_input = [{ + 't1': [constant_op.constant(0.), + constant_op.constant(1.)], + }, + constant_op.constant(2.)] + + modify(nested_input) + + with self.assertRaisesRegexp(ValueError, expected_msg): + + # The flat list doesn't change whereas the true structure changes + @function.defun + def modify_same_flat(n): + n[0].append(n[1].pop(0)) + + nested_input = [[constant_op.constant(0.)], + [constant_op.constant(1.), + constant_op.constant(2.)]] + + modify_same_flat(nested_input) + if __name__ == '__main__': ops.enable_eager_execution( |