diff options
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r-- | tensorflow/python/ops/gradients_impl.py | 49 |
1 files changed, 24 insertions, 25 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index f5fdb12b2c..20c7a9fd66 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -977,9 +977,7 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, `hessians()` adds ops to the graph to output the Hessian matrix of `ys` with respect to `xs`. It returns a list of `Tensor` of length `len(xs)` - where each tensor is the Hessian of `sum(ys)`. This function currently - only supports evaluating the Hessian with respect to (a list of) one- - dimensional tensors. + where each tensor is the Hessian of `sum(ys)`. The Hessian is a matrix of second-order partial derivatives of a scalar tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details). @@ -1005,31 +1003,32 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, 'colocate_gradients_with_ops': colocate_gradients_with_ops, 'gate_gradients': gate_gradients, 'aggregation_method': aggregation_method - } + } # Compute first-order derivatives and iterate for each x in xs. hessians = [] _gradients = gradients(ys, xs, **kwargs) - for i, _gradient, x in zip(range(len(xs)), _gradients, xs): - # Ensure that x is a vector. - check_rank = check_ops.assert_rank( - x, 1, message='Cannot compute Hessian because element %d of `xs` does ' - 'not have rank one.' % i - ) - with ops.control_dependencies([check_rank]): - # Declare an iterator and tensor array loop variables for the gradients. - n = array_ops.size(x) - loop_vars = [ + for gradient, x in zip(_gradients, xs): + # change shape to one-dimension without graph branching + gradient = array_ops.reshape(gradient, [-1]) + + # Declare an iterator and tensor array loop variables for the gradients. + n = array_ops.size(x) + loop_vars = [ array_ops.constant(0, dtypes.int32), tensor_array_ops.TensorArray(x.dtype, n) - ] - # Iterate over all elements of the gradient and compute second order - # derivatives. - _, hessian = control_flow_ops.while_loop( - lambda j, _: j < n, - lambda j, result: (j + 1, - result.write(j, gradients(_gradient[j], x)[0])), - loop_vars - ) - - hessians.append(hessian.stack()) + ] + # Iterate over all elements of the gradient and compute second order + # derivatives. + _, hessian = control_flow_ops.while_loop( + lambda j, _: j < n, + lambda j, result: (j + 1, + result.write(j, gradients(gradient[j], x)[0])), + loop_vars + ) + + _shape = array_ops.shape(x) + _reshaped_hessian = array_ops.reshape( + hessian.stack(), array_ops.concat((_shape, _shape), 0) + ) + hessians.append(_reshaped_hessian) return hessians |