aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_impl.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-27 14:28:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-27 14:31:07 -0700
commita4dbc33512adb3705345b093a0aafec151e7e32d (patch)
tree7922e1cf71084d592008299f99a0cf509a0714bc /tensorflow/python/ops/gradients_impl.py
parent6da711a50c3ef98aebacd6a909596a0f74b783e1 (diff)
If two identical functions are given different grad func,
they should be named differently. Otherwise, tf.gradients gets confused. PiperOrigin-RevId: 194593519
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r--tensorflow/python/ops/gradients_impl.py32
1 files changed, 18 insertions, 14 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 581ba7de48..1448151fef 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -256,21 +256,21 @@ def _DefaultGradYs(grad_ys,
continue
if y.dtype.is_floating or y.dtype.is_integer:
if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
- raise TypeError("Gradient type %s generated for real or "
- "integer-valued tensor %s with type %s must be "
- "real or integer" %
- (dtypes.as_dtype(grad_y.dtype).name, y,
- dtypes.as_dtype(y.dtype).name))
+ raise TypeError(
+ "Gradient type %s generated for real or "
+ "integer-valued tensor %s with type %s must be "
+ "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
+ dtypes.as_dtype(y.dtype).name))
elif y.dtype.is_complex:
if not grad_y.dtype.is_complex:
- raise TypeError("Gradient type %s generated for complex-valued "
- "tensor %s with type %s must be real" %
- (dtypes.as_dtype(grad_y.dtype).name, y,
- dtypes.as_dtype(y.dtype).name))
+ raise TypeError(
+ "Gradient type %s generated for complex-valued "
+ "tensor %s with type %s must be real" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
else:
- raise TypeError("Tensor %s with type %s must be numeric "
- "to obtain a default gradient" %
- (y, dtypes.as_dtype(y.dtype).name))
+ raise TypeError(
+ "Tensor %s with type %s must be numeric "
+ "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
# Create a grad_y tensor in the name scope of the gradient.
# Required for TensorArrays to identify which gradient call a
# grad_y value is coming from.
@@ -605,15 +605,19 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
loop_state.ExitGradWhileContext(op, before=True)
grad_fn = None
- # pylint: disable=protected-access
func_call = None
+ # pylint: disable=protected-access
is_func_call = ops.get_default_graph()._is_function(op.type)
+ # pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if is_func_call:
func_call = ops.get_default_graph()._get_function(op.type)
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
grad_fn = func_call.python_grad_func
- # pylint: enable=protected-access
else:
# A grad_fn must be defined, either as a function or as None
# for ops that do not have gradients.