aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-03-01 07:31:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-01 07:54:51 -0800
commitc0b379385bb23ad86c7233458f42c62aa7538788 (patch)
tree78eecae912a965868245270caba0d3a247f120f5
parentf424ca38712a87aeaf614af454d96b5d155592ca (diff)
LinearOperatorUDVHUpdate argument name change: diag --> update_diag
Change: 148886147
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py4
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_udvh_update_test.py68
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py82
3 files changed, 79 insertions, 75 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 9806839106..ac0836736e 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -240,8 +240,8 @@ class MultivariateNormalDiagPlusLowRank(
scale = linalg.LinearOperatorUDVHUpdate(
scale,
u=scale_perturb_factor,
- diag=scale_perturb_diag,
- is_diag_positive=scale_perturb_diag is None,
+ diag_update=scale_perturb_diag,
+ is_diag_update_positive=scale_perturb_diag is None,
is_non_singular=True, # Implied by is_positive_definite=True.
is_self_adjoint=True,
is_positive_definite=True,
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_udvh_update_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_udvh_update_test.py
index 019f73312a..7abe12f1a4 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_udvh_update_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_udvh_update_test.py
@@ -39,13 +39,13 @@ class BaseLinearOperatorUDVHUpdatetest(object):
# If True, A = L + UDV^H
# If False, A = L + UV^H or A = L + UU^H, depending on _use_v.
- _use_diag_perturbation = None
+ _use_diag_update = None
# If True, diag is > 0, which means D is symmetric positive definite.
- _is_diag_positive = None
+ _is_diag_update_positive = None
# If True, A = L + UDV^H
- # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_perturbation
+ # If False, A = L + UDU^H or A = L + UU^H, depending on _use_diag_update
_use_v = None
@property
@@ -67,7 +67,7 @@ class BaseLinearOperatorUDVHUpdatetest(object):
diag_shape = shape[:-1]
k = shape[-2] // 2 + 1
u_perturbation_shape = shape[:-1] + [k]
- diag_perturbation_shape = shape[:-2] + [k]
+ diag_update_shape = shape[:-2] + [k]
# base_operator L will be a symmetric positive definite diagonal linear
# operator, with condition number as high as 1e4.
@@ -86,13 +86,13 @@ class BaseLinearOperatorUDVHUpdatetest(object):
v_ph = array_ops.placeholder(dtype=dtype)
# D
- if self._is_diag_positive:
- diag_perturbation = linear_operator_test_util.random_uniform(
- diag_perturbation_shape, minval=1e-4, maxval=1., dtype=dtype)
+ if self._is_diag_update_positive:
+ diag_update = linear_operator_test_util.random_uniform(
+ diag_update_shape, minval=1e-4, maxval=1., dtype=dtype)
else:
- diag_perturbation = linear_operator_test_util.random_normal(
- diag_perturbation_shape, stddev=1e-4, dtype=dtype)
- diag_perturbation_ph = array_ops.placeholder(dtype=dtype)
+ diag_update = linear_operator_test_util.random_normal(
+ diag_update_shape, stddev=1e-4, dtype=dtype)
+ diag_update_ph = array_ops.placeholder(dtype=dtype)
if use_placeholder:
# Evaluate here because (i) you cannot feed a tensor, and (ii)
@@ -101,7 +101,7 @@ class BaseLinearOperatorUDVHUpdatetest(object):
base_diag = base_diag.eval()
u = u.eval()
v = v.eval()
- diag_perturbation = diag_perturbation.eval()
+ diag_update = diag_update.eval()
# In all cases, set base_operator to be positive definite.
base_operator = linalg.LinearOperatorDiag(
@@ -111,13 +111,13 @@ class BaseLinearOperatorUDVHUpdatetest(object):
base_operator,
u=u_ph,
v=v_ph if self._use_v else None,
- diag=diag_perturbation_ph if self._use_diag_perturbation else None,
- is_diag_positive=self._is_diag_positive)
+ diag_update=diag_update_ph if self._use_diag_update else None,
+ is_diag_update_positive=self._is_diag_update_positive)
feed_dict = {
base_diag_ph: base_diag,
u_ph: u,
v_ph: v,
- diag_perturbation_ph: diag_perturbation}
+ diag_update_ph: diag_update}
else:
base_operator = linalg.LinearOperatorDiag(
base_diag, is_positive_definite=True)
@@ -125,31 +125,31 @@ class BaseLinearOperatorUDVHUpdatetest(object):
base_operator,
u,
v=v if self._use_v else None,
- diag=diag_perturbation if self._use_diag_perturbation else None,
- is_diag_positive=self._is_diag_positive)
+ diag_update=diag_update if self._use_diag_update else None,
+ is_diag_update_positive=self._is_diag_update_positive)
feed_dict = None
# The matrix representing L
base_diag_mat = array_ops.matrix_diag(base_diag)
# The matrix representing D
- diag_perturbation_mat = array_ops.matrix_diag(diag_perturbation)
+ diag_update_mat = array_ops.matrix_diag(diag_update)
# Set up mat as some variant of A = L + UDV^H
- if self._use_v and self._use_diag_perturbation:
+ if self._use_v and self._use_diag_update:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
mat = base_diag_mat + math_ops.matmul(
- u, math_ops.matmul(diag_perturbation_mat, v, adjoint_b=True))
+ u, math_ops.matmul(diag_update_mat, v, adjoint_b=True))
elif self._use_v:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
- elif self._use_diag_perturbation:
+ elif self._use_diag_update:
# In this case, we have L + UDU^H, which is PD if D > 0, since L > 0.
- expect_use_cholesky = self._is_diag_positive
+ expect_use_cholesky = self._is_diag_update_positive
mat = base_diag_mat + math_ops.matmul(
- u, math_ops.matmul(diag_perturbation_mat, u, adjoint_b=True))
+ u, math_ops.matmul(diag_update_mat, u, adjoint_b=True))
else:
# In this case, we have L + UU^H, which is PD since L > 0.
expect_use_cholesky = True
@@ -168,8 +168,8 @@ class LinearOperatorUDVHUpdatetestWithDiagUseCholesky(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
- _use_diag_perturbation = True
- _is_diag_positive = True
+ _use_diag_update = True
+ _is_diag_update_positive = True
_use_v = False
def setUp(self):
@@ -186,8 +186,8 @@ class LinearOperatorUDVHUpdatetestWithDiagCannotUseCholesky(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UDU^H, D !> 0, L > 0 ==> A !> 0 and we cannot use a Cholesky."""
- _use_diag_perturbation = True
- _is_diag_positive = False
+ _use_diag_update = True
+ _is_diag_update_positive = False
_use_v = False
def setUp(self):
@@ -205,8 +205,8 @@ class LinearOperatorUDVHUpdatetestNoDiagUseCholesky(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UU^H, L > 0 ==> A > 0 and we can use a Cholesky."""
- _use_diag_perturbation = False
- _is_diag_positive = None
+ _use_diag_update = False
+ _is_diag_update_positive = None
_use_v = False
def setUp(self):
@@ -223,8 +223,8 @@ class LinearOperatorUDVHUpdatetestNoDiagCannotUseCholesky(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""A = L + UV^H, L > 0 ==> A is not symmetric and we cannot use a Cholesky."""
- _use_diag_perturbation = False
- _is_diag_positive = None
+ _use_diag_update = False
+ _is_diag_update_positive = None
_use_v = True
def setUp(self):
@@ -242,8 +242,8 @@ class LinearOperatorUDVHUpdatetestWithDiagNotSquare(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
"""A = L + UDU^H, D > 0, L > 0 ==> A > 0 and we can use a Cholesky."""
- _use_diag_perturbation = True
- _is_diag_positive = True
+ _use_diag_update = True
+ _is_diag_update_positive = True
_use_v = True
@@ -309,14 +309,14 @@ class LinearOpearatorUDVHUpdateBroadcastsShape(test.TestCase):
u = rng.rand(5, 3, 2)
diag = rng.rand(5, 4) # Last dimension should be 2
with self.assertRaisesRegexp(ValueError, "not compatible"):
- linalg.LinearOperatorUDVHUpdate(base_operator, u=u, diag=diag)
+ linalg.LinearOperatorUDVHUpdate(base_operator, u=u, diag_update=diag)
def test_diag_incompatible_batch_shape_raises(self):
base_operator = linalg.LinearOperatorIdentity(num_rows=3, dtype=np.float64)
u = rng.rand(5, 3, 2)
diag = rng.rand(4, 2) # First dimension should be 5
with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
- linalg.LinearOperatorUDVHUpdate(base_operator, u=u, diag=diag)
+ linalg.LinearOperatorUDVHUpdate(base_operator, u=u, diag_update=diag)
if __name__ == "__main__":
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py
index bce27dcd6f..7c7776e624 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py
@@ -61,14 +61,14 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
```python
# Create a 3 x 3 diagonal linear operator.
diag_operator = LinearOperatorDiag(
- diag=[1., 2., 3.], is_non_singular=True, is_self_adjoint=True,
+ diag_update=[1., 2., 3.], is_non_singular=True, is_self_adjoint=True,
is_positive_definite=True)
# Perturb with a rank 2 perturbation
operator = LinearOperatorUDVHUpdate(
operator=diag_operator,
u=[[1., 2.], [-1., 3.], [0., 0.]],
- diag=[11., 12.],
+ diag_update=[11., 12.],
v=[[1., 2.], [-1., 3.], [10., 10.]])
operator.shape
@@ -112,7 +112,8 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
#### Matrix property hints
This `LinearOperator` is initialized with boolean flags of the form `is_X`,
- for `X = non_singular, self_adjoint, positive_definite, diag_positive, square`
+ for `X = non_singular, self_adjoint, positive_definite, diag_update_positive`
+ and `square`
These have the following meaning
* If `is_X == True`, callers should expect the operator to have the
property `X`. This is a promise that should be fulfilled, but is *not* a
@@ -126,9 +127,9 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
def __init__(self,
base_operator,
u,
- diag=None,
+ diag_update=None,
v=None,
- is_diag_positive=None,
+ is_diag_update_positive=None,
is_non_singular=None,
is_self_adjoint=None,
is_positive_definite=None,
@@ -151,13 +152,14 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
`LinearOperator`. This is `L` above.
u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
This is `U` above.
- diag: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` as
- `base_operator`. This is the diagonal of `D` above.
+ diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
+ as `base_operator`. This is the diagonal of `D` above.
Defaults to `D` being the identity operator.
v: Optional `Tensor` of same `dtype` as `u` and shape `[B1,...,Bb, N, K]`
Defaults to `v = u`, in which case the perturbation is symmetric.
If `M != N`, then `v` must be set since the perturbation is not square.
- is_diag_positive: Python `bool`. If `True`, expect `diag > 0`.
+ is_diag_update_positive: Python `bool`.
+ If `True`, expect `diag_update > 0`.
is_non_singular: Expect that this operator is non-singular.
Default is `None`, unless `is_positive_definite` is auto-set to be
`True` (see below).
@@ -166,8 +168,8 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
and `v = None` (meaning `u=v`), in which case this defaults to `True`.
is_positive_definite: Expect that this operator is positive definite.
Default is `None`, unless `base_operator` is positive-definite
- `v = None` (meaning `u=v`), and `is_diag_positive`, in which case this
- defaults to `True`.
+ `v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case
+ this defaults to `True`.
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`.
@@ -177,10 +179,10 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
# TODO(langmore) support complex types.
# Complex types are not allowed due to tf.cholesky() requiring float.
# If complex dtypes are allowed, we update the following
- # 1. is_diag_positive should still imply that `diag > 0`, but we need to
- # remind the user that this implies diag is real. This is needed because
- # if diag has non-zero imaginary part, it will not be self-adjoint
- # positive definite.
+ # 1. is_diag_update_positive should still imply that `diag > 0`, but we need
+ # to remind the user that this implies diag is real. This is needed
+ # because if diag has non-zero imaginary part, it will not be
+ # self-adjoint positive definite.
dtype = base_operator.dtype
allowed_dtypes = [dtypes.float32, dtypes.float64]
if dtype not in allowed_dtypes:
@@ -188,17 +190,17 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
"Argument matrix must have dtype in %s. Found: %s"
% (allowed_dtypes, dtype))
- if diag is None:
- if is_diag_positive is False:
+ if diag_update is None:
+ if is_diag_update_positive is False:
raise ValueError(
"Default diagonal is the identity, which is positive. However, "
- "user set 'is_diag_positive' to False.")
- is_diag_positive = True
+ "user set 'is_diag_update_positive' to False.")
+ is_diag_update_positive = True
# In this case, we can use a Cholesky decomposition to help us solve/det.
self._use_cholesky = (
base_operator.is_positive_definite and base_operator.is_self_adjoint
- and is_diag_positive
+ and is_diag_update_positive
and v is None)
# Possibly auto-set some characteristic flags from None to True.
@@ -223,7 +225,7 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
is_positive_definite = True
is_self_adjoint = True
- values = base_operator.graph_parents + [u, diag, v]
+ values = base_operator.graph_parents + [u, diag_update, v]
with ops.name_scope(name, values=values):
# Create U and V.
@@ -233,14 +235,16 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
else:
self._v = ops.convert_to_tensor(v, name="v")
- if diag is None:
- self._diag = None
+ if diag_update is None:
+ self._diag_update = None
else:
- self._diag = ops.convert_to_tensor(diag, name="diag")
+ self._diag_update = ops.convert_to_tensor(
+ diag_update, name="diag_update")
# Create base_operator L.
self._base_operator = base_operator
- graph_parents = base_operator.graph_parents + [self.u, self._diag, self.v]
+ graph_parents = base_operator.graph_parents + [
+ self.u, self._diag_update, self.v]
graph_parents = [p for p in graph_parents if p is not None]
super(LinearOperatorUDVHUpdate, self).__init__(
@@ -253,11 +257,11 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
name=name)
# Create the diagonal operator D.
- self._set_diag_operators(diag, is_diag_positive)
- self._is_diag_positive = is_diag_positive
+ self._set_diag_operators(diag_update, is_diag_update_positive)
+ self._is_diag_update_positive = is_diag_update_positive
contrib_tensor_util.assert_same_float_dtype(
- (base_operator, self.u, self.v, self._diag))
+ (base_operator, self.u, self.v, self._diag_update))
self._check_shapes()
# Pre-compute the so-called "capacitance" matrix
@@ -278,18 +282,18 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
self.base_operator.domain_dimension.assert_is_compatible_with(
uv_shape[-2])
- if self._diag is not None:
- uv_shape[-1].assert_is_compatible_with(self._diag.get_shape()[-1])
+ if self._diag_update is not None:
+ uv_shape[-1].assert_is_compatible_with(self._diag_update.get_shape()[-1])
array_ops.broadcast_static_shape(
- batch_shape, self._diag.get_shape()[:-1])
+ batch_shape, self._diag_update.get_shape()[:-1])
- def _set_diag_operators(self, diag, is_diag_positive):
- """Set attributes self._diag and self._diag_operator."""
- if diag is not None:
+ def _set_diag_operators(self, diag_update, is_diag_update_positive):
+ """Set attributes self._diag_update and self._diag_operator."""
+ if diag_update is not None:
self._diag_operator = linear_operator_diag.LinearOperatorDiag(
- self._diag, is_positive_definite=is_diag_positive)
+ self._diag_update, is_positive_definite=is_diag_update_positive)
self._diag_inv_operator = linear_operator_diag.LinearOperatorDiag(
- 1. / self._diag, is_positive_definite=is_diag_positive)
+ 1. / self._diag_update, is_positive_definite=is_diag_update_positive)
else:
if self.u.get_shape()[-1].value is not None:
r = self.u.get_shape()[-1].value
@@ -310,14 +314,14 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
return self._v
@property
- def is_diag_positive(self):
+ def is_diag_update_positive(self):
"""If this operator is `A = L + U D V^H`, this hints `D > 0` elementwise."""
- return self._is_diag_positive
+ return self._is_diag_update_positive
@property
- def diag_arg(self):
+ def diag_update(self):
"""If this operator is `A = L + U D V^H`, this is the diagonal of `D`."""
- return self._diag
+ return self._diag_update
@property
def diag_operator(self):