aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-05-10 20:12:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 10:55:29 -0700
commit8f8855adc604333180bf57f11e00dc70e679974d (patch)
tree7e144f3602a8923710436bff1c5b71e6694b9659
parent947fc02a3bfc9bbddac37be2a252c9b7f1b3598b (diff)
Cleanup of a few things in distributions library. Noticed while putting
together change for MVNFullCovariance. * Using assert_none_equal * Removing redundant asserts in mvn_tril. PiperOrigin-RevId: 155707155
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py9
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py17
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py5
7 files changed, 39 insertions, 21 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
index 406cd4ebbe..3f4582eb7e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py
@@ -103,6 +103,14 @@ class MultivariateNormalDiagTest(test.TestCase):
self.assertAllClose(cov_mat, np.cov(samps.T),
atol=0.05, rtol=0.05)
+ def testSingularScaleRaises(self):
+ mu = [-1., 1]
+ diag = [1., 0]
+ with self.test_session():
+ dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
+ with self.assertRaisesOpError("Singular"):
+ dist.sample().eval()
+
def testSampleWithBroadcastScale(self):
# mu corresponds to a 2-batch of 3-variate normals
mu = np.zeros([2, 3])
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
index 8c6980ca57..685f32883d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py
@@ -151,6 +151,14 @@ class MultivariateNormalTriLTest(test.TestCase):
self.assertAllClose(sample_values.mean(axis=0), mu, atol=1e-2)
self.assertAllClose(np.cov(sample_values, rowvar=0), sigma, atol=0.06)
+ def testSingularScaleRaises(self):
+ with self.test_session():
+ mu = None
+ chol = [[1., 0.], [0., 0.]]
+ mvn = ds.MultivariateNormalTriL(mu, chol, validate_args=True)
+ with self.assertRaisesOpError("Singular operator"):
+ mvn.sample().eval()
+
def testSampleWithSampleShape(self):
with self.test_session():
mu = self._rng.rand(3, 5, 2)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
index fc04acd90c..c355adeedb 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_laplace_diag_test.py
@@ -92,6 +92,14 @@ class VectorLaplaceDiagTest(test.TestCase):
self.assertAllClose(cov_mat, np.cov(samps.T),
atol=0.05, rtol=0.05)
+ def testSingularScaleRaises(self):
+ mu = [-1., 1]
+ diag = [1., 0]
+ with self.test_session():
+ dist = ds.VectorLaplaceDiag(mu, diag, validate_args=True)
+ with self.assertRaisesOpError("Singular"):
+ dist.sample().eval()
+
def testSampleWithBroadcastScale(self):
# mu corresponds to a 2-batch of 3-variate normals
mu = np.zeros([2, 3])
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 370b4f4a06..5e3b42dd2a 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -25,7 +25,6 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import util
from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import
@@ -45,13 +44,11 @@ def make_diag_scale(loc, scale_diag, scale_identity_multiplier,
check_ops.assert_positive(
x, message="diagonal part must be positive"),
], x)
- # TODO(b/35157376): Use `assert_none_equal` once it exists.
return control_flow_ops.with_dependencies([
- check_ops.assert_greater(
- math_ops.abs(x),
+ check_ops.assert_none_equal(
+ x,
array_ops.zeros([], x.dtype),
- message="diagonal part must be non-zero"),
- ], x)
+ message="diagonal part must be non-zero")], x)
with ops.name_scope(name, "make_diag_scale",
values=[loc, scale_diag, scale_identity_multiplier]):
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index e57145b42f..163cf75d99 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -197,11 +197,14 @@ class MultivariateNormalDiag(
with ops.name_scope(name):
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
+ # No need to validate_args while making diag_scale. The returned
+ # LinearOperatorDiag has an assert_non_singular method that is called by
+ # the Bijector.
scale = distribution_util.make_diag_scale(
loc=loc,
scale_diag=scale_diag,
scale_identity_multiplier=scale_identity_multiplier,
- validate_args=validate_args,
+ validate_args=False,
assert_positive=False)
super(MultivariateNormalDiag, self).__init__(
loc=loc,
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index dd5933f62b..d662b25e1e 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -21,10 +21,6 @@ from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import util as distribution_util
@@ -55,7 +51,7 @@ class MultivariateNormalTriL(
where:
* `loc` is a vector in `R^k`,
- * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
+ * `scale` is a matrix in `R^{k x k}`, `covariance = scale @ scale.T`,
* `Z` denotes the normalization constant, and,
* `||y||**2` denotes the squared Euclidean norm of `y`.
@@ -191,14 +187,9 @@ class MultivariateNormalTriL(
is_positive_definite=True,
assert_proper_shapes=validate_args)
else:
- if validate_args:
- scale_tril = control_flow_ops.with_dependencies([
- # TODO(b/35157376): Use `assert_none_equal` once it exists.
- check_ops.assert_greater(
- math_ops.abs(array_ops.matrix_diag_part(scale_tril)),
- array_ops.zeros([], scale_tril.dtype),
- message="`scale_tril` must have non-zero diagonal"),
- ], scale_tril)
+ # No need to validate that scale_tril is non-singular.
+ # LinearOperatorTriL has an assert_non_singular method that is called
+ # by the Bijector.
scale = linalg.LinearOperatorTriL(
scale_tril,
is_non_singular=True,
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index 2dfda81a15..0e3867809a 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -214,11 +214,14 @@ class VectorLaplaceDiag(
with ops.name_scope(name):
with ops.name_scope("init", values=[
loc, scale_diag, scale_identity_multiplier]):
+ # No need to validate_args while making diag_scale. The returned
+ # LinearOperatorDiag has an assert_non_singular method that is called by
+ # the Bijector.
scale = distribution_util.make_diag_scale(
loc=loc,
scale_diag=scale_diag,
scale_identity_multiplier=scale_identity_multiplier,
- validate_args=validate_args,
+ validate_args=False,
assert_positive=False)
super(VectorLaplaceDiag, self).__init__(
loc=loc,