diff options
Diffstat (limited to 'tensorflow/python/ops/cond_v2_impl.py')
-rw-r--r-- | tensorflow/python/ops/cond_v2_impl.py | 148 |
1 files changed, 52 insertions, 96 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py index b3dacff6d6..c4e9c982b5 100644 --- a/tensorflow/python/ops/cond_v2_impl.py +++ b/tensorflow/python/ops/cond_v2_impl.py @@ -27,14 +27,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python import pywrap_tensorflow as c_api -from tensorflow.python.framework import c_api_util from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_functional_ops -from tensorflow.python.util import compat # The following modules cannot be imported directly because they cause circular @@ -57,46 +56,27 @@ def cond_v2(pred, true_fn, false_fn, name="cond"): name = "cond" with ops.name_scope(name) as scope: - # Identify if there is a caller device, & get the innermost if possible. - # pylint: disable=protected-access - device_funcs = ops.get_default_graph()._device_functions_outer_to_inner - caller_device = device_funcs[-1] if device_funcs else None - - caller_colocation_stack = ops.get_default_graph()._colocation_stack - caller_container = ops.get_default_graph()._container - caller_collection_ref = ops.get_default_graph()._collections - with ops.name_scope(None): # Find the outer most graph for uniquing function names. # TODO(jpienaar): Make this work in eager mode. graph = ops.get_default_graph() - while isinstance(graph, _function._FuncGraph): - graph = graph._outer_graph + while isinstance(graph, _function.FuncGraph): + graph = graph.outer_graph true_name = graph.unique_name(("%strue" % scope).replace("/", "_")) false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_")) - # pylint: enable=protected-access + true_graph = _function.func_graph_from_py_func( - true_fn, [], [], - name=true_name, - device=caller_device, - colocation_stack=caller_colocation_stack, - collections_ref=caller_collection_ref, - container=caller_container) + true_name, true_fn, [], {}) false_graph = _function.func_graph_from_py_func( - false_fn, [], [], - name=false_name, - device=caller_device, - colocation_stack=caller_colocation_stack, - collections_ref=caller_collection_ref, - container=caller_container) + false_name, false_fn, [], {}) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, - true_graph.extra_inputs, - false_graph.extra_inputs) + true_graph.external_captures, + false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. @@ -148,8 +128,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name true_graph, false_graph = _get_func_graphs(op) # Note: op.graph != ops.get_default_graph() when we are computing the gradient # of a nested cond. - assert true_graph._outer_graph == op.graph - assert false_graph._outer_graph == op.graph + assert true_graph.outer_graph == op.graph + assert false_graph.outer_graph == op.graph # Create grad functions that compute the gradient of the true/false forward # graphs. These functions will capture tensors from the forward pass @@ -164,14 +144,13 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name # Resolve references to forward graph tensors in grad graphs and ensure # they are in-scope, i.e., belong to one of outer graphs of the grad graph. - true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) - false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) + true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) + false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) # Make the inputs to true_grad_graph and false_grad_graph match. Note that # this modifies true_grad_graph and false_grad_graph. grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph, - true_grad_extra_inputs, - false_grad_extra_inputs) + true_grad_inputs, false_grad_inputs) # Add all intermediate tensors as function outputs so they're available for # higher-order gradient computations. @@ -211,8 +190,8 @@ def _get_func_graphs(if_op): """ def _get_func_graph_for_branch(branch_name): """Generates and returns a _FuncGraph for the given branch.""" - extra_inputs = if_op.inputs[1:] # First input is pred. - input_shapes = [t.shape for t in extra_inputs] + inputs = if_op.inputs[1:] # First input is pred. + input_shapes = [t.shape for t in inputs] func_name = if_op.get_attr(branch_name).name fdef = if_op.graph._get_function(func_name).definition # `if_op.graph` may not be the same as `ops.get_default_graph()` e.g. @@ -224,9 +203,8 @@ def _get_func_graphs(if_op): with if_op.graph.as_default(): func_graph = _function_def_to_graph.function_def_to_graph( fdef, input_shapes) - func_graph.extra_inputs = extra_inputs - func_graph.extra_args = func_graph.inputs - func_graph._captured = dict(zip(extra_inputs, func_graph.inputs)) + func_graph.captures = collections.OrderedDict(zip(inputs, + func_graph.inputs)) # Set the if op so that the gradient code can use it. func_graph._if = if_op return func_graph @@ -282,12 +260,12 @@ def _grad_fn(func_graph, grads): def _create_grad_func(func_graph, grads, name): """Returns the _FuncGraph representation of _grad_fn.""" - return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads), - [], [], name) + return _function.func_graph_from_py_func( + name, lambda: _grad_fn(func_graph, grads), [], {}) def _resolve_grad_inputs(cond_graph, grad_graph): - """Returns the tensors to pass as `extra_inputs` to `grad_graph`. + """Returns the tensors to pass as inputs to `grad_graph`. The `grad_graph` may have external references to 1. Its outer graph containing the input gradients. These references are kept @@ -305,10 +283,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph): Returns: A list of inputs tensors to be passed to grad_graph. """ - new_extra_inputs = [] + new_inputs = [] - for t in grad_graph.extra_inputs: - if t.graph != grad_graph._outer_graph: + for t in grad_graph.external_captures: + if t.graph != grad_graph.outer_graph: # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this # tensor to the least common ancestor of the `cond_graph` and # `grad_graph` so that it is "in-scope" for `grad_graph`. @@ -316,19 +294,19 @@ def _resolve_grad_inputs(cond_graph, grad_graph): # common ancestor once and re-use. assert _is_ancestor(cond_graph, t.graph) while not _is_ancestor(grad_graph, t.graph): - assert isinstance(t.graph, _function._FuncGraph) - if t in t.graph.extra_args: - # TODO(srbs): Consider building a map of extra_args -> extra_inputs. - # instead of searching for `t` twice. - t = t.graph.extra_inputs[t.graph.extra_args.index(t)] + assert isinstance(t.graph, _function.FuncGraph) + if t in t.graph.internal_captures: + # TODO(srbs): Consider building a map of internal_captures -> + # external_captures instead of searching for `t` twice. + t = t.graph.external_captures[t.graph.internal_captures.index(t)] else: # Note: All intermediate tensors are output by the If op. # TODO(srbs): .index() calls may be expensive. Optimize. t = t.graph._if.outputs[t.graph.outputs.index(t)] assert _is_ancestor(grad_graph, t.graph) - new_extra_inputs.append(t) + new_inputs.append(t) - return new_extra_inputs + return new_inputs def _create_new_tf_function(func_graph): @@ -340,26 +318,9 @@ def _create_new_tf_function(func_graph): Returns: The name of the new TF_Function. """ - c_func = c_api.TF_GraphToFunction_wrapper( - func_graph._c_graph, - compat.as_str(func_graph.name), - False, # append_hash_to_fn_name - None, # opers - [t._as_tf_output() for t in func_graph.inputs], - [t._as_tf_output() for t in func_graph.outputs], - [], - None, # opts - None) # description - _ = c_api_util.ScopedTFFunction(c_func) - - # TODO(b/109833212): this sucks, we're serializing the TF_Function*, - # deserializing it into a Python FunctionDef, then reserializing it to create - # a new TF_Function that we add to the graph. - fdef = _function.function_def_from_tf_function(c_func) - defined_func = _function._from_definition(fdef) - defined_func._sub_functions = func_graph._functions - defined_func.add_to_graph(func_graph._outer_graph) - + func = _function._EagerDefinedFunction( + func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {}) + func.add_to_graph(func_graph.outer_graph) return func_graph.name @@ -421,21 +382,20 @@ def _pad_params(true_graph, false_graph, true_params, false_params): return new_true_params, new_false_inputs -def _make_inputs_match(true_graph, false_graph, true_extra_inputs, - false_extra_inputs): +def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs): """Modifies true_graph and false_graph so they have the same input signature. This method reorders and/or adds parameters to true_graph and false_graph so - they have the same input signature, and updates the 'inputs', 'extra_inputs', - and '_captured' fields of both graphs accordingly. It uses the input tensors - from the outer graph to avoid duplicating shared arguments. + they have the same input signature, and updates the 'inputs' and 'captured' + fields of both graphs accordingly. It uses the input tensors from the outer + graph to avoid duplicating shared arguments. Args: true_graph: function._FuncGraph false_graph: function._FuncGraph - true_extra_inputs: a list of Tensors in the outer graph. The inputs for + true_inputs: a list of Tensors in the outer graph. The inputs for true_graph. - false_extra_inputs: a list of Tensors in the outer graph. The inputs for + false_inputs: a list of Tensors in the outer graph. The inputs for false_graph. Returns: @@ -444,12 +404,12 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs, false_inputs. """ shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs( - true_extra_inputs, false_extra_inputs) + true_inputs, false_inputs) new_inputs = shared_inputs + true_only_inputs + false_only_inputs - true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs)) - false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs)) + true_input_to_param = dict(zip(true_inputs, true_graph.inputs)) + false_input_to_param = dict(zip(false_inputs, false_graph.inputs)) true_graph.inputs = ( [true_input_to_param[t] for t in shared_inputs] + @@ -462,14 +422,10 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs, [false_input_to_param[t] for t in false_only_inputs]) # Rewrite the _FuncGraphs' state to reflect the new inputs. - true_graph.extra_inputs = new_inputs - false_graph.extra_inputs = new_inputs - - true_graph.extra_args = true_graph.inputs - false_graph.extra_args = false_graph.inputs - - true_graph._captured = dict(zip(new_inputs, true_graph.inputs)) - false_graph._captured = dict(zip(new_inputs, false_graph.inputs)) + true_graph.captures = collections.OrderedDict(zip(new_inputs, + true_graph.inputs)) + false_graph.captures = collections.OrderedDict(zip(new_inputs, + false_graph.inputs)) return new_inputs @@ -506,10 +462,10 @@ def _get_grad_fn_name(func_graph): counter = 1 has_conflict = True while has_conflict: - curr_graph = func_graph._outer_graph + curr_graph = func_graph.outer_graph has_conflict = curr_graph._is_function(name) - while not has_conflict and isinstance(curr_graph, _function._FuncGraph): - curr_graph = curr_graph._outer_graph + while not has_conflict and isinstance(curr_graph, _function.FuncGraph): + curr_graph = curr_graph.outer_graph has_conflict = curr_graph._is_function(name) if has_conflict: name = "%s_%s" % (base_name, counter) @@ -534,6 +490,6 @@ def _check_same_outputs(true_graph, false_graph): def _is_ancestor(graph, maybe_ancestor): if maybe_ancestor == graph: return True - if isinstance(graph, _function._FuncGraph): - return _is_ancestor(graph._outer_graph, maybe_ancestor) + if isinstance(graph, _function.FuncGraph): + return _is_ancestor(graph.outer_graph, maybe_ancestor) return False |