diff options
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 52 |
1 files changed, 27 insertions, 25 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index bd8a5c86ac..52ce451238 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_grad # pylint: disable=unused-import from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops @@ -45,6 +46,7 @@ from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import spectral_grad # pylint: disable=unused-import +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import tf_logging as logging @@ -920,8 +922,6 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, Raises: LookupError: if one of the operations between `xs` and `ys` does not have a registered gradient function. - ValueError: if the arguments are invalid or not supported. Currently, - this function only supports one-dimensional `x` in `xs`. """ xs = _AsList(xs) kwargs = { @@ -929,28 +929,30 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, 'gate_gradients': gate_gradients, 'aggregation_method': aggregation_method } - # Compute a hessian matrix for each x in xs + # Compute first-order derivatives and iterate for each x in xs. hessians = [] - for i, x in enumerate(xs): - # Check dimensions - ndims = x.get_shape().ndims - if ndims is None: - raise ValueError('Cannot compute Hessian because the dimensionality of ' - 'element number %d of `xs` cannot be determined' % i) - elif ndims != 1: - raise ValueError('Computing hessians is currently only supported for ' - 'one-dimensional tensors. Element number %d of `xs` has ' - '%d dimensions.' % (i, ndims)) - with ops.name_scope(name + '_first_derivative'): - # Compute the partial derivatives of the input with respect to all - # elements of `x` - _gradients = gradients(ys, x, **kwargs)[0] - # Unpack the gradients into a list so we can take derivatives with - # respect to each element - _gradients = array_ops.unstack(_gradients) - with ops.name_scope(name + '_second_derivative'): - # Compute the partial derivatives with respect to each element of the list - _hess = [gradients(_gradient, x, **kwargs)[0] for _gradient in _gradients] - # Pack the list into a matrix and add to the list of hessians - hessians.append(array_ops.stack(_hess, name=name)) + _gradients = gradients(ys, xs, **kwargs) + for i, _gradient, x in zip(range(len(xs)), _gradients, xs): + # Ensure that x is a vector. + check_rank = check_ops.assert_rank( + x, 1, message='Cannot compute Hessian because element %d of `xs` does ' + 'not have rank one.' % i + ) + with ops.control_dependencies([check_rank]): + # Declare an iterator and tensor array loop variables for the gradients. + n = array_ops.size(x) + loop_vars = [ + array_ops.constant(0, dtypes.int32), + tensor_array_ops.TensorArray(x.dtype, n) + ] + # Iterate over all elements of the gradient and compute second order + # derivatives. + _, hessian = control_flow_ops.while_loop( + lambda j, _: j < n, + lambda j, result: (j + 1, + result.write(j, gradients(_gradient[j], x)[0])), + loop_vars + ) + + hessians.append(hessian.stack()) return hessians |