aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/gradients_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/gradients_impl.py')
-rw-r--r--tensorflow/python/ops/gradients_impl.py49
1 files changed, 24 insertions, 25 deletions
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index f5fdb12b2c..20c7a9fd66 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -977,9 +977,7 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
`hessians()` adds ops to the graph to output the Hessian matrix of `ys`
with respect to `xs`. It returns a list of `Tensor` of length `len(xs)`
- where each tensor is the Hessian of `sum(ys)`. This function currently
- only supports evaluating the Hessian with respect to (a list of) one-
- dimensional tensors.
+ where each tensor is the Hessian of `sum(ys)`.
The Hessian is a matrix of second-order partial derivatives of a scalar
tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
@@ -1005,31 +1003,32 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
'colocate_gradients_with_ops': colocate_gradients_with_ops,
'gate_gradients': gate_gradients,
'aggregation_method': aggregation_method
- }
+ }
# Compute first-order derivatives and iterate for each x in xs.
hessians = []
_gradients = gradients(ys, xs, **kwargs)
- for i, _gradient, x in zip(range(len(xs)), _gradients, xs):
- # Ensure that x is a vector.
- check_rank = check_ops.assert_rank(
- x, 1, message='Cannot compute Hessian because element %d of `xs` does '
- 'not have rank one.' % i
- )
- with ops.control_dependencies([check_rank]):
- # Declare an iterator and tensor array loop variables for the gradients.
- n = array_ops.size(x)
- loop_vars = [
+ for gradient, x in zip(_gradients, xs):
+ # change shape to one-dimension without graph branching
+ gradient = array_ops.reshape(gradient, [-1])
+
+ # Declare an iterator and tensor array loop variables for the gradients.
+ n = array_ops.size(x)
+ loop_vars = [
array_ops.constant(0, dtypes.int32),
tensor_array_ops.TensorArray(x.dtype, n)
- ]
- # Iterate over all elements of the gradient and compute second order
- # derivatives.
- _, hessian = control_flow_ops.while_loop(
- lambda j, _: j < n,
- lambda j, result: (j + 1,
- result.write(j, gradients(_gradient[j], x)[0])),
- loop_vars
- )
-
- hessians.append(hessian.stack())
+ ]
+ # Iterate over all elements of the gradient and compute second order
+ # derivatives.
+ _, hessian = control_flow_ops.while_loop(
+ lambda j, _: j < n,
+ lambda j, result: (j + 1,
+ result.write(j, gradients(gradient[j], x)[0])),
+ loop_vars
+ )
+
+ _shape = array_ops.shape(x)
+ _reshaped_hessian = array_ops.reshape(
+ hessian.stack(), array_ops.concat((_shape, _shape), 0)
+ )
+ hessians.append(_reshaped_hessian)
return hessians