aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/function.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r--tensorflow/python/eager/function.py53
1 files changed, 4 insertions, 49 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index ff138cad1e..f1a63adce1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -51,7 +51,6 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -202,6 +201,7 @@ class FuncGraph(ops.Graph):
# from the default graph even in eager mode. Maybe it should be part of the
# eager context?
self._distribution_strategy_stack = graph._distribution_strategy_stack
+ self._variable_creator_stack = graph._variable_creator_stack
# Inherit the graph key, since this is used for matching variables in
# optimizers.
self._graph_key = graph._graph_key
@@ -563,17 +563,6 @@ class Function(object):
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
self._backward_graph_function = None
- # Map holding distributed variables, keyed by resource handle tensors.
- self._distributed_variables = {}
- strategy = distribution_strategy_context.get_distribution_strategy()
- for variable in self._func_graph.variables:
- # If variable is not distributed, unwrap returns [variable].
- component_variables = strategy.unwrap(variable)
- # Only update the dictionary when the variable is actually distributed.
- if (len(component_variables) > 1 or component_variables[0] != variable):
- for component_variable in component_variables:
- self._distributed_variables[component_variable.handle] = variable
-
def __call__(self, *args):
"""Executes the wrapped function.
@@ -602,7 +591,6 @@ class Function(object):
if v.trainable:
tape.variable_accessed(v)
- captures = self._resolve_captured_inputs()
tensor_inputs = []
for i, arg in enumerate(nest.flatten(args)):
if isinstance(arg, resource_variable_ops.ResourceVariable):
@@ -615,9 +603,10 @@ class Function(object):
raise ValueError("All inputs to `Function`s must be Tensors; "
"on invocation of %s, the %d-th input (%s) was not a "
"Tensor." % (self._func_graph.name, i, str(arg)))
- args = tensor_inputs + captures
+ args = tensor_inputs + self._captured_inputs
- if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ if (tape.should_record(tensor_inputs) or
+ tape.should_record(self._captured_inputs)):
return self._backprop_call(args)
# Only need to override the gradient in graph mode and when we have outputs.
@@ -804,32 +793,6 @@ class Function(object):
args, backward_function)
return self._build_call_outputs(real_outputs)
- def _resolve_captured_inputs(self):
- """Resolve captured distributed variables to their current values.
-
- Some inputs can be distributed variables. Such variables yield a different
- component (i.e. actual tf.Variable) variables depending on the context of
- execution.
-
- Returns:
- a list of resolved captured input tensors.
- """
- if self._distributed_variables:
- # Loop over each captured input and check if it corresponds to something
- # distributed. If so, get its _distributed_container and fetch the
- # component appropriate for the current execution context.
- resolved_captured_inputs = self._captured_inputs[:]
- for i, captured_input in enumerate(self._captured_inputs):
- distributed_var = self._distributed_variables.get(captured_input, None)
- if distributed_var is not None:
- # distributed variables override __getattr__ and substitute the
- # right component variable. In here, `distributed_var.handle`
- # actually does the equivalent of
- # distributed_var.get_current_component_var().handle.
- resolved_captured_inputs[i] = distributed_var.handle
- return resolved_captured_inputs
- return self._captured_inputs
-
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -1010,14 +973,6 @@ def func_graph_from_py_func(name,
for x in _flatten(func_graph.structured_outputs)
if x is not None)
- # Some captured variables might be components of DistributedValues.
- # Instead of storing non-distributed component variables, we
- # store their distributed containers so we can retrieve the correct
- # component variables at call-time.
- strategy = distribution_strategy_context.get_distribution_strategy()
- for i, variable in enumerate(variables):
- # If variable is not distributed value_container returns itself.
- variables[i] = strategy.value_container(variable)
func_graph.variables = variables
# Register any other functions defined in the graph.