aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/cond_v2_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/cond_v2_impl.py')
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py148
1 files changed, 52 insertions, 96 deletions
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index b3dacff6d6..c4e9c982b5 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -27,14 +27,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from tensorflow.core.framework import attr_value_pb2
-from tensorflow.python import pywrap_tensorflow as c_api
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
-from tensorflow.python.util import compat
# The following modules cannot be imported directly because they cause circular
@@ -57,46 +56,27 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
name = "cond"
with ops.name_scope(name) as scope:
- # Identify if there is a caller device, & get the innermost if possible.
- # pylint: disable=protected-access
- device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
- caller_device = device_funcs[-1] if device_funcs else None
-
- caller_colocation_stack = ops.get_default_graph()._colocation_stack
- caller_container = ops.get_default_graph()._container
- caller_collection_ref = ops.get_default_graph()._collections
-
with ops.name_scope(None):
# Find the outer most graph for uniquing function names.
# TODO(jpienaar): Make this work in eager mode.
graph = ops.get_default_graph()
- while isinstance(graph, _function._FuncGraph):
- graph = graph._outer_graph
+ while isinstance(graph, _function.FuncGraph):
+ graph = graph.outer_graph
true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))
- # pylint: enable=protected-access
+
true_graph = _function.func_graph_from_py_func(
- true_fn, [], [],
- name=true_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ true_name, true_fn, [], {})
false_graph = _function.func_graph_from_py_func(
- false_fn, [], [],
- name=false_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ false_name, false_fn, [], {})
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that
# this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match(true_graph, false_graph,
- true_graph.extra_inputs,
- false_graph.extra_inputs)
+ true_graph.external_captures,
+ false_graph.external_captures)
# Add all intermediate tensors as function outputs so they're available for
# the gradient computation.
@@ -148,8 +128,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
true_graph, false_graph = _get_func_graphs(op)
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
- assert true_graph._outer_graph == op.graph
- assert false_graph._outer_graph == op.graph
+ assert true_graph.outer_graph == op.graph
+ assert false_graph.outer_graph == op.graph
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
@@ -164,14 +144,13 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
- true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
- false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
+ true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
+ false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# Make the inputs to true_grad_graph and false_grad_graph match. Note that
# this modifies true_grad_graph and false_grad_graph.
grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
- true_grad_extra_inputs,
- false_grad_extra_inputs)
+ true_grad_inputs, false_grad_inputs)
# Add all intermediate tensors as function outputs so they're available for
# higher-order gradient computations.
@@ -211,8 +190,8 @@ def _get_func_graphs(if_op):
"""
def _get_func_graph_for_branch(branch_name):
"""Generates and returns a _FuncGraph for the given branch."""
- extra_inputs = if_op.inputs[1:] # First input is pred.
- input_shapes = [t.shape for t in extra_inputs]
+ inputs = if_op.inputs[1:] # First input is pred.
+ input_shapes = [t.shape for t in inputs]
func_name = if_op.get_attr(branch_name).name
fdef = if_op.graph._get_function(func_name).definition
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
@@ -224,9 +203,8 @@ def _get_func_graphs(if_op):
with if_op.graph.as_default():
func_graph = _function_def_to_graph.function_def_to_graph(
fdef, input_shapes)
- func_graph.extra_inputs = extra_inputs
- func_graph.extra_args = func_graph.inputs
- func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ func_graph.captures = collections.OrderedDict(zip(inputs,
+ func_graph.inputs))
# Set the if op so that the gradient code can use it.
func_graph._if = if_op
return func_graph
@@ -282,12 +260,12 @@ def _grad_fn(func_graph, grads):
def _create_grad_func(func_graph, grads, name):
"""Returns the _FuncGraph representation of _grad_fn."""
- return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads),
- [], [], name)
+ return _function.func_graph_from_py_func(
+ name, lambda: _grad_fn(func_graph, grads), [], {})
def _resolve_grad_inputs(cond_graph, grad_graph):
- """Returns the tensors to pass as `extra_inputs` to `grad_graph`.
+ """Returns the tensors to pass as inputs to `grad_graph`.
The `grad_graph` may have external references to
1. Its outer graph containing the input gradients. These references are kept
@@ -305,10 +283,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
Returns:
A list of inputs tensors to be passed to grad_graph.
"""
- new_extra_inputs = []
+ new_inputs = []
- for t in grad_graph.extra_inputs:
- if t.graph != grad_graph._outer_graph:
+ for t in grad_graph.external_captures:
+ if t.graph != grad_graph.outer_graph:
# `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
# tensor to the least common ancestor of the `cond_graph` and
# `grad_graph` so that it is "in-scope" for `grad_graph`.
@@ -316,19 +294,19 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
# common ancestor once and re-use.
assert _is_ancestor(cond_graph, t.graph)
while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function._FuncGraph)
- if t in t.graph.extra_args:
- # TODO(srbs): Consider building a map of extra_args -> extra_inputs.
- # instead of searching for `t` twice.
- t = t.graph.extra_inputs[t.graph.extra_args.index(t)]
+ assert isinstance(t.graph, _function.FuncGraph)
+ if t in t.graph.internal_captures:
+ # TODO(srbs): Consider building a map of internal_captures ->
+ # external_captures instead of searching for `t` twice.
+ t = t.graph.external_captures[t.graph.internal_captures.index(t)]
else:
# Note: All intermediate tensors are output by the If op.
# TODO(srbs): .index() calls may be expensive. Optimize.
t = t.graph._if.outputs[t.graph.outputs.index(t)]
assert _is_ancestor(grad_graph, t.graph)
- new_extra_inputs.append(t)
+ new_inputs.append(t)
- return new_extra_inputs
+ return new_inputs
def _create_new_tf_function(func_graph):
@@ -340,26 +318,9 @@ def _create_new_tf_function(func_graph):
Returns:
The name of the new TF_Function.
"""
- c_func = c_api.TF_GraphToFunction_wrapper(
- func_graph._c_graph,
- compat.as_str(func_graph.name),
- False, # append_hash_to_fn_name
- None, # opers
- [t._as_tf_output() for t in func_graph.inputs],
- [t._as_tf_output() for t in func_graph.outputs],
- [],
- None, # opts
- None) # description
- _ = c_api_util.ScopedTFFunction(c_func)
-
- # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
- # deserializing it into a Python FunctionDef, then reserializing it to create
- # a new TF_Function that we add to the graph.
- fdef = _function.function_def_from_tf_function(c_func)
- defined_func = _function._from_definition(fdef)
- defined_func._sub_functions = func_graph._functions
- defined_func.add_to_graph(func_graph._outer_graph)
-
+ func = _function._EagerDefinedFunction(
+ func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
+ func.add_to_graph(func_graph.outer_graph)
return func_graph.name
@@ -421,21 +382,20 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
return new_true_params, new_false_inputs
-def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
- false_extra_inputs):
+def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
"""Modifies true_graph and false_graph so they have the same input signature.
This method reorders and/or adds parameters to true_graph and false_graph so
- they have the same input signature, and updates the 'inputs', 'extra_inputs',
- and '_captured' fields of both graphs accordingly. It uses the input tensors
- from the outer graph to avoid duplicating shared arguments.
+ they have the same input signature, and updates the 'inputs' and 'captured'
+ fields of both graphs accordingly. It uses the input tensors from the outer
+ graph to avoid duplicating shared arguments.
Args:
true_graph: function._FuncGraph
false_graph: function._FuncGraph
- true_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ true_inputs: a list of Tensors in the outer graph. The inputs for
true_graph.
- false_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ false_inputs: a list of Tensors in the outer graph. The inputs for
false_graph.
Returns:
@@ -444,12 +404,12 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
false_inputs.
"""
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
- true_extra_inputs, false_extra_inputs)
+ true_inputs, false_inputs)
new_inputs = shared_inputs + true_only_inputs + false_only_inputs
- true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs))
- false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs))
+ true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
+ false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
true_graph.inputs = (
[true_input_to_param[t] for t in shared_inputs] +
@@ -462,14 +422,10 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
[false_input_to_param[t] for t in false_only_inputs])
# Rewrite the _FuncGraphs' state to reflect the new inputs.
- true_graph.extra_inputs = new_inputs
- false_graph.extra_inputs = new_inputs
-
- true_graph.extra_args = true_graph.inputs
- false_graph.extra_args = false_graph.inputs
-
- true_graph._captured = dict(zip(new_inputs, true_graph.inputs))
- false_graph._captured = dict(zip(new_inputs, false_graph.inputs))
+ true_graph.captures = collections.OrderedDict(zip(new_inputs,
+ true_graph.inputs))
+ false_graph.captures = collections.OrderedDict(zip(new_inputs,
+ false_graph.inputs))
return new_inputs
@@ -506,10 +462,10 @@ def _get_grad_fn_name(func_graph):
counter = 1
has_conflict = True
while has_conflict:
- curr_graph = func_graph._outer_graph
+ curr_graph = func_graph.outer_graph
has_conflict = curr_graph._is_function(name)
- while not has_conflict and isinstance(curr_graph, _function._FuncGraph):
- curr_graph = curr_graph._outer_graph
+ while not has_conflict and isinstance(curr_graph, _function.FuncGraph):
+ curr_graph = curr_graph.outer_graph
has_conflict = curr_graph._is_function(name)
if has_conflict:
name = "%s_%s" % (base_name, counter)
@@ -534,6 +490,6 @@ def _check_same_outputs(true_graph, false_graph):
def _is_ancestor(graph, maybe_ancestor):
if maybe_ancestor == graph:
return True
- if isinstance(graph, _function._FuncGraph):
- return _is_ancestor(graph._outer_graph, maybe_ancestor)
+ if isinstance(graph, _function.FuncGraph):
+ return _is_ancestor(graph.outer_graph, maybe_ancestor)
return False