diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-09-14 16:43:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-14 16:47:28 -0700 |
commit | e179c17b96bcb855b2056f60851a24551b4189a6 (patch) | |
tree | 3a7c01d701cad0d2dbe4c11916fc0492aa35e66a /tensorflow/python/eager | |
parent | 1c2a300d483d9e5d5502cdd8131644f7647996c5 (diff) |
Makes tf.Variable arguments (non-captured) DT_RESOURCE function inputs.
Previously, tf.Variable arguments to a defun-d Python function were made captured inputs. This change makes it possible to parameterize functions on DT_RESOURCE inputs.
PiperOrigin-RevId: 213064739
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 60 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 37 |
2 files changed, 88 insertions, 9 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 962e334b27..f3fb48fd3b 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -65,7 +65,7 @@ gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-ac WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_" -def _create_substitute_placeholder(value, name, dtype=None): +def _create_substitute_placeholder(value, name=None, dtype=None): """Creates a placeholder for `value` and propagates shape info to it.""" # Note: setting ops.control_dependencies(None) ensures we always put # capturing placeholders outside of any control flow context. @@ -550,7 +550,19 @@ class Function(object): self._distributed_variables[component_variable.handle] = variable def __call__(self, *args): - """Executes the wrapped function.""" + """Executes the wrapped function. + + Args: + *args: a list of Tensors or Variables. + + Returns: + The result of applying the TF function to `args`. + + Raises: + ValueError: If the current device stack does not match the device stack + under which the function was created, or if `args` contains anything + other than Tensors or Variables. + """ ctx = context.context() device_functions = _get_device_functions(ctx, ops.get_default_graph()) if device_functions != self._device_functions: @@ -566,7 +578,18 @@ class Function(object): tape.variable_accessed(v) captures = self._resolve_captured_inputs() - tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] + tensor_inputs = [] + for i, arg in enumerate(nest.flatten(args)): + if isinstance(arg, resource_variable_ops.ResourceVariable): + if arg.trainable: + tape.variable_accessed(arg) + tensor_inputs.append(arg.handle) + elif isinstance(arg, ops.Tensor): + tensor_inputs.append(arg) + else: + raise ValueError("All inputs to `Function`s must be Tensors; " + "on invocation of %s, the %d-th input (%s) was not a " + "Tensor." % (self._func_graph.name, i, str(arg))) args = tensor_inputs + captures if tape.should_record(tensor_inputs) or tape.should_record(captures): @@ -817,10 +840,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): func_kwds = {} # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. - func_graph.inputs.extend( - x for x in nest.flatten(func_args) + nest.flatten(func_kwds) - if isinstance(x, ops.Tensor)) - # Variables to help check whether mutation happens in calling the function # Copy the recursive list, tuple and map structure, but not base objects func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args)) @@ -867,6 +886,26 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): finally: tape.pop_tape(this_tape) + # Variables in `func_args`, `func_kwds` should be explicit inputs + # to the function, not captured inputs. + variables = set(this_tape.watched_variables()) + inputs = [] + for arg in nest.flatten(func_args) + nest.flatten(func_kwds): + if isinstance(arg, resource_variable_ops.ResourceVariable): + try: + resource_placeholder = func_graph.captures.pop(arg.handle) + variables.remove(arg) + except KeyError: + # This case occurs if a Variable among the inputs is not actually + # used by the function; we still add an explicit input for it + # because the user should presumably pass the Variable as an input + # to the corresponding graph function. + resource_placeholder = _create_substitute_placeholder(arg.handle) + inputs.append(resource_placeholder) + elif isinstance(arg, ops.Tensor): + inputs.append(arg) + func_graph.inputs = inputs + list(func_graph.captures.values()) + func_graph.structured_outputs = func_outputs # Returning a closed-over tensor does not trigger convert_to_tensor. func_graph.outputs.extend( @@ -878,7 +917,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None): # Instead of storing non-distributed component variables, we # store their distributed containers so we can retrieve the correct # component variables at call-time. - variables = list(this_tape.watched_variables()) + variables = list(variables) strategy = distribution_strategy_context.get_distribution_strategy() for i, variable in enumerate(variables): # If variable is not distributed value_container returns itself. @@ -1201,7 +1240,10 @@ class PolymorphicFunction(object): self._variables.extend( [v for v in graph_function.variables if v not in self._variables]) self._function_cache[cache_key] = graph_function - return graph_function, (args, kwds) + return graph_function, [ + t for t in nest.flatten((args, kwds)) + if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable)) + ] def register(func, *args, **kwargs): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index a0abefe666..c168b6060c 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1685,6 +1685,43 @@ class FunctionTest(test.TestCase): # pylint: disable=protected-access self.assertEqual(len(graph._functions), 1) + def testCallingFunctionWithDifferentVariables(self): + + @function.defun + def foo(v): + v.assign_add(1.0) + return v.read_value() + + v = resource_variable_ops.ResourceVariable(0.0) + graph_function = foo.get_concrete_function(v) + self.assertEqual(len(graph_function.inputs), 1) + self.assertEqual(len(graph_function.captured_inputs), 0) + + self.assertEqual(float(graph_function(v)), 1.0) + self.assertEqual(float(graph_function(v)), 2.0) + + w = resource_variable_ops.ResourceVariable(0.0) + + @function.defun + def bar(v): + del v + return constant_op.constant(1.0) + + graph_function = bar.get_concrete_function(v) + self.assertEqual(float(graph_function(v)), 1.0) + self.assertEqual(float(graph_function(w)), 1.0) + + def testCallingFunctionWithNonTensorsFails(self): + + @function.defun + def foo(x): + return x + + graph_function = foo.get_concrete_function(constant_op.constant(1.0)) + with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must ' + 'be Tensors;.*'): + graph_function('Not a Tensor.') + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |