diff options
Diffstat (limited to 'tensorflow/python/eager/function.py')
-rw-r--r-- | tensorflow/python/eager/function.py | 53 |
1 files changed, 4 insertions, 49 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index ff138cad1e..f1a63adce1 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -51,7 +51,6 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import compat from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator @@ -202,6 +201,7 @@ class FuncGraph(ops.Graph): # from the default graph even in eager mode. Maybe it should be part of the # eager context? self._distribution_strategy_stack = graph._distribution_strategy_stack + self._variable_creator_stack = graph._variable_creator_stack # Inherit the graph key, since this is used for matching variables in # optimizers. self._graph_key = graph._graph_key @@ -563,17 +563,6 @@ class Function(object): self._func_graph.inputs, self._func_graph.outputs, self._attrs) self._backward_graph_function = None - # Map holding distributed variables, keyed by resource handle tensors. - self._distributed_variables = {} - strategy = distribution_strategy_context.get_distribution_strategy() - for variable in self._func_graph.variables: - # If variable is not distributed, unwrap returns [variable]. - component_variables = strategy.unwrap(variable) - # Only update the dictionary when the variable is actually distributed. - if (len(component_variables) > 1 or component_variables[0] != variable): - for component_variable in component_variables: - self._distributed_variables[component_variable.handle] = variable - def __call__(self, *args): """Executes the wrapped function. @@ -602,7 +591,6 @@ class Function(object): if v.trainable: tape.variable_accessed(v) - captures = self._resolve_captured_inputs() tensor_inputs = [] for i, arg in enumerate(nest.flatten(args)): if isinstance(arg, resource_variable_ops.ResourceVariable): @@ -615,9 +603,10 @@ class Function(object): raise ValueError("All inputs to `Function`s must be Tensors; " "on invocation of %s, the %d-th input (%s) was not a " "Tensor." % (self._func_graph.name, i, str(arg))) - args = tensor_inputs + captures + args = tensor_inputs + self._captured_inputs - if tape.should_record(tensor_inputs) or tape.should_record(captures): + if (tape.should_record(tensor_inputs) or + tape.should_record(self._captured_inputs)): return self._backprop_call(args) # Only need to override the gradient in graph mode and when we have outputs. @@ -804,32 +793,6 @@ class Function(object): args, backward_function) return self._build_call_outputs(real_outputs) - def _resolve_captured_inputs(self): - """Resolve captured distributed variables to their current values. - - Some inputs can be distributed variables. Such variables yield a different - component (i.e. actual tf.Variable) variables depending on the context of - execution. - - Returns: - a list of resolved captured input tensors. - """ - if self._distributed_variables: - # Loop over each captured input and check if it corresponds to something - # distributed. If so, get its _distributed_container and fetch the - # component appropriate for the current execution context. - resolved_captured_inputs = self._captured_inputs[:] - for i, captured_input in enumerate(self._captured_inputs): - distributed_var = self._distributed_variables.get(captured_input, None) - if distributed_var is not None: - # distributed variables override __getattr__ and substitute the - # right component variable. In here, `distributed_var.handle` - # actually does the equivalent of - # distributed_var.get_current_component_var().handle. - resolved_captured_inputs[i] = distributed_var.handle - return resolved_captured_inputs - return self._captured_inputs - def _build_call_outputs(self, result): """Maps the fdef output list to actual output structure. @@ -1010,14 +973,6 @@ def func_graph_from_py_func(name, for x in _flatten(func_graph.structured_outputs) if x is not None) - # Some captured variables might be components of DistributedValues. - # Instead of storing non-distributed component variables, we - # store their distributed containers so we can retrieve the correct - # component variables at call-time. - strategy = distribution_strategy_context.get_distribution_strategy() - for i, variable in enumerate(variables): - # If variable is not distributed value_container returns itself. - variables[i] = strategy.value_container(variable) func_graph.variables = variables # Register any other functions defined in the graph. |