diff options
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 30 |
1 files changed, 14 insertions, 16 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index aac95037dc..6909fcaed5 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -800,23 +800,21 @@ def _GradientsHelper(ys, # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op not in stop_ops): - if is_func_call: - if is_partitioned_call: - func_call = src_graph._get_function( # pylint: disable=protected-access - compat.as_bytes(op.get_attr("f").name)) + try: + grad_fn = ops.get_gradient_function(op) + except LookupError: + if is_func_call: + if is_partitioned_call: + func_call = src_graph._get_function( # pylint: disable=protected-access + compat.as_bytes(op.get_attr("f").name)) + else: + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access + # Note that __defun is not set if the graph is + # imported. If it's set, we prefer to access the original + # defun. + func_call = getattr(op, "__defun", func_call) + grad_fn = func_call.python_grad_func else: - func_call = src_graph._get_function(op.type) # pylint: disable=protected-access - # Note that __defun is not set if the graph is - # imported. If it's set, we prefer to access the original - # defun. - func_call = getattr(op, "__defun", func_call) - grad_fn = func_call.python_grad_func - else: - # A grad_fn must be defined, either as a function or as None - # for ops that do not have gradients. - try: - grad_fn = ops.get_gradient_function(op) - except LookupError: raise LookupError( "No gradient defined for operation '%s' (op type: %s)" % (op.name, op.type)) |