aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-04 15:11:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 15:16:24 -0700
commit26d3617d2ab5f4874b73059be524e94b9535465b (patch)
treed4e003ef8d0a675f51c3fa9228ca8eb051fc34ac /tensorflow/python/eager
parent2e2e89699c1186eef157911b57e4b062de376ce9 (diff)
Avoid creating control edges on not-this-graph.
PiperOrigin-RevId: 215811680
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index dd9f5e233c..2750461fb2 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -269,15 +269,6 @@ class FuncGraph(ops.Graph):
def variables(self, var_list):
self._weak_variables = [weakref.ref(v) for v in var_list]
- def control_dependencies(self, control_inputs):
- # Drop control dependencies to outside of the graph. TODO(b/117109273)
- # unclear how to capture an op, not a tensor.
- if not control_inputs:
- return super(FuncGraph, self).control_dependencies(control_inputs)
- return super(FuncGraph, self).control_dependencies(
- [c for c in control_inputs
- if getattr(c, "graph", None) is self])
-
def create_op(
self,
op_type,
@@ -503,6 +494,9 @@ class _EagerDefinedFunction(object):
Returns:
The outputs of the function call.
+
+ Raises:
+ ValueError: if the number of arguments is incorrect.
"""
executing_eagerly = ctx.executing_eagerly()
@@ -536,6 +530,10 @@ class _EagerDefinedFunction(object):
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
# branch if a TPU kernel is registered for `PartitionedCall`.
+ if len(args) != len(self.signature.input_arg):
+ raise ValueError(
+ "Arguments and signature arguments do not match: %s %s " %
+ (len(args), len(list(self.signature.input_arg))))
outputs = functional_ops.partitioned_call(
args=args,
f=self,
@@ -756,7 +754,6 @@ class Function(object):
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
forward_function_attr.update(self._attrs)
-
self._forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs + backwards_graph_captures,