aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r--tensorflow/python/ops/gradients_impl.py52
1 files changed, 27 insertions, 25 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index bd8a5c86ac..52ce451238 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -34,6 +34,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
@@ -45,6 +46,7 @@ from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import
+from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import tf_logging as logging
@@ -920,8 +922,6 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
Raises:
LookupError: if one of the operations between `xs` and `ys` does not
have a registered gradient function.
- ValueError: if the arguments are invalid or not supported. Currently,
- this function only supports one-dimensional `x` in `xs`.
"""
xs = _AsList(xs)
kwargs = {
@@ -929,28 +929,30 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
'gate_gradients': gate_gradients,
'aggregation_method': aggregation_method
}
- # Compute a hessian matrix for each x in xs
+ # Compute first-order derivatives and iterate for each x in xs.
hessians = []
- for i, x in enumerate(xs):
- # Check dimensions
- ndims = x.get_shape().ndims
- if ndims is None:
- raise ValueError('Cannot compute Hessian because the dimensionality of '
- 'element number %d of `xs` cannot be determined' % i)
- elif ndims != 1:
- raise ValueError('Computing hessians is currently only supported for '
- 'one-dimensional tensors. Element number %d of `xs` has '
- '%d dimensions.' % (i, ndims))
- with ops.name_scope(name + '_first_derivative'):
- # Compute the partial derivatives of the input with respect to all
- # elements of `x`
- _gradients = gradients(ys, x, **kwargs)[0]
- # Unpack the gradients into a list so we can take derivatives with
- # respect to each element
- _gradients = array_ops.unstack(_gradients)
- with ops.name_scope(name + '_second_derivative'):
- # Compute the partial derivatives with respect to each element of the list
- _hess = [gradients(_gradient, x, **kwargs)[0] for _gradient in _gradients]
- # Pack the list into a matrix and add to the list of hessians
- hessians.append(array_ops.stack(_hess, name=name))
+ _gradients = gradients(ys, xs, **kwargs)
+ for i, _gradient, x in zip(range(len(xs)), _gradients, xs):
+ # Ensure that x is a vector.
+ check_rank = check_ops.assert_rank(
+ x, 1, message='Cannot compute Hessian because element %d of `xs` does '
+ 'not have rank one.' % i
+ )
+ with ops.control_dependencies([check_rank]):
+ # Declare an iterator and tensor array loop variables for the gradients.
+ n = array_ops.size(x)
+ loop_vars = [
+ array_ops.constant(0, dtypes.int32),
+ tensor_array_ops.TensorArray(x.dtype, n)
+ ]
+ # Iterate over all elements of the gradient and compute second order
+ # derivatives.
+ _, hessian = control_flow_ops.while_loop(
+ lambda j, _: j < n,
+ lambda j, result: (j + 1,
+ result.write(j, gradients(_gradient[j], x)[0])),
+ loop_vars
+ )
+
+ hessians.append(hessian.stack())
return hessians