diff options
author | 2016-07-19 17:46:36 -0800 | |
---|---|---|
committer | 2016-07-19 19:03:31 -0700 | |
commit | 8c413daa09318c6ad021eb830650e3c66ee90891 (patch) | |
tree | 2f8a05855400d531178c992c00bb75de5de51622 | |
parent | f3a613e9db95958316569d74748d4fdb632ffbb4 (diff) |
strict/strict_statistics -> validate_args/allow_nan_stats
Change: 127900496
23 files changed, 355 insertions, 324 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py index 2e1c818b70..c636a4d060 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bernoulli_test.py @@ -169,7 +169,7 @@ class BernoulliTest(tf.test.TestCase): def testEntropyWithBatch(self): p = [[0.1, 0.7], [0.2, 0.6]] - dist = tf.contrib.distributions.Bernoulli(p=p, strict=False) + dist = tf.contrib.distributions.Bernoulli(p=p, validate_args=False) with self.test_session(): self.assertAllClose(dist.entropy().eval(), [[entropy(0.1), entropy(0.7)], [entropy(0.2), entropy(0.6)]]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py index 2252ac57f7..6f26232091 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/beta_test.py @@ -193,11 +193,11 @@ class BetaTest(tf.test.TestCase): with self.assertRaisesOpError('Condition x < y.*'): dist.mode().eval() - def testBetaMode_disable_strict_statistics(self): + def testBetaMode_enable_allow_nan_stats(self): with tf.Session(): a = np.array([1., 2, 3]) b = np.array([2., 4, 1.2]) - dist = tf.contrib.distributions.Beta(a, b, strict_statistics=False) + dist = tf.contrib.distributions.Beta(a, b, allow_nan_stats=True) expected_mode = (a - 1)/(a + b - 2) expected_mode[0] = np.nan @@ -206,7 +206,7 @@ class BetaTest(tf.test.TestCase): a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) - dist = tf.contrib.distributions.Beta(a, b, strict_statistics=False) + dist = tf.contrib.distributions.Beta(a, b, allow_nan_stats=True) expected_mode = (a - 1)/(a + b - 2) expected_mode[0] = np.nan diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py index 95534ed4c7..aec5b85699 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_multinomial_test.py @@ -272,7 +272,7 @@ class DirichletMultinomialTest(tf.test.TestCase): counts = [[1., 0], [0., -1]] # counts should be non-negative. n = [-5.3] # n should be a non negative integer equal to counts.sum. dist = tf.contrib.distributions.DirichletMultinomial( - n, alpha, strict=False) + n, alpha, validate_args=False) dist.pmf(counts).eval() # Should not raise. diff --git a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py index 028dde15ea..4368ccc10c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/dirichlet_test.py @@ -157,11 +157,11 @@ class DirichletTest(tf.test.TestCase): with self.assertRaisesOpError('Condition x < y.*'): dirichlet.mode().eval() - def testDirichletMode_disable_strict_statistics(self): + def testDirichletMode_enable_allow_nan_stats(self): with self.test_session(): alpha = np.array([1., 2, 3]) dirichlet = tf.contrib.distributions.Dirichlet( - alpha=alpha, strict_statistics=False) + alpha=alpha, allow_nan_stats=True) expected_mode = (alpha - 1)/(np.sum(alpha) - 3) expected_mode[0] = np.nan diff --git a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py index 5c56befd42..e1434cae91 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py @@ -118,33 +118,33 @@ class GammaTest(tf.test.TestCase): self.assertEqual(gamma.mean().get_shape(), (3,)) self.assertAllClose(gamma.mean().eval(), expected_means) - def testGammaModeStrictStatsIsTrueWorksWhenAllBatchMembersAreDefined(self): + def testGammaModeAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): with self.test_session(): alpha_v = np.array([5.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) gamma = tf.contrib.distributions.Gamma( - alpha=alpha_v, beta=beta_v) # strict_statistics=True is the default. + alpha=alpha_v, beta=beta_v) expected_modes = (alpha_v - 1) / beta_v self.assertEqual(gamma.mode().get_shape(), (3,)) self.assertAllClose(gamma.mode().eval(), expected_modes) - def testGammaModeStrictStatsTrueRaisesForUndefinedBatchMembers(self): + def testGammaModeAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): with self.test_session(): # Mode will not be defined for the first entry. alpha_v = np.array([0.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) gamma = tf.contrib.distributions.Gamma( - alpha=alpha_v, beta=beta_v) # strict_statistics=True is the default. + alpha=alpha_v, beta=beta_v) with self.assertRaisesOpError('x < y'): gamma.mode().eval() - def testGammaModeStrictStatsIsFalseReturnsNaNforUndefinedBatchMembers(self): + def testGammaModeAllowNanStatsIsTrueReturnsNaNforUndefinedBatchMembers(self): with self.test_session(): # Mode will not be defined for the first entry. alpha_v = np.array([0.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) gamma = tf.contrib.distributions.Gamma( - alpha=alpha_v, beta=beta_v, strict_statistics=False) + alpha=alpha_v, beta=beta_v, allow_nan_stats=True) expected_modes = (alpha_v - 1) / beta_v expected_modes[0] = np.nan self.assertEqual(gamma.mode().get_shape(), (3,)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py index 9a0fb7d486..e3acc1b84c 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py @@ -124,18 +124,18 @@ class InverseGammaTest(tf.test.TestCase): alpha_v = np.array([5.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) inv_gamma = tf.contrib.distributions.InverseGamma( - alpha=alpha_v, beta=beta_v) # strict_statistics=True is the default. + alpha=alpha_v, beta=beta_v) expected_means = stats.invgamma.mean(alpha_v, scale=beta_v) self.assertEqual(inv_gamma.mean().get_shape(), (3,)) self.assertAllClose(inv_gamma.mean().eval(), expected_means) - def testInverseGammaMeanStrictStats(self): + def testInverseGammaMeanAllowNanStats(self): with self.test_session(): # Mean will not be defined for the first entry. alpha_v = np.array([1.0, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) inv_gamma = tf.contrib.distributions.InverseGamma( - alpha=alpha_v, beta=beta_v) # strict_statistics=True is the default. + alpha=alpha_v, beta=beta_v) with self.assertRaisesOpError('x < y'): inv_gamma.mean().eval() @@ -146,7 +146,7 @@ class InverseGammaTest(tf.test.TestCase): beta_v = np.array([1.0, 2.0, 4.0, 5.0]) inv_gamma = tf.contrib.distributions.InverseGamma(alpha=alpha_v, beta=beta_v, - strict_statistics=False) + allow_nan_stats=True) expected_means = beta_v / (alpha_v - 1) expected_means[0] = np.nan expected_means[1] = np.nan @@ -163,7 +163,7 @@ class InverseGammaTest(tf.test.TestCase): self.assertEqual(inv_gamma.variance().get_shape(), (3,)) self.assertAllClose(inv_gamma.variance().eval(), expected_variances) - def testInverseGammaVarianceStrictStats(self): + def testInverseGammaVarianceAllowNanStats(self): with self.test_session(): alpha_v = np.array([1.5, 3.0, 2.5]) beta_v = np.array([1.0, 4.0, 5.0]) @@ -178,7 +178,7 @@ class InverseGammaTest(tf.test.TestCase): beta_v = np.array([1.0, 4.0, 5.0]) inv_gamma = tf.contrib.distributions.InverseGamma(alpha=alpha_v, beta=beta_v, - strict_statistics=False) + allow_nan_stats=True) expected_variances = stats.invgamma.var(alpha_v, scale=beta_v) expected_variances[0] = np.nan self.assertEqual(inv_gamma.variance().get_shape(), (3,)) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py index e93ff10ef5..176c78398f 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py @@ -229,38 +229,38 @@ class StudentTTest(tf.test.TestCase): _check2d_rows(tf.contrib.distributions.StudentT( df=7., mu=3., sigma=[[2.], [3.], [4.]])) - def testMeanStrictStatisticsIsTrueWorksWhenAllBatchMembersAreDefined(self): + def testMeanAllowNanStatsIsFalseWorksWhenAllBatchMembersAreDefined(self): with tf.Session(): mu = [1., 3.3, 4.4] student = tf.contrib.distributions.StudentT( df=[3., 5., 7.], mu=mu, - sigma=[3., 2., 1.]) # strict_statistics=True is the default. + sigma=[3., 2., 1.]) mean = student.mean().eval() self.assertAllClose([1., 3.3, 4.4], mean) - def testMeanStrictStatisticsIsTrueRaisesWhenBatchMemberIsUndefined(self): + def testMeanAllowNanStatsIsFalseRaisesWhenBatchMemberIsUndefined(self): with tf.Session(): mu = [1., 3.3, 4.4] student = tf.contrib.distributions.StudentT( df=[0.5, 5., 7.], mu=mu, - sigma=[3., 2., 1.]) # strict_statistics=True is the default. + sigma=[3., 2., 1.]) with self.assertRaisesOpError('x < y'): student.mean().eval() - def testMeanStrictStatisticsIsFalseReturnsNaNForUndefinedBatchMembers(self): + def testMeanAllowNanStatsIsTrueReturnsNaNForUndefinedBatchMembers(self): with tf.Session(): mu = [-2, 0., 1., 3.3, 4.4] student = tf.contrib.distributions.StudentT( df=[0.5, 1., 3., 5., 7.], mu=mu, sigma=[5., 4., 3., 2., 1.], - strict_statistics=False) + allow_nan_stats=True) mean = student.mean().eval() self.assertAllClose([np.nan, np.nan, 1., 3.3, 4.4], mean) - def testVarianceStrictStatisticsFalseReturnsNaNforUndefinedBatchMembers(self): + def testVarianceAllowNanStatsTrueReturnsNaNforUndefinedBatchMembers(self): with tf.Session(): # df = 0.5 ==> undefined mean ==> undefined variance. # df = 1.5 ==> infinite variance. @@ -268,7 +268,7 @@ class StudentTTest(tf.test.TestCase): mu = [-2, 0., 1., 3.3, 4.4] sigma = [5., 4., 3., 2., 1.] student = tf.contrib.distributions.StudentT( - df=df, mu=mu, sigma=sigma, strict_statistics=False) + df=df, mu=mu, sigma=sigma, allow_nan_stats=True) var = student.variance().eval() ## scipy uses inf for variance when the mean is undefined. When mean is # undefined we say variance is undefined as well. So test the first @@ -281,7 +281,7 @@ class StudentTTest(tf.test.TestCase): stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)] self.assertAllClose(expected_var, var) - def testVarianceStrictStatisticsTrueGivesCorrectValueForDefinedBatchMembers( + def testVarianceAllowNanStatsFalseGivesCorrectValueForDefinedBatchMembers( self): with tf.Session(): # df = 1.5 ==> infinite variance. @@ -289,25 +289,25 @@ class StudentTTest(tf.test.TestCase): mu = [0., 1., 3.3, 4.4] sigma = [4., 3., 2., 1.] student = tf.contrib.distributions.StudentT( - df=df, mu=mu, sigma=sigma) # strict_statistics=True is the default. + df=df, mu=mu, sigma=sigma) var = student.variance().eval() expected_var = [ stats.t.var(d, loc=m, scale=s) for (d, m, s) in zip(df, mu, sigma)] self.assertAllClose(expected_var, var) - def testVarianceStrictStatisticsTrueRaisesForUndefinedBatchMembers(self): + def testVarianceAllowNanStatsFalseRaisesForUndefinedBatchMembers(self): with tf.Session(): # df <= 1 ==> variance not defined student = tf.contrib.distributions.StudentT( - df=1.0, mu=0.0, sigma=1.0) # strict_statistics=True is the default. + df=1.0, mu=0.0, sigma=1.0) with self.assertRaisesOpError('x < y'): student.variance().eval() with tf.Session(): # df <= 1 ==> variance not defined student = tf.contrib.distributions.StudentT( - df=0.5, mu=0.0, sigma=1.0) # strict_statistics=True is the default. + df=0.5, mu=0.0, sigma=1.0) with self.assertRaisesOpError('x < y'): student.variance().eval() diff --git a/tensorflow/contrib/distributions/python/ops/bernoulli.py b/tensorflow/contrib/distributions/python/ops/bernoulli.py index 4c7b85a120..b3259b2867 100644 --- a/tensorflow/contrib/distributions/python/ops/bernoulli.py +++ b/tensorflow/contrib/distributions/python/ops/bernoulli.py @@ -42,8 +42,13 @@ class Bernoulli(distribution.Distribution): * log_cdf """ - def __init__(self, logits=None, p=None, dtype=dtypes.int32, strict=True, - strict_statistics=True, name="Bernoulli"): + def __init__(self, + logits=None, + p=None, + dtype=dtypes.int32, + validate_args=True, + allow_nan_stats=False, + name="Bernoulli"): """Construct Bernoulli distributions. Args: @@ -55,21 +60,21 @@ class Bernoulli(distribution.Distribution): event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. dtype: dtype for samples. - strict: Whether to assert that `0 <= p <= 1`. If not strict, `log_pmf` may - return nans. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args, + `log_pmf` may return nans. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution. Raises: ValueError: If p and logits are passed, or if neither are passed. """ - self._strict_statistics = strict_statistics + self._allow_nan_stats = allow_nan_stats self._name = name self._dtype = dtype - self._strict = strict + self._validate_args = validate_args check_op = check_ops.assert_less_equal if p is None and logits is None: raise ValueError("Must pass p or logits.") @@ -84,8 +89,8 @@ class Bernoulli(distribution.Distribution): elif logits is None: with ops.name_scope(name): with ops.name_scope("p"): - with ops.control_dependencies( - [check_op(p, 1.), check_op(0., p)] if strict else []): + with ops.control_dependencies([check_op(p, 1.), check_op(0., p)] if + validate_args else []): self._p = array_ops.identity(p) with ops.name_scope("logits"): self._logits = math_ops.log(self._p) - math_ops.log(1. - self._p) @@ -96,14 +101,14 @@ class Bernoulli(distribution.Distribution): self._event_shape = array_ops.constant([], dtype=dtypes.int32) @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): diff --git a/tensorflow/contrib/distributions/python/ops/beta.py b/tensorflow/contrib/distributions/python/ops/beta.py index 31dd8e6753..38fa29e824 100644 --- a/tensorflow/contrib/distributions/python/ops/beta.py +++ b/tensorflow/contrib/distributions/python/ops/beta.py @@ -97,7 +97,8 @@ class Beta(distribution.Distribution): ``` """ - def __init__(self, a, b, strict=True, strict_statistics=True, name="Beta"): + def __init__(self, a, b, validate_args=True, allow_nan_stats=False, + name="Beta"): """Initialize a batch of Beta distributions. Args: @@ -108,12 +109,12 @@ class Beta(distribution.Distribution): b: Positive `float` or `double` tensor with shape broadcastable to `[N1,..., Nm]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different Beta distributions. - strict: Whether to assert valid values for parameters `a` and `b`, and - `x` in `prob` and `log_prob`. If False, correct behavior is not + validate_args: Whether to assert valid values for parameters `a` and `b`, + and `x` in `prob` and `log_prob`. If False, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. @@ -130,7 +131,7 @@ class Beta(distribution.Distribution): with ops.op_scope([a, b], name): with ops.control_dependencies([ check_ops.assert_positive(a), - check_ops.assert_positive(b)] if strict else []): + check_ops.assert_positive(b)] if validate_args else []): a = array_ops.identity(a, name="a") b = array_ops.identity(b, name="b") @@ -143,8 +144,8 @@ class Beta(distribution.Distribution): self._get_batch_shape = self._a_b_sum.get_shape() self._get_event_shape = tensor_shape.TensorShape([]) - self._strict = strict - self._strict_statistics = strict_statistics + self._validate_args = validate_args + self._allow_nan_stats = allow_nan_stats @property def a(self): @@ -167,14 +168,14 @@ class Beta(distribution.Distribution): return self._a_b_sum.dtype @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args def batch_shape(self, name="batch_shape"): """Batch dimensions of this instance as a 1-D int32 `Tensor`. @@ -249,7 +250,7 @@ class Beta(distribution.Distribution): Note that the mode for the Beta distribution is only defined when `a > 1`, `b > 1`. This returns the mode when `a > 1` and `b > 1`, - and NaN otherwise. If `self.strict_statistics` is `True`, an exception + and NaN otherwise. If `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`. Args: @@ -266,17 +267,17 @@ class Beta(distribution.Distribution): one = math_ops.cast(1, self.dtype) mode = (a - 1)/ (a_b_sum - 2) - if self.strict_statistics: - return control_flow_ops.with_dependencies([ - check_ops.assert_less(one, a), - check_ops.assert_less(one, b)], mode) - else: + if self.allow_nan_stats: return math_ops.select( math_ops.logical_and( math_ops.greater(a, 1), math_ops.greater(b, 1)), mode, (constant_op.constant(float("NaN"), dtype=self.dtype) * array_ops.ones_like(a_b_sum, dtype=self.dtype))) + else: + return control_flow_ops.with_dependencies([ + check_ops.assert_less(one, a), + check_ops.assert_less(one, b)], mode) def entropy(self, name="entropy"): """Entropy of the distribution in nats.""" @@ -389,5 +390,5 @@ class Beta(distribution.Distribution): dependencies = [ check_ops.assert_positive(x), check_ops.assert_less(x, math_ops.cast( - 1, self.dtype))] if self.strict else [] + 1, self.dtype))] if self.validate_args else [] return control_flow_ops.with_dependencies(dependencies, x) diff --git a/tensorflow/contrib/distributions/python/ops/categorical.py b/tensorflow/contrib/distributions/python/ops/categorical.py index c9a4b81cd1..8b0714055a 100644 --- a/tensorflow/contrib/distributions/python/ops/categorical.py +++ b/tensorflow/contrib/distributions/python/ops/categorical.py @@ -45,8 +45,8 @@ class Categorical(distribution.Distribution): self, logits, dtype=dtypes.int32, - strict=True, - strict_statistics=True, + validate_args=True, + allow_nan_stats=False, name="Categorical"): """Initialize Categorical distributions using class log-probabilities. @@ -56,17 +56,17 @@ class Categorical(distribution.Distribution): index into a batch of independent distributions and the last dimension indexes into the classes. dtype: The type of the event samples (default: int32). - strict: Unused in this distribution. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Unused in this distribution. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: A name for this distribution (optional). """ - self._strict_statistics = strict_statistics + self._allow_nan_stats = allow_nan_stats self._name = name self._dtype = dtype - self._strict = strict + self._validate_args = validate_args with ops.op_scope([logits], name): self._logits = ops.convert_to_tensor(logits, name="logits") logits_shape = array_ops.shape(self._logits) @@ -76,14 +76,14 @@ class Categorical(distribution.Distribution): self._num_classes = array_ops.gather(logits_shape, self._batch_rank) @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): diff --git a/tensorflow/contrib/distributions/python/ops/chi2.py b/tensorflow/contrib/distributions/python/ops/chi2.py index eeae24bcd6..6439eec294 100644 --- a/tensorflow/contrib/distributions/python/ops/chi2.py +++ b/tensorflow/contrib/distributions/python/ops/chi2.py @@ -34,32 +34,36 @@ class Chi2(gamma.Gamma): with Chi2(df) = Gamma(df/2, 1/2). """ - def __init__(self, df, strict=True, strict_statistics=True, name="Chi2"): + def __init__(self, + df, + validate_args=True, + allow_nan_stats=False, + name="Chi2"): """Construct Chi2 distributions with parameter `df`. Args: df: `float` or `double` tensor, the degrees of freedom of the distribution(s). `df` must contain only positive values. - strict: Whether to assert that `df > 0`, and that `x > 0` in the - methods `prob(x)` and `log_prob(x)`. If `strict` is False + validate_args: Whether to assert that `df > 0`, and that `x > 0` in the + methods `prob(x)` and `log_prob(x)`. If `validate_args` is False and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. """ # Even though all stats of chi2 are defined for valid parameters, this is # not true in the parent class "gamma." therefore, passing - # strict_statistics=True + # allow_nan_stats=False # through to the parent class results in unnecessary asserts. with ops.op_scope([df], name): df = ops.convert_to_tensor(df) self._df = df super(Chi2, self).__init__(alpha=df / 2, beta=math_ops.cast(0.5, dtype=df.dtype), - strict=strict, - strict_statistics=strict_statistics) + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) @property def df(self): diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet.py b/tensorflow/contrib/distributions/python/ops/dirichlet.py index 0f916b64a3..8337980b56 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet.py @@ -119,7 +119,10 @@ class Dirichlet(distribution.Distribution): ``` """ - def __init__(self, alpha, strict=True, strict_statistics=True, + def __init__(self, + alpha, + validate_args=True, + allow_nan_stats=False, name="Dirichlet"): """Initialize a batch of Dirichlet distributions. @@ -127,12 +130,12 @@ class Dirichlet(distribution.Distribution): alpha: Positive `float` or `double` tensor with shape broadcastable to `[N1,..., Nm, k]` `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different `k` class Dirichlet distributions. - strict: Whether to assert valid values for parameters `alpha` and + validate_args: Whether to assert valid values for parameters `alpha` and `x` in `prob` and `log_prob`. If False, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. @@ -150,8 +153,9 @@ class Dirichlet(distribution.Distribution): with ops.op_scope([alpha], name): alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps") with ops.control_dependencies([ - check_ops.assert_positive(alpha), - check_ops.assert_rank_at_least(alpha, 1)] if strict else []): + check_ops.assert_positive(alpha), check_ops.assert_rank_at_least( + alpha, 1) + ] if validate_args else []): alpha = array_ops.identity(alpha, name="alpha") self._alpha = alpha @@ -164,8 +168,8 @@ class Dirichlet(distribution.Distribution): self._get_batch_shape = self._alpha_0.get_shape() self._get_event_shape = self._alpha.get_shape().with_rank_at_least(1)[-1:] - self._strict = strict - self._strict_statistics = strict_statistics + self._validate_args = validate_args + self._allow_nan_stats = allow_nan_stats @property def alpha(self): @@ -183,14 +187,14 @@ class Dirichlet(distribution.Distribution): return self._alpha.dtype @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args def batch_shape(self, name="batch_shape"): """Batch dimensions of this instance as a 1-D int32 `Tensor`. @@ -274,7 +278,7 @@ class Dirichlet(distribution.Distribution): Note that the mode for the Beta distribution is only defined when `alpha > 1`. This returns the mode when `alpha > 1`, - and NaN otherwise. If `self.strict_statistics` is `True`, an exception + and NaN otherwise. If `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`. Args: @@ -290,15 +294,16 @@ class Dirichlet(distribution.Distribution): array_ops.expand_dims(self._alpha_0, -1) - math_ops.cast( self.event_shape()[0], self.dtype)) - if self.strict_statistics: - return control_flow_ops.with_dependencies([ - check_ops.assert_less(one, self._alpha)], mode) - else: + if self.allow_nan_stats: return math_ops.select( math_ops.greater(self._alpha, 1), mode, (constant_op.constant(float("NaN"), dtype=self.dtype) * array_ops.ones_like(self._alpha, dtype=self.dtype))) + else: + return control_flow_ops.with_dependencies([ + check_ops.assert_less(one, self._alpha) + ], mode) def entropy(self, name="entropy"): """Entropy of the distribution in nats.""" @@ -401,7 +406,7 @@ class Dirichlet(distribution.Distribution): x = ops.convert_to_tensor(x, name="x_before_deps") candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1]) one = math_ops.cast(1., self.dtype) - dependencies = [check_ops.assert_positive(x), - check_ops.assert_less(x, one), - _assert_close(one, candidate_one)] if self.strict else [] + dependencies = [check_ops.assert_positive(x), check_ops.assert_less(x, one), + _assert_close(one, candidate_one) + ] if self.validate_args else [] return control_flow_ops.with_dependencies(dependencies, x) diff --git a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py index fb86d8e612..c20590ce35 100644 --- a/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py +++ b/tensorflow/contrib/distributions/python/ops/dirichlet_multinomial.py @@ -133,8 +133,8 @@ class DirichletMultinomial(distribution.Distribution): n, alpha, allow_arbitrary_counts=False, - strict=True, - strict_statistics=True, + validate_args=True, + allow_nan_stats=False, name='DirichletMultinomial'): """Initialize a batch of DirichletMultinomial distributions. @@ -149,14 +149,14 @@ class DirichletMultinomial(distribution.Distribution): allow_arbitrary_counts: Boolean. This represents whether the pmf/cdf allows for the `counts` tensor to be non-integral values. The pmf/cdf are functions that can be evaluated at non-integral values, - but are only a distribution over non-negative integers. If `strict` is - `False`, this assertion is turned off. - strict: Whether to assert valid values for parameters `alpha` and `n`, and - `x` in `prob` and `log_prob`. If False, correct behavior is not - guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + but are only a distribution over non-negative integers. If + `validate_args` is `False`, this assertion is turned off. + validate_args: Whether to assert valid values for parameters `alpha` and + `n`, and `x` in `prob` and `log_prob`. If False, correct behavior is + not guaranteed. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. @@ -171,8 +171,8 @@ class DirichletMultinomial(distribution.Distribution): dist = DirichletMultinomial([3., 4], [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) ``` """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args self._name = name self._allow_arbitrary_counts = allow_arbitrary_counts with ops.op_scope([n, alpha], name): @@ -208,14 +208,14 @@ class DirichletMultinomial(distribution.Distribution): return self._alpha @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): @@ -364,7 +364,7 @@ class DirichletMultinomial(distribution.Distribution): def _check_counts(self, counts): """Check counts for proper shape, values, then return tensor version.""" counts = ops.convert_to_tensor(counts, name='counts') - if not self.strict: + if not self.validate_args: return counts candidate_n = math_ops.reduce_sum(counts, reduction_indices=[-1]) dependencies = [check_ops.assert_non_negative(counts), @@ -378,7 +378,7 @@ class DirichletMultinomial(distribution.Distribution): def _check_alpha(self, alpha): alpha = ops.convert_to_tensor(alpha, name='alpha') - if not self.strict: + if not self.validate_args: return alpha return control_flow_ops.with_dependencies( [check_ops.assert_rank_at_least(alpha, 1), @@ -386,7 +386,7 @@ class DirichletMultinomial(distribution.Distribution): def _check_n(self, n): n = ops.convert_to_tensor(n, name='n') - if not self.strict: + if not self.validate_args: return n return control_flow_ops.with_dependencies( [check_ops.assert_non_negative(n), _assert_integer_form(n)], n) diff --git a/tensorflow/contrib/distributions/python/ops/distribution.py b/tensorflow/contrib/distributions/python/ops/distribution.py index 2e52f578b2..48f8a1c077 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution.py +++ b/tensorflow/contrib/distributions/python/ops/distribution.py @@ -116,11 +116,11 @@ class Distribution(object): b = tf.exp(tf.matmul(logits, weights_b)) # Will raise exception if ANY batch member has a < 1 or b < 1. - dist = distributions.beta(a, b, strict_statistics=True) # default is True + dist = distributions.beta(a, b, allow_nan_stats=False) # default is False mode = dist.mode().eval() # Will return NaN for batch members with either a < 1 or b < 1. - dist = distributions.beta(a, b, strict_statistics=False) + dist = distributions.beta(a, b, allow_nan_stats=True) mode = dist.mode().eval() ``` @@ -129,16 +129,16 @@ class Distribution(object): ```python # Will raise an exception if any Op is run. negative_a = -1.0 * a # beta distribution by definition has a > 0. - dist = distributions.beta(negative_a, b, strict_statistics=False) + dist = distributions.beta(negative_a, b, allow_nan_stats=True) dist.mean().eval() ``` """ @abc.abstractproperty - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - # return self._strict_statistics + # return self._allow_nan_stats # Notes: # # When it makes sense, return +- infinity for statistics. E.g. the variance @@ -150,19 +150,19 @@ class Distribution(object): # it is either + or - infinity), so the variance = E[(X - mean)^2] is also # undefined. # - # Distributions should be initialized with a kwarg "strict_statistics" with + # Distributions should be initialized with a kwarg "allow_nan_stats" with # the following docstring (refer to above docstring note on undefined # statistics for more detail). - # strict_statistics: Boolean, default True. If True, raise an exception if + # allow_nan_stats: Boolean, default False. If False, raise an exception if # a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - # If False, batch members with valid parameters leading to undefined + # If True, batch members with valid parameters leading to undefined # statistics will return NaN for this statistic. pass @abc.abstractproperty - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - # return self._strict. + # return self._validate_args. pass @abc.abstractproperty diff --git a/tensorflow/contrib/distributions/python/ops/exponential.py b/tensorflow/contrib/distributions/python/ops/exponential.py index 84fbde2c23..84dfe98321 100644 --- a/tensorflow/contrib/distributions/python/ops/exponential.py +++ b/tensorflow/contrib/distributions/python/ops/exponential.py @@ -39,24 +39,24 @@ class Exponential(gamma.Gamma): """ def __init__( - self, lam, strict=True, strict_statistics=True, name="Exponential"): + self, lam, validate_args=True, allow_nan_stats=False, name="Exponential"): """Construct Exponential distribution with parameter `lam`. Args: lam: `float` or `double` tensor, the rate of the distribution(s). `lam` must contain only positive values. - strict: Whether to assert that `lam > 0`, and that `x > 0` in the - methods `prob(x)` and `log_prob(x)`. If `strict` is False + validate_args: Whether to assert that `lam > 0`, and that `x > 0` in the + methods `prob(x)` and `log_prob(x)`. If `validate_args` is False and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. """ # Even though all statistics of are defined for valid inputs, this is not # true in the parent class "Gamma." Therefore, passing - # strict_statistics=True + # allow_nan_stats=False # through to the parent class results in unnecessary asserts. with ops.op_scope([lam], name): lam = ops.convert_to_tensor(lam) @@ -64,8 +64,8 @@ class Exponential(gamma.Gamma): super(Exponential, self).__init__( alpha=math_ops.cast(1.0, dtype=lam.dtype), beta=lam, - strict_statistics=strict_statistics, - strict=strict) + allow_nan_stats=allow_nan_stats, + validate_args=validate_args) @property def lam(self): diff --git a/tensorflow/contrib/distributions/python/ops/gamma.py b/tensorflow/contrib/distributions/python/ops/gamma.py index b20a757061..1f733ceda1 100644 --- a/tensorflow/contrib/distributions/python/ops/gamma.py +++ b/tensorflow/contrib/distributions/python/ops/gamma.py @@ -57,8 +57,12 @@ class Gamma(distribution.Distribution): """ - def __init__( - self, alpha, beta, strict=True, strict_statistics=True, name="Gamma"): + def __init__(self, + alpha, + beta, + validate_args=True, + allow_nan_stats=False, + name="Gamma"): """Construct Gamma distributions with parameters `alpha` and `beta`. The parameters `alpha` and `beta` must be shaped in a way that supports @@ -71,25 +75,24 @@ class Gamma(distribution.Distribution): beta: `float` or `double` tensor, the inverse scale params of the distribution(s). beta must contain only positive values. - strict: Whether to assert that `a > 0, b > 0`, and that `x > 0` in the - methods `prob(x)` and `log_prob(x)`. If `strict` is False + validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in + the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. Raises: TypeError: if `alpha` and `beta` are different dtypes. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.op_scope([alpha, beta], name) as scope: self._name = scope - with ops.control_dependencies( - [check_ops.assert_positive(alpha), check_ops.assert_positive(beta)] - if strict else []): + with ops.control_dependencies([check_ops.assert_positive( + alpha), check_ops.assert_positive(beta)] if validate_args else []): alpha = array_ops.identity(alpha, name="alpha") beta = array_ops.identity(beta, name="beta") @@ -103,14 +106,14 @@ class Gamma(distribution.Distribution): self._beta = beta @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): @@ -191,7 +194,7 @@ class Gamma(distribution.Distribution): """Mode of each batch member. The mode of a gamma distribution is `(alpha - 1) / beta` when `alpha > 1`, - and `NaN` otherwise. If `self.strict_statistics` is `True`, an exception + and `NaN` otherwise. If `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`. Args: @@ -205,14 +208,14 @@ class Gamma(distribution.Distribution): with ops.name_scope(self.name): with ops.op_scope([alpha, beta], name): mode_if_defined = (alpha - 1.0) / beta - if self.strict_statistics: - one = ops.convert_to_tensor(1.0, dtype=self.dtype) - return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, alpha)], mode_if_defined) - else: + if self.allow_nan_stats: alpha_ge_1 = alpha >= 1.0 nan = np.nan * self._ones() return math_ops.select(alpha_ge_1, mode_if_defined, nan) + else: + one = ops.convert_to_tensor(1.0, dtype=self.dtype) + return control_flow_ops.with_dependencies( + [check_ops.assert_less(one, alpha)], mode_if_defined) def variance(self, name="variance"): """Variance of each batch member.""" @@ -244,9 +247,8 @@ class Gamma(distribution.Distribution): alpha = self._alpha beta = self._beta x = ops.convert_to_tensor(x) - x = control_flow_ops.with_dependencies( - [check_ops.assert_positive(x)] if self.strict else [], - x) + x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if + self.validate_args else [], x) contrib_tensor_util.assert_same_float_dtype(tensors=[x,], dtype=self.dtype) @@ -281,9 +283,8 @@ class Gamma(distribution.Distribution): with ops.name_scope(self.name): with ops.op_scope([self._alpha, self._beta, x], name): x = ops.convert_to_tensor(x) - x = control_flow_ops.with_dependencies( - [check_ops.assert_positive(x)] if self.strict else [], - x) + x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if + self.validate_args else [], x) contrib_tensor_util.assert_same_float_dtype(tensors=[x,], dtype=self.dtype) # Note that igamma returns the regularized incomplete gamma function, diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index d239789489..a23f6df571 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -60,8 +60,8 @@ class InverseGamma(distribution.Distribution): def __init__(self, alpha, beta, - strict=True, - strict_statistics=True, + validate_args=True, + allow_nan_stats=False, name="InverseGamma"): """Construct InverseGamma distributions with parameters `alpha` and `beta`. @@ -74,24 +74,24 @@ class InverseGamma(distribution.Distribution): alpha must contain only positive values. beta: `float` or `double` tensor, the scale params of the distribution(s). beta must contain only positive values. - strict: Whether to assert that `a > 0, b > 0`, and that `x > 0` in the - methods `prob(x)` and `log_prob(x)`. If `strict` is False + validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in + the methods `prob(x)` and `log_prob(x)`. If `validate_args` is False and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prepend to all ops created by this distribution. Raises: TypeError: if `alpha` and `beta` are different dtypes. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.op_scope([alpha, beta], name) as scope: self._name = scope with ops.control_dependencies([check_ops.assert_positive( - alpha), check_ops.assert_positive(beta)] if strict else []): + alpha), check_ops.assert_positive(beta)] if validate_args else []): alpha = array_ops.identity(alpha, name="alpha") beta = array_ops.identity(beta, name="beta") @@ -105,14 +105,14 @@ class InverseGamma(distribution.Distribution): self._beta = beta @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): @@ -187,8 +187,8 @@ class InverseGamma(distribution.Distribution): """Mean of each batch member. The mean of an inverse gamma distribution is `beta / (alpha - 1)`, - when `alpha > 1`, and `NaN` otherwise. If `self.strict_statistics` is - `True`, an exception will be raised rather than returning `NaN` + when `alpha > 1`, and `NaN` otherwise. If `self.allow_nan_stats` is + `False`, an exception will be raised rather than returning `NaN` Args: name: A name to give this op. @@ -201,14 +201,14 @@ class InverseGamma(distribution.Distribution): with ops.name_scope(self.name): with ops.op_scope([alpha, beta], name): mean_if_defined = beta / (alpha - 1.0) - if self.strict_statistics: - one = ops.convert_to_tensor(1.0, dtype=self.dtype) - return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, alpha)], mean_if_defined) - else: + if self.allow_nan_stats: alpha_gt_1 = alpha > 1.0 nan = np.nan * self._ones() return math_ops.select(alpha_gt_1, mean_if_defined, nan) + else: + one = ops.convert_to_tensor(1.0, dtype=self.dtype) + return control_flow_ops.with_dependencies( + [check_ops.assert_less(one, alpha)], mean_if_defined) def mode(self, name="mode"): """Mode of each batch member. @@ -229,7 +229,7 @@ class InverseGamma(distribution.Distribution): """Variance of each batch member. Variance for inverse gamma is defined only for `alpha > 2`. If - `self.strict_statistics` is `True`, an exception will be raised rather + `self.allow_nan_stats` is `False`, an exception will be raised rather than returning `NaN`. Args: @@ -245,14 +245,14 @@ class InverseGamma(distribution.Distribution): var_if_defined = (math_ops.square(self._beta) / (math_ops.square(self._alpha - 1.0) * (self._alpha - 2.0))) - if self.strict_statistics: - two = ops.convert_to_tensor(2.0, dtype=self.dtype) - return control_flow_ops.with_dependencies( - [check_ops.assert_less(two, alpha)], var_if_defined) - else: + if self.allow_nan_stats: alpha_gt_2 = alpha > 2.0 nan = np.nan * self._ones() return math_ops.select(alpha_gt_2, var_if_defined, nan) + else: + two = ops.convert_to_tensor(2.0, dtype=self.dtype) + return control_flow_ops.with_dependencies( + [check_ops.assert_less(two, alpha)], var_if_defined) def log_prob(self, x, name="log_prob"): """Log prob of observations in `x` under these InverseGamma distribution(s). @@ -273,7 +273,7 @@ class InverseGamma(distribution.Distribution): beta = self._beta x = ops.convert_to_tensor(x) x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if - self.strict else [], x) + self.validate_args else [], x) contrib_tensor_util.assert_same_float_dtype(tensors=[x,], dtype=self.dtype) @@ -309,7 +309,7 @@ class InverseGamma(distribution.Distribution): with ops.op_scope([self._alpha, self._beta, x], name): x = ops.convert_to_tensor(x) x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if - self.strict else [], x) + self.validate_args else [], x) contrib_tensor_util.assert_same_float_dtype(tensors=[x,], dtype=self.dtype) # Note that igammac returns the upper regularized incomplete gamma diff --git a/tensorflow/contrib/distributions/python/ops/laplace.py b/tensorflow/contrib/distributions/python/ops/laplace.py index 1e5e867293..ee6aa81c0f 100644 --- a/tensorflow/contrib/distributions/python/ops/laplace.py +++ b/tensorflow/contrib/distributions/python/ops/laplace.py @@ -48,13 +48,12 @@ class Laplace(distribution.Distribution): distributions spliced together "back-to-back." """ - def __init__( - self, - loc, - scale, - strict=True, - strict_statistics=True, - name="Laplace"): + def __init__(self, + loc, + scale, + validate_args=True, + allow_nan_stats=False, + name="Laplace"): """Construct Laplace distribution with parameters `loc` and `scale`. The parameters `loc` and `scale` must be shaped in a way that supports @@ -65,24 +64,25 @@ class Laplace(distribution.Distribution): of the distribution. scale: `float` or `double`, positive-valued tensor which characterzes the spread of the distribution. - strict: Whether to validate input with asserts. If `strict` is `False`, - and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to validate input with asserts. If `validate_args` + is `False`, and the inputs are invalid, correct behavior is not + guaranteed. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: if `loc` and `scale` are of different dtype. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.op_scope([loc, scale], name): loc = ops.convert_to_tensor(loc) scale = ops.convert_to_tensor(scale) - with ops.control_dependencies( - [check_ops.assert_positive(scale)] if strict else []): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): self._name = name self._loc = array_ops.identity(loc, name="loc") self._scale = array_ops.identity(scale, name="scale") @@ -92,14 +92,14 @@ class Laplace(distribution.Distribution): contrib_tensor_util.assert_same_float_dtype((loc, scale)) @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py index 36efdf161f..d47ae6c019 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -89,14 +89,13 @@ class MultivariateNormalOperatorPD(distribution.Distribution): """ - def __init__( - self, - mu, - cov, - allow_nan=False, - strict=True, - strict_statistics=True, - name="MultivariateNormalCov"): + def __init__(self, + mu, + cov, + allow_nan=False, + validate_args=True, + allow_nan_stats=False, + name="MultivariateNormalCov"): """Multivariate Normal distributions on `R^k`. User must provide means `mu`, and an instance of `OperatorPDBase`, `cov`, @@ -110,19 +109,20 @@ class MultivariateNormalOperatorPD(distribution.Distribution): a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. - strict: Whether to validate input with asserts. If `strict` is `False`, - and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to validate input with asserts. If `validate_args` + is `False`, and the inputs are invalid, correct behavior is not + guaranteed. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: If `mu` and `cov` are different dtypes. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.name_scope(name): with ops.op_scope([mu] + cov.inputs, "init"): self._cov = cov @@ -139,7 +139,7 @@ class MultivariateNormalOperatorPD(distribution.Distribution): "mu and cov must have the same dtype. Found mu.dtype = %s, " "cov.dtype = %s" % (mu.dtype, cov.dtype)) - if not self.strict: + if not self.validate_args: return mu else: assert_compatible_shapes = control_flow_ops.group( @@ -160,14 +160,14 @@ class MultivariateNormalOperatorPD(distribution.Distribution): return control_flow_ops.with_dependencies([assert_compatible_shapes], mu) @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property def dtype(self): @@ -441,13 +441,12 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD): """ - def __init__( - self, - mu, - chol, - strict=True, - strict_statistics=True, - name="MultivariateNormalCholesky"): + def __init__(self, + mu, + chol, + validate_args=True, + allow_nan_stats=False, + name="MultivariateNormalCholesky"): """Multivariate Normal distributions on `R^k`. User must provide means `mu` and `chol` which holds the (batch) Cholesky @@ -458,20 +457,25 @@ class MultivariateNormalCholesky(MultivariateNormalOperatorPD): `b >= 0`. chol: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape `[N1,...,Nb, k, k]`. - strict: Whether to validate input with asserts. If `strict` is `False`, + validate_args: Whether to validate input with asserts. If `validate_args` + is `False`, and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: If `mu` and `chol` are different dtypes. """ - cov = operator_pd_cholesky.OperatorPDCholesky(chol, verify_pd=strict) + cov = operator_pd_cholesky.OperatorPDCholesky(chol, verify_pd=validate_args) super(MultivariateNormalCholesky, self).__init__( - mu, cov, strict_statistics=strict_statistics, strict=strict, name=name) + mu, + cov, + allow_nan_stats=allow_nan_stats, + validate_args=validate_args, + name=name) class MultivariateNormalFull(MultivariateNormalOperatorPD): @@ -519,13 +523,12 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD): """ - def __init__( - self, - mu, - sigma, - strict=True, - strict_statistics=True, - name="MultivariateNormalFull"): + def __init__(self, + mu, + sigma, + validate_args=True, + allow_nan_stats=False, + name="MultivariateNormalFull"): """Multivariate Normal distributions on `R^k`. User must provide means `mu` and `sigma`, the mean and covariance. @@ -535,17 +538,22 @@ class MultivariateNormalFull(MultivariateNormalOperatorPD): `b >= 0`. sigma: `(N+2)-D` `Tensor` with same `dtype` as `mu` and shape `[N1,...,Nb, k, k]`. - strict: Whether to validate input with asserts. If `strict` is `False`, - and the inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to validate input with asserts. If `validate_args` + is `False`, and the inputs are invalid, correct behavior is not + guaranteed. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: If `mu` and `sigma` are different dtypes. """ - cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=strict) + cov = operator_pd_full.OperatorPDFull(sigma, verify_pd=validate_args) super(MultivariateNormalFull, self).__init__( - mu, cov, strict_statistics=strict_statistics, strict=strict, name=name) + mu, + cov, + allow_nan_stats=allow_nan_stats, + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/normal.py b/tensorflow/contrib/distributions/python/ops/normal.py index f083c6f40e..dff8c7fdbb 100644 --- a/tensorflow/contrib/distributions/python/ops/normal.py +++ b/tensorflow/contrib/distributions/python/ops/normal.py @@ -80,8 +80,12 @@ class Normal(distribution.Distribution): """ - def __init__( - self, mu, sigma, strict=True, strict_statistics=True, name="Normal"): + def __init__(self, + mu, + sigma, + validate_args=True, + allow_nan_stats=False, + name="Normal"): """Construct Normal distributions with mean and stddev `mu` and `sigma`. The parameters `mu` and `sigma` must be shaped in a way that supports @@ -91,24 +95,24 @@ class Normal(distribution.Distribution): mu: `float` or `double` tensor, the means of the distribution(s). sigma: `float` or `double` tensor, the stddevs of the distribution(s). sigma must contain only positive values. - strict: Whether to assert that `sigma > 0`. If `strict` is False, - correct output is not guaranteed when input is invalid. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to assert that `sigma > 0`. If `validate_args` is + False, correct output is not guaranteed when input is invalid. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: if mu and sigma are different dtypes. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.op_scope([mu, sigma], name): mu = ops.convert_to_tensor(mu) sigma = ops.convert_to_tensor(sigma) - with ops.control_dependencies( - [check_ops.assert_positive(sigma)] if strict else []): + with ops.control_dependencies([check_ops.assert_positive(sigma)] if + validate_args else []): self._name = name self._mu = array_ops.identity(mu, name="mu") self._sigma = array_ops.identity(sigma, name="sigma") @@ -118,14 +122,14 @@ class Normal(distribution.Distribution): contrib_tensor_util.assert_same_float_dtype((mu, sigma)) @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py index ecea8184bc..e5fa624ddc 100644 --- a/tensorflow/contrib/distributions/python/ops/student_t.py +++ b/tensorflow/contrib/distributions/python/ops/student_t.py @@ -84,14 +84,13 @@ class StudentT(distribution.Distribution): ``` """ - def __init__( - self, - df, - mu, - sigma, - strict=True, - strict_statistics=True, - name="StudentT"): + def __init__(self, + df, + mu, + sigma, + validate_args=True, + allow_nan_stats=False, + name="StudentT"): """Construct Student's t distributions. The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`. @@ -106,23 +105,23 @@ class StudentT(distribution.Distribution): sigma: `float` or `double` tensor, the scaling factor for the distribution(s). `sigma` must contain only positive values. Note that `sigma` is not the standard deviation of this distribution. - strict: Whether to assert that `df > 0, sigma > 0`. If `strict` is False - and inputs are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to assert that `df > 0, sigma > 0`. If + `validate_args` is False and inputs are invalid, correct behavior is not + guaranteed. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to give Ops created by the initializer. Raises: TypeError: if mu and sigma are different dtypes. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.op_scope([df, mu, sigma], name) as scope: - with ops.control_dependencies( - [check_ops.assert_positive(df), check_ops.assert_positive(sigma)] - if strict else []): + with ops.control_dependencies([check_ops.assert_positive( + df), check_ops.assert_positive(sigma)] if validate_args else []): self._df = ops.convert_to_tensor(df, name="df") self._mu = ops.convert_to_tensor(mu, name="mu") self._sigma = ops.convert_to_tensor(sigma, name="sigma") @@ -133,14 +132,14 @@ class StudentT(distribution.Distribution): self._get_event_shape = tensor_shape.TensorShape([]) @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): @@ -169,7 +168,7 @@ class StudentT(distribution.Distribution): """Mean of the distribution. The mean of Student's T equals `mu` if `df > 1`, otherwise it is `NaN`. If - `self.strict_statistics=True`, then an exception will be raised rather than + `self.allow_nan_stats=False`, then an exception will be raised rather than returning `NaN`. Args: @@ -181,14 +180,14 @@ class StudentT(distribution.Distribution): with ops.name_scope(self.name): with ops.op_scope([self._mu], name): result_if_defined = self._mu * self._ones() - if self.strict_statistics: - one = ops.convert_to_tensor(1.0, dtype=self.dtype) - return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, self._df)], result_if_defined) - else: + if self.allow_nan_stats: df_gt_1 = self._df > self._ones() nan = np.nan + self._zeros() return math_ops.select(df_gt_1, result_if_defined, nan) + else: + one = ops.convert_to_tensor(1.0, dtype=self.dtype) + return control_flow_ops.with_dependencies( + [check_ops.assert_less(one, self._df)], result_if_defined) def mode(self, name="mode"): with ops.name_scope(self.name): @@ -207,7 +206,7 @@ class StudentT(distribution.Distribution): ``` The NaN state occurs because mean is undefined for `df <= 1`, and if - `self.strict_statistics` is `True`, an exception will be raised if any batch + `self.allow_nan_stats` is `False`, an exception will be raised if any batch members fall into this state. Args: @@ -227,15 +226,15 @@ class StudentT(distribution.Distribution): result_where_finite, self._zeros() + np.inf) - if self.strict_statistics: - one = ops.convert_to_tensor(1.0, self.dtype) - return control_flow_ops.with_dependencies( - [check_ops.assert_less(one, self._df)], result_where_defined) - else: + if self.allow_nan_stats: return math_ops.select( (self._zeros() + self._df > 1), result_where_defined, self._zeros() + np.nan) + else: + one = ops.convert_to_tensor(1.0, self.dtype) + return control_flow_ops.with_dependencies( + [check_ops.assert_less(one, self._df)], result_where_defined) def std(self, name="std"): with ops.name_scope(self.name): diff --git a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py index 4340b3e8a3..185741b217 100644 --- a/tensorflow/contrib/distributions/python/ops/transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/transformed_distribution.py @@ -243,12 +243,12 @@ class TransformedDistribution(distribution.Distribution): return self._base_dist.is_reparameterized @property - def strict_statistics(self): - return self._base_dist.strict_statistics + def allow_nan_stats(self): + return self._base_dist.allow_nan_stats @property - def strict(self): - return self._base_dist.strict + def validate_args(self): + return self._base_dist.validate_args @property def is_continuous(self): diff --git a/tensorflow/contrib/distributions/python/ops/uniform.py b/tensorflow/contrib/distributions/python/ops/uniform.py index e7e685cd9d..eb196a3ea9 100644 --- a/tensorflow/contrib/distributions/python/ops/uniform.py +++ b/tensorflow/contrib/distributions/python/ops/uniform.py @@ -37,8 +37,12 @@ class Uniform(distribution.Distribution): The PDF of this distribution is constant between [`a`, `b`], and 0 elsewhere. """ - def __init__( - self, a=0.0, b=1.0, strict=True, strict_statistics=True, name="Uniform"): + def __init__(self, + a=0.0, + b=1.0, + validate_args=True, + allow_nan_stats=False, + name="Uniform"): """Construct Uniform distributions with `a` and `b`. The parameters `a` and `b` must be shaped in a way that supports @@ -65,22 +69,22 @@ class Uniform(distribution.Distribution): Args: a: `float` or `double` tensor, the minimum endpoint. b: `float` or `double` tensor, the maximum endpoint. Must be > `a`. - strict: Whether to assert that `a > b`. If `strict` is False and inputs - are invalid, correct behavior is not guaranteed. - strict_statistics: Boolean, default True. If True, raise an exception if + validate_args: Whether to assert that `a > b`. If `validate_args` is False + and inputs are invalid, correct behavior is not guaranteed. + allow_nan_stats: Boolean, default False. If False, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. - If False, batch members with valid parameters leading to undefined + If True, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. name: The name to prefix Ops created by this distribution class. Raises: - InvalidArgumentError: if `a >= b` and `strict=True`. + InvalidArgumentError: if `a >= b` and `validate_args=True`. """ - self._strict_statistics = strict_statistics - self._strict = strict + self._allow_nan_stats = allow_nan_stats + self._validate_args = validate_args with ops.op_scope([a, b], name): - with ops.control_dependencies( - [check_ops.assert_less(a, b)] if strict else []): + with ops.control_dependencies([check_ops.assert_less(a, b)] if + validate_args else []): a = array_ops.identity(a, name="a") b = array_ops.identity(b, name="b") @@ -93,14 +97,14 @@ class Uniform(distribution.Distribution): contrib_tensor_util.assert_same_float_dtype((a, b)) @property - def strict_statistics(self): + def allow_nan_stats(self): """Boolean describing behavior when a stat is undefined for batch member.""" - return self._strict_statistics + return self._allow_nan_stats @property - def strict(self): + def validate_args(self): """Boolean describing behavior on invalid input.""" - return self._strict + return self._validate_args @property def name(self): |