aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-09-14 16:43:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 16:47:28 -0700
commite179c17b96bcb855b2056f60851a24551b4189a6 (patch)
tree3a7c01d701cad0d2dbe4c11916fc0492aa35e66a /tensorflow/python/eager
parent1c2a300d483d9e5d5502cdd8131644f7647996c5 (diff)
Makes tf.Variable arguments (non-captured) DT_RESOURCE function inputs.
Previously, tf.Variable arguments to a defun-d Python function were made captured inputs. This change makes it possible to parameterize functions on DT_RESOURCE inputs. PiperOrigin-RevId: 213064739
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py60
-rw-r--r--tensorflow/python/eager/function_test.py37
2 files changed, 88 insertions, 9 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 962e334b27..f3fb48fd3b 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -65,7 +65,7 @@ gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-ac
WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
-def _create_substitute_placeholder(value, name, dtype=None):
+def _create_substitute_placeholder(value, name=None, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
# Note: setting ops.control_dependencies(None) ensures we always put
# capturing placeholders outside of any control flow context.
@@ -550,7 +550,19 @@ class Function(object):
self._distributed_variables[component_variable.handle] = variable
def __call__(self, *args):
- """Executes the wrapped function."""
+ """Executes the wrapped function.
+
+ Args:
+ *args: a list of Tensors or Variables.
+
+ Returns:
+ The result of applying the TF function to `args`.
+
+ Raises:
+ ValueError: If the current device stack does not match the device stack
+ under which the function was created, or if `args` contains anything
+ other than Tensors or Variables.
+ """
ctx = context.context()
device_functions = _get_device_functions(ctx, ops.get_default_graph())
if device_functions != self._device_functions:
@@ -566,7 +578,18 @@ class Function(object):
tape.variable_accessed(v)
captures = self._resolve_captured_inputs()
- tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ tensor_inputs = []
+ for i, arg in enumerate(nest.flatten(args)):
+ if isinstance(arg, resource_variable_ops.ResourceVariable):
+ if arg.trainable:
+ tape.variable_accessed(arg)
+ tensor_inputs.append(arg.handle)
+ elif isinstance(arg, ops.Tensor):
+ tensor_inputs.append(arg)
+ else:
+ raise ValueError("All inputs to `Function`s must be Tensors; "
+ "on invocation of %s, the %d-th input (%s) was not a "
+ "Tensor." % (self._func_graph.name, i, str(arg)))
args = tensor_inputs + captures
if tape.should_record(tensor_inputs) or tape.should_record(captures):
@@ -817,10 +840,6 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
func_kwds = {}
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
- func_graph.inputs.extend(
- x for x in nest.flatten(func_args) + nest.flatten(func_kwds)
- if isinstance(x, ops.Tensor))
-
# Variables to help check whether mutation happens in calling the function
# Copy the recursive list, tuple and map structure, but not base objects
func_args_before = nest.pack_sequence_as(func_args, nest.flatten(func_args))
@@ -867,6 +886,26 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
finally:
tape.pop_tape(this_tape)
+ # Variables in `func_args`, `func_kwds` should be explicit inputs
+ # to the function, not captured inputs.
+ variables = set(this_tape.watched_variables())
+ inputs = []
+ for arg in nest.flatten(func_args) + nest.flatten(func_kwds):
+ if isinstance(arg, resource_variable_ops.ResourceVariable):
+ try:
+ resource_placeholder = func_graph.captures.pop(arg.handle)
+ variables.remove(arg)
+ except KeyError:
+ # This case occurs if a Variable among the inputs is not actually
+ # used by the function; we still add an explicit input for it
+ # because the user should presumably pass the Variable as an input
+ # to the corresponding graph function.
+ resource_placeholder = _create_substitute_placeholder(arg.handle)
+ inputs.append(resource_placeholder)
+ elif isinstance(arg, ops.Tensor):
+ inputs.append(arg)
+ func_graph.inputs = inputs + list(func_graph.captures.values())
+
func_graph.structured_outputs = func_outputs
# Returning a closed-over tensor does not trigger convert_to_tensor.
func_graph.outputs.extend(
@@ -878,7 +917,7 @@ def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
# Instead of storing non-distributed component variables, we
# store their distributed containers so we can retrieve the correct
# component variables at call-time.
- variables = list(this_tape.watched_variables())
+ variables = list(variables)
strategy = distribution_strategy_context.get_distribution_strategy()
for i, variable in enumerate(variables):
# If variable is not distributed value_container returns itself.
@@ -1201,7 +1240,10 @@ class PolymorphicFunction(object):
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
self._function_cache[cache_key] = graph_function
- return graph_function, (args, kwds)
+ return graph_function, [
+ t for t in nest.flatten((args, kwds))
+ if isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
+ ]
def register(func, *args, **kwargs):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index a0abefe666..c168b6060c 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1685,6 +1685,43 @@ class FunctionTest(test.TestCase):
# pylint: disable=protected-access
self.assertEqual(len(graph._functions), 1)
+ def testCallingFunctionWithDifferentVariables(self):
+
+ @function.defun
+ def foo(v):
+ v.assign_add(1.0)
+ return v.read_value()
+
+ v = resource_variable_ops.ResourceVariable(0.0)
+ graph_function = foo.get_concrete_function(v)
+ self.assertEqual(len(graph_function.inputs), 1)
+ self.assertEqual(len(graph_function.captured_inputs), 0)
+
+ self.assertEqual(float(graph_function(v)), 1.0)
+ self.assertEqual(float(graph_function(v)), 2.0)
+
+ w = resource_variable_ops.ResourceVariable(0.0)
+
+ @function.defun
+ def bar(v):
+ del v
+ return constant_op.constant(1.0)
+
+ graph_function = bar.get_concrete_function(v)
+ self.assertEqual(float(graph_function(v)), 1.0)
+ self.assertEqual(float(graph_function(w)), 1.0)
+
+ def testCallingFunctionWithNonTensorsFails(self):
+
+ @function.defun
+ def foo(x):
+ return x
+
+ graph_function = foo.get_concrete_function(constant_op.constant(1.0))
+ with self.assertRaisesRegexp(ValueError, 'All inputs to `Function`s must '
+ 'be Tensors;.*'):
+ graph_function('Not a Tensor.')
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):