aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r--tensorflow/python/ops/gradients_impl.py30
1 files changed, 14 insertions, 16 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index aac95037dc..6909fcaed5 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -800,23 +800,21 @@ def _GradientsHelper(ys,
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- # A grad_fn must be defined, either as a function or as None
- # for ops that do not have gradients.
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))