aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/solvers
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/contrib/solvers
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/contrib/solvers')
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py63
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/util_test.py37
-rw-r--r--tensorflow/contrib/solvers/python/ops/linear_equations.py52
-rw-r--r--tensorflow/contrib/solvers/python/ops/util.py17
4 files changed, 142 insertions, 27 deletions
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
index 930df2414b..a1282847be 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/linear_equations_test.py
@@ -45,32 +45,67 @@ def _get_linear_equations_tests(dtype_, use_static_shape_, shape_):
low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
# Make a selfadjoint, positive definite.
a_np = np.dot(a_np.T, a_np)
+ # jacobi preconditioner
+ jacobi_np = np.zeros_like(a_np)
+ jacobi_np[range(a_np.shape[0]), range(a_np.shape[1])] = (
+ 1.0 / a_np.diagonal())
rhs_np = np.random.uniform(
low=-1.0, high=1.0, size=shape_[0]).astype(dtype_)
+ x_np = np.zeros_like(rhs_np)
tol = 1e-6 if dtype_ == np.float64 else 1e-3
max_iter = 20
with self.test_session() as sess:
if use_static_shape_:
a = constant_op.constant(a_np)
rhs = constant_op.constant(rhs_np)
+ x = constant_op.constant(x_np)
+ jacobi = constant_op.constant(jacobi_np)
else:
a = array_ops.placeholder(dtype_)
rhs = array_ops.placeholder(dtype_)
+ x = array_ops.placeholder(dtype_)
+ jacobi = array_ops.placeholder(dtype_)
operator = util.create_operator(a)
- cg_graph = linear_equations.conjugate_gradient(
- operator, rhs, tol=tol, max_iter=max_iter)
- if use_static_shape_:
- cg_val = sess.run(cg_graph)
- else:
- cg_val = sess.run(cg_graph, feed_dict={a: a_np, rhs: rhs_np})
- norm_r0 = np.linalg.norm(rhs_np)
- norm_r = np.sqrt(cg_val.gamma)
- self.assertLessEqual(norm_r, tol * norm_r0)
- # Validate that we get an equally small residual norm with numpy
- # using the computed solution.
- r_np = rhs_np - np.dot(a_np, cg_val.x)
- norm_r_np = np.linalg.norm(r_np)
- self.assertLessEqual(norm_r_np, tol * norm_r0)
+ preconditioners = [
+ None, util.identity_operator(a),
+ util.create_operator(jacobi)
+ ]
+ cg_results = []
+ for preconditioner in preconditioners:
+ cg_graph = linear_equations.conjugate_gradient(
+ operator,
+ rhs,
+ preconditioner=preconditioner,
+ x=x,
+ tol=tol,
+ max_iter=max_iter)
+ if use_static_shape_:
+ cg_val = sess.run(cg_graph)
+ else:
+ cg_val = sess.run(
+ cg_graph,
+ feed_dict={
+ a: a_np,
+ rhs: rhs_np,
+ x: x_np,
+ jacobi: jacobi_np
+ })
+ norm_r0 = np.linalg.norm(rhs_np)
+ norm_r = np.linalg.norm(cg_val.r)
+ self.assertLessEqual(norm_r, tol * norm_r0)
+ # Validate that we get an equally small residual norm with numpy
+ # using the computed solution.
+ r_np = rhs_np - np.dot(a_np, cg_val.x)
+ norm_r_np = np.linalg.norm(r_np)
+ self.assertLessEqual(norm_r_np, tol * norm_r0)
+ cg_results.append(cg_val)
+ # Validate that we get same results using identity_preconditioner
+ # and None
+ self.assertEqual(cg_results[0].i, cg_results[1].i)
+ self.assertAlmostEqual(cg_results[0].gamma, cg_results[1].gamma)
+ self.assertAllClose(cg_results[0].r, cg_results[1].r, rtol=tol)
+ self.assertAllClose(cg_results[0].x, cg_results[1].x, rtol=tol)
+ self.assertAllClose(cg_results[0].p, cg_results[1].p, rtol=tol)
return [test_conjugate_gradient]
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 1566984b27..5d7534657b 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -63,6 +63,43 @@ class UtilTest(test.TestCase):
def testCreateOperatorUnknownShape(self):
self._testCreateOperator(False)
+ def _testIdentityOperator(self, use_static_shape_):
+ for dtype in np.float32, np.float64:
+ a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
+ x_np = np.array([[2.], [-3.]], dtype=dtype)
+ y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
+ with self.test_session() as sess:
+ if use_static_shape_:
+ a = constant_op.constant(a_np, dtype=dtype)
+ x = constant_op.constant(x_np, dtype=dtype)
+ y = constant_op.constant(y_np, dtype=dtype)
+ else:
+ a = array_ops.placeholder(dtype)
+ x = array_ops.placeholder(dtype)
+ y = array_ops.placeholder(dtype)
+ id_op = util.identity_operator(a)
+ ax = id_op.apply(x)
+ aty = id_op.apply_adjoint(y)
+ op_shape = ops.convert_to_tensor(id_op.shape)
+ if use_static_shape_:
+ op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty])
+ else:
+ op_shape_val, ax_val, aty_val = sess.run(
+ [op_shape, ax, aty], feed_dict={
+ a: a_np,
+ x: x_np,
+ y: y_np
+ })
+ self.assertAllEqual(op_shape_val, [3, 2])
+ self.assertAllClose(ax_val, x_np)
+ self.assertAllClose(aty_val, y_np)
+
+ def testIdentityOperator(self):
+ self._testIdentityOperator(True)
+
+ def testIdentityOperatorUnknownShape(self):
+ self._testIdentityOperator(False)
+
def testL2Norm(self):
with self.test_session():
x_np = np.array([[2], [-3.], [5.]])
diff --git a/tensorflow/contrib/solvers/python/ops/linear_equations.py b/tensorflow/contrib/solvers/python/ops/linear_equations.py
index 8cba56eba6..2395707257 100644
--- a/tensorflow/contrib/solvers/python/ops/linear_equations.py
+++ b/tensorflow/contrib/solvers/python/ops/linear_equations.py
@@ -26,11 +26,14 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
def conjugate_gradient(operator,
rhs,
+ preconditioner=None,
+ x=None,
tol=1e-4,
max_iter=20,
name="conjugate_gradient"):
@@ -55,6 +58,15 @@ def conjugate_gradient(operator,
vector with the result of applying the operator to `x`, i.e. if
`operator` represents matrix `A`, `apply` should return `A * x`.
rhs: A rank-1 `Tensor` of shape `[N]` containing the right-hand size vector.
+ preconditioner: An object representing a linear operator, see `operator`
+ for detail. The preconditioner should approximate the inverse of `A`.
+ An efficient preconditioner could dramatically improve the rate of
+ convergence. If `preconditioner` represents matrix `M`(`M` approximates
+ `A^{-1}`), the algorithm uses `preconditioner.apply(x)` to estimate
+ `A^{-1}x`. For this to be useful, the cost of applying `M` should be
+ much lower than computing `A^{-1}` directly.
+ x: A rank-1 `Tensor` of shape `[N]` containing the initial guess for the
+ solution.
tol: A float scalar convergence tolerance.
max_iter: An integer giving the maximum number of iterations.
name: A name scope for the operation.
@@ -65,35 +77,49 @@ def conjugate_gradient(operator,
- x: A rank-1 `Tensor` of shape `[N]` containing the computed solution.
- r: A rank-1 `Tensor` of shape `[M]` containing the residual vector.
- p: A rank-1 `Tensor` of shape `[N]`. `A`-conjugate basis vector.
- - gamma: \\(||r||_2^2\\)
+ - gamma: \\(r \dot M \dot r\\), equivalent to \\(||r||_2^2\\) when
+ `preconditioner=None`.
"""
# ephemeral class holding CG state.
cg_state = collections.namedtuple("CGState", ["i", "x", "r", "p", "gamma"])
def stopping_criterion(i, state):
- return math_ops.logical_and(i < max_iter, state.gamma > tol)
+ return math_ops.logical_and(i < max_iter, linalg_ops.norm(state.r) > tol)
- # TODO(rmlarsen): add preconditioning
- def cg_step(i, state):
+ def cg_step(i, state): # pylint: disable=missing-docstring
z = operator.apply(state.p)
alpha = state.gamma / util.dot(state.p, z)
x = state.x + alpha * state.p
r = state.r - alpha * z
- gamma = util.l2norm_squared(r)
- beta = gamma / state.gamma
- p = r + beta * state.p
+ if preconditioner is None:
+ gamma = util.dot(r, r)
+ beta = gamma / state.gamma
+ p = r + beta * state.p
+ else:
+ q = preconditioner.apply(r)
+ gamma = util.dot(r, q)
+ beta = gamma / state.gamma
+ p = q + beta * state.p
return i + 1, cg_state(i + 1, x, r, p, gamma)
with ops.name_scope(name):
n = operator.shape[1:]
rhs = array_ops.expand_dims(rhs, -1)
- gamma0 = util.l2norm_squared(rhs)
- tol = tol * tol * gamma0
- x = array_ops.expand_dims(
- array_ops.zeros(
- n, dtype=rhs.dtype.base_dtype), -1)
+ if x is None:
+ x = array_ops.expand_dims(
+ array_ops.zeros(n, dtype=rhs.dtype.base_dtype), -1)
+ r0 = rhs
+ else:
+ x = array_ops.expand_dims(x, -1)
+ r0 = rhs - operator.apply(x)
+ if preconditioner is None:
+ p0 = r0
+ else:
+ p0 = preconditioner.apply(r0)
+ gamma0 = util.dot(r0, p0)
+ tol *= linalg_ops.norm(r0)
i = constant_op.constant(0, dtype=dtypes.int32)
- state = cg_state(i=i, x=x, r=rhs, p=rhs, gamma=gamma0)
+ state = cg_state(i=i, x=x, r=r0, p=p0, gamma=gamma0)
_, state = control_flow_ops.while_loop(stopping_criterion, cg_step,
[i, state])
return cg_state(
diff --git a/tensorflow/contrib/solvers/python/ops/util.py b/tensorflow/contrib/solvers/python/ops/util.py
index 777e0c185d..96947e8eea 100644
--- a/tensorflow/contrib/solvers/python/ops/util.py
+++ b/tensorflow/contrib/solvers/python/ops/util.py
@@ -45,6 +45,23 @@ def create_operator(matrix):
apply_adjoint=lambda v: math_ops.matmul(matrix, v, adjoint_a=True))
+def identity_operator(matrix):
+ """Creates a linear operator from a rank-2 identity tensor."""
+
+ linear_operator = collections.namedtuple(
+ "LinearOperator", ["shape", "dtype", "apply", "apply_adjoint"])
+ shape = matrix.get_shape()
+ if shape.is_fully_defined():
+ shape = shape.as_list()
+ else:
+ shape = array_ops.shape(matrix)
+ return linear_operator(
+ shape=shape,
+ dtype=matrix.dtype,
+ apply=lambda v: v,
+ apply_adjoint=lambda v: v)
+
+
# TODO(rmlarsen): Measure if we should just call matmul.
def dot(x, y):
return math_ops.reduce_sum(math_ops.conj(x) * y)