diff options
author | Alexandre Passos <apassos@google.com> | 2018-10-04 15:11:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 15:16:24 -0700 |
commit | 26d3617d2ab5f4874b73059be524e94b9535465b (patch) | |
tree | d4e003ef8d0a675f51c3fa9228ca8eb051fc34ac /tensorflow/python/eager | |
parent | 2e2e89699c1186eef157911b57e4b062de376ce9 (diff) |
Avoid creating control edges on not-this-graph.
PiperOrigin-RevId: 215811680
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dd9f5e233c..2750461fb2 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -269,15 +269,6 @@ class FuncGraph(ops.Graph): def variables(self, var_list): self._weak_variables = [weakref.ref(v) for v in var_list] - def control_dependencies(self, control_inputs): - # Drop control dependencies to outside of the graph. TODO(b/117109273) - # unclear how to capture an op, not a tensor. - if not control_inputs: - return super(FuncGraph, self).control_dependencies(control_inputs) - return super(FuncGraph, self).control_dependencies( - [c for c in control_inputs - if getattr(c, "graph", None) is self]) - def create_op( self, op_type, @@ -503,6 +494,9 @@ class _EagerDefinedFunction(object): Returns: The outputs of the function call. + + Raises: + ValueError: if the number of arguments is incorrect. """ executing_eagerly = ctx.executing_eagerly() @@ -536,6 +530,10 @@ class _EagerDefinedFunction(object): # TODO(akshayka): Either remove this if the FunctionLibraryRuntime # creates `PartitionedCallOp` kernels by default, or remove the previous # branch if a TPU kernel is registered for `PartitionedCall`. + if len(args) != len(self.signature.input_arg): + raise ValueError( + "Arguments and signature arguments do not match: %s %s " % + (len(args), len(list(self.signature.input_arg)))) outputs = functional_ops.partitioned_call( args=args, f=self, @@ -756,7 +754,6 @@ class Function(object): BACKWARD_FUNCTION_ATTRIBUTE_NAME: self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access forward_function_attr.update(self._attrs) - self._forward_function = _EagerDefinedFunction( forward_function_name, self._func_graph, self._func_graph.inputs, self._func_graph.outputs + backwards_graph_captures, |