aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Xuechen Li <lxuechen@google.com>2018-08-02 11:19:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 11:23:12 -0700
commitcf89b81401ea651dccf72ff322c1159618d0c8aa (patch)
tree52d3f88eaf4707a2e19042d91fc5db0985bae41d
parent4238e0f19cbbaea96a2a37f432f97fda584189ba (diff)
Throw ValueError with clear error message when function to be traced by defun
has operations that modify the input arguments in place. PiperOrigin-RevId: 207133095
-rw-r--r--tensorflow/python/eager/function.py29
-rw-r--r--tensorflow/python/eager/function_test.py169
2 files changed, 196 insertions, 2 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 99129c2537..de55f999d7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -680,6 +680,11 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
func_args = _get_defun_inputs(args)
func_kwds = _get_defun_inputs(kwds)
+ # Variables to help check whether mutation happens in calling the function
+ # Copy the recursive list, tuple and map structure, but not base objects
+ func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
+ func_kwds_before = nest.pack_sequence_as(func_kwds, nest.flatten(func_kwds))
+
def convert(x):
if x is None:
return None
@@ -691,6 +696,25 @@ def _trace_and_define_function(name, func, compiled, args, kwds):
try:
func_outputs = func(*func_args, **func_kwds)
func_outputs = nest.map_structure(convert, func_outputs)
+
+ def check_mutation(n1, n2):
+ """Check if two list of arguments are exactly the same."""
+ errmsg = ("Function to be traced should not modify structure of input "
+ "arguments. Check if your function has list and dictionary "
+ "operations that alter input arguments, "
+ "such as `list.pop`, `list.append`")
+ try:
+ nest.assert_same_structure(n1, n2)
+ except ValueError:
+ raise ValueError(errmsg)
+
+ for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)):
+ if arg1 is not arg2:
+ raise ValueError(errmsg)
+
+ check_mutation(func_args_before, func_args)
+ check_mutation(func_kwds_before, func_kwds)
+
finally:
tape.pop_tape(this_tape)
variables = this_tape.watched_variables()
@@ -894,8 +918,9 @@ def defun(func=None, compiled=False):
`defun`-generated graphs.
For a Python function to be compatible with `defun`, all of its arguments must
- be hashable Python objects or lists thereof. Additionally, it must return zero
- or more @{tf.Tensor} objects.
+ be hashable Python objects or lists thereof. The function itself may not
+ modify the list/map structure of its arguments. Additionally, it must return
+ zero or more @{tf.Tensor} objects.
Executing a graph generated by `defun` respects device annotations (i.e.,
all `with tf.device` directives present in a Python function will also be
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 2e86563a7d..afd4bbf4f3 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import iterator_ops
@@ -1212,6 +1213,174 @@ class AutomaticControlDependenciesTest(test.TestCase):
train()
self.assertEqual(v.numpy(), -1.0)
+ def testFunctionModifiesInputList(self):
+ # Tests on `list` methods that do in place modification, except `list.sort`
+ # since it cannot even be "defunned" in the first place
+
+ def get_list():
+ return [constant_op.constant(0.), constant_op.constant(1.)]
+
+ expected_msg = (
+ 'Function to be traced should not modify structure of input '
+ 'arguments. Check if your function has list and dictionary '
+ 'operations that alter input arguments, '
+ 'such as `list.pop`, `list.append`')
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def append(l):
+ l.append(constant_op.constant(0.))
+
+ append(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def extend(l):
+ l.extend([constant_op.constant(0.)])
+
+ extend(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def insert(l):
+ l.insert(0, constant_op.constant(0.))
+
+ insert(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def pop(l):
+ l.pop()
+
+ pop(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def reverse(l):
+ l.reverse()
+
+ reverse(get_list())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def remove(l):
+ l.remove(l[0])
+
+ remove(get_list())
+
+ # `list.clear` is a method that is in Py3 but not Py2
+ if sys.version.startswith('3'):
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def clear(l):
+ l.clear()
+
+ clear(get_list())
+
+ # One last test for keyword arguments
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def kwdappend(**kwargs):
+ l = kwargs['l']
+ l.append(constant_op.constant(0.))
+
+ kwdappend(l=get_list())
+
+ def testFunctionModifiesInputDict(self):
+
+ def get_dict():
+ return {'t1': constant_op.constant(0.), 't2': constant_op.constant(1.)}
+
+ expected_msg = (
+ 'Function to be traced should not modify structure of input '
+ 'arguments. Check if your function has list and dictionary '
+ 'operations that alter input arguments, '
+ 'such as `list.pop`, `list.append`')
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def clear(m):
+ m.clear()
+
+ clear(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def pop(m):
+ m.pop('t1')
+
+ pop(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def popitem(m):
+ m.popitem()
+
+ popitem(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def update(m):
+ m.update({'t1': constant_op.constant(3.)})
+
+ update(get_dict())
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def setdefault(m):
+ m.setdefault('t3', constant_op.constant(3.))
+
+ setdefault(get_dict())
+
+ def testFunctionModifiesInputNest(self):
+ # Test on functions that modify structure of nested input arguments
+ expected_msg = (
+ 'Function to be traced should not modify structure of input '
+ 'arguments. Check if your function has list and dictionary '
+ 'operations that alter input arguments, '
+ 'such as `list.pop`, `list.append`')
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ @function.defun
+ def modify(n):
+ n[0]['t1'].append(constant_op.constant(1.))
+
+ nested_input = [{
+ 't1': [constant_op.constant(0.),
+ constant_op.constant(1.)],
+ },
+ constant_op.constant(2.)]
+
+ modify(nested_input)
+
+ with self.assertRaisesRegexp(ValueError, expected_msg):
+
+ # The flat list doesn't change whereas the true structure changes
+ @function.defun
+ def modify_same_flat(n):
+ n[0].append(n[1].pop(0))
+
+ nested_input = [[constant_op.constant(0.)],
+ [constant_op.constant(1.),
+ constant_op.constant(2.)]]
+
+ modify_same_flat(nested_input)
+
if __name__ == '__main__':
ops.enable_eager_execution(