diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-27 14:28:12 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-27 14:31:07 -0700 |
commit | a4dbc33512adb3705345b093a0aafec151e7e32d (patch) | |
tree | 7922e1cf71084d592008299f99a0cf509a0714bc /tensorflow/python/ops/gradients_impl.py | |
parent | 6da711a50c3ef98aebacd6a909596a0f74b783e1 (diff) |
If two identical functions are given different grad func,
they should be named differently. Otherwise, tf.gradients
gets confused.
PiperOrigin-RevId: 194593519
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 581ba7de48..1448151fef 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -256,21 +256,21 @@ def _DefaultGradYs(grad_ys, continue if y.dtype.is_floating or y.dtype.is_integer: if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer: - raise TypeError("Gradient type %s generated for real or " - "integer-valued tensor %s with type %s must be " - "real or integer" % - (dtypes.as_dtype(grad_y.dtype).name, y, - dtypes.as_dtype(y.dtype).name)) + raise TypeError( + "Gradient type %s generated for real or " + "integer-valued tensor %s with type %s must be " + "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y, + dtypes.as_dtype(y.dtype).name)) elif y.dtype.is_complex: if not grad_y.dtype.is_complex: - raise TypeError("Gradient type %s generated for complex-valued " - "tensor %s with type %s must be real" % - (dtypes.as_dtype(grad_y.dtype).name, y, - dtypes.as_dtype(y.dtype).name)) + raise TypeError( + "Gradient type %s generated for complex-valued " + "tensor %s with type %s must be real" % (dtypes.as_dtype( + grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name)) else: - raise TypeError("Tensor %s with type %s must be numeric " - "to obtain a default gradient" % - (y, dtypes.as_dtype(y.dtype).name)) + raise TypeError( + "Tensor %s with type %s must be numeric " + "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name)) # Create a grad_y tensor in the name scope of the gradient. # Required for TensorArrays to identify which gradient call a # grad_y value is coming from. @@ -605,15 +605,19 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, loop_state.ExitGradWhileContext(op, before=True) grad_fn = None - # pylint: disable=protected-access func_call = None + # pylint: disable=protected-access is_func_call = ops.get_default_graph()._is_function(op.type) + # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op._id not in stop_ops): if is_func_call: func_call = ops.get_default_graph()._get_function(op.type) + # Note that __defun is not set if the graph is + # imported. If it's set, we prefer to access the original + # defun. + func_call = getattr(op, "__defun", func_call) grad_fn = func_call.python_grad_func - # pylint: enable=protected-access else: # A grad_fn must be defined, either as a function or as None # for ops that do not have gradients. |