aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-04 15:11:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 15:16:24 -0700
commit26d3617d2ab5f4874b73059be524e94b9535465b (patch)
treed4e003ef8d0a675f51c3fa9228ca8eb051fc34ac /tensorflow/python
parent2e2e89699c1186eef157911b57e4b062de376ce9 (diff)
Avoid creating control edges on not-this-graph.
PiperOrigin-RevId: 215811680
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/function.py17
-rw-r--r--tensorflow/python/ops/control_flow_ops.py3
2 files changed, 10 insertions, 10 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index dd9f5e233c..2750461fb2 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -269,15 +269,6 @@ class FuncGraph(ops.Graph):
def variables(self, var_list):
self._weak_variables = [weakref.ref(v) for v in var_list]
- def control_dependencies(self, control_inputs):
- # Drop control dependencies to outside of the graph. TODO(b/117109273)
- # unclear how to capture an op, not a tensor.
- if not control_inputs:
- return super(FuncGraph, self).control_dependencies(control_inputs)
- return super(FuncGraph, self).control_dependencies(
- [c for c in control_inputs
- if getattr(c, "graph", None) is self])
-
def create_op(
self,
op_type,
@@ -503,6 +494,9 @@ class _EagerDefinedFunction(object):
Returns:
The outputs of the function call.
+
+ Raises:
+ ValueError: if the number of arguments is incorrect.
"""
executing_eagerly = ctx.executing_eagerly()
@@ -536,6 +530,10 @@ class _EagerDefinedFunction(object):
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
# branch if a TPU kernel is registered for `PartitionedCall`.
+ if len(args) != len(self.signature.input_arg):
+ raise ValueError(
+ "Arguments and signature arguments do not match: %s %s " %
+ (len(args), len(list(self.signature.input_arg))))
outputs = functional_ops.partitioned_call(
args=args,
f=self,
@@ -756,7 +754,6 @@ class Function(object):
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
forward_function_attr.update(self._attrs)
-
self._forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs + backwards_graph_captures,
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index f779c3d273..5bc217d355 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1333,6 +1333,9 @@ class ControlFlowState(object):
"""
if util.IsLoopSwitch(op):
return None
+ if op.graph._building_function: # pylint: disable=protected-access
+ # The optimization here is tricky to apply to functions
+ return array_ops.zeros_like(op.outputs[index])
dead_branch = util.IsSwitch(op)
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)