diff options
author | Alexandre Passos <apassos@google.com> | 2018-10-04 15:11:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 15:16:24 -0700 |
commit | 26d3617d2ab5f4874b73059be524e94b9535465b (patch) | |
tree | d4e003ef8d0a675f51c3fa9228ca8eb051fc34ac /tensorflow/python | |
parent | 2e2e89699c1186eef157911b57e4b062de376ce9 (diff) |
Avoid creating control edges on not-this-graph.
PiperOrigin-RevId: 215811680
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/function.py | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 3 |
2 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dd9f5e233c..2750461fb2 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -269,15 +269,6 @@ class FuncGraph(ops.Graph): def variables(self, var_list): self._weak_variables = [weakref.ref(v) for v in var_list] - def control_dependencies(self, control_inputs): - # Drop control dependencies to outside of the graph. TODO(b/117109273) - # unclear how to capture an op, not a tensor. - if not control_inputs: - return super(FuncGraph, self).control_dependencies(control_inputs) - return super(FuncGraph, self).control_dependencies( - [c for c in control_inputs - if getattr(c, "graph", None) is self]) - def create_op( self, op_type, @@ -503,6 +494,9 @@ class _EagerDefinedFunction(object): Returns: The outputs of the function call. + + Raises: + ValueError: if the number of arguments is incorrect. """ executing_eagerly = ctx.executing_eagerly() @@ -536,6 +530,10 @@ class _EagerDefinedFunction(object): # TODO(akshayka): Either remove this if the FunctionLibraryRuntime # creates `PartitionedCallOp` kernels by default, or remove the previous # branch if a TPU kernel is registered for `PartitionedCall`. + if len(args) != len(self.signature.input_arg): + raise ValueError( + "Arguments and signature arguments do not match: %s %s " % + (len(args), len(list(self.signature.input_arg)))) outputs = functional_ops.partitioned_call( args=args, f=self, @@ -756,7 +754,6 @@ class Function(object): BACKWARD_FUNCTION_ATTRIBUTE_NAME: self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access forward_function_attr.update(self._attrs) - self._forward_function = _EagerDefinedFunction( forward_function_name, self._func_graph, self._func_graph.inputs, self._func_graph.outputs + backwards_graph_captures, diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index f779c3d273..5bc217d355 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1333,6 +1333,9 @@ class ControlFlowState(object): """ if util.IsLoopSwitch(op): return None + if op.graph._building_function: # pylint: disable=protected-access + # The optimization here is tricky to apply to functions + return array_ops.zeros_like(op.outputs[index]) dead_branch = util.IsSwitch(op) forward_ctxt = _GetWhileContext(op) grad_state = self._map.get(forward_ctxt) |