aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distributions
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-29 08:36:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 08:38:52 -0700
commitcff06379c2e1ac01de3b3c0ca32c3a3037d5b833 (patch)
tree844dffffb19c7d35ba4a8d1f668fc13366676577 /tensorflow/contrib/distributions
parent8ecf1ebc5d83e66b29a07113b53c49ef8264703c (diff)
Generalize assert_true_mean_equal and assert_true_mean_equal_two_sample to assert_true_mean_in_interval.
PiperOrigin-RevId: 198400265
Diffstat (limited to 'tensorflow/contrib/distributions')
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py33
-rw-r--r--tensorflow/contrib/distributions/python/ops/statistical_testing.py131
2 files changed, 122 insertions, 42 deletions
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
index ce6cf702d5..4a5a6b5ae1 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/statistical_testing_test.py
@@ -129,16 +129,41 @@ class StatisticalTestingTest(test.TestCase):
# Test that the test assertion confirms that the mean of the
# standard uniform distribution is not 0.4.
- with self.assertRaisesOpError("Mean confidence interval too high"):
+ with self.assertRaisesOpError("true mean greater than expected"):
sess.run(st.assert_true_mean_equal_by_dkwm(
samples, 0., 1., 0.4, false_fail_rate=1e-6))
# Test that the test assertion confirms that the mean of the
# standard uniform distribution is not 0.6.
- with self.assertRaisesOpError("Mean confidence interval too low"):
+ with self.assertRaisesOpError("true mean smaller than expected"):
sess.run(st.assert_true_mean_equal_by_dkwm(
samples, 0., 1., 0.6, false_fail_rate=1e-6))
+ def test_dkwm_mean_in_interval_one_sample_assertion(self):
+ rng = np.random.RandomState(seed=0)
+ num_samples = 5000
+
+ # Test that the test assertion agrees that the mean of the standard
+ # uniform distribution is between 0.4 and 0.6.
+ samples = rng.uniform(size=num_samples).astype(np.float32)
+ self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+ samples, 0., 1.,
+ expected_low=0.4, expected_high=0.6, false_fail_rate=1e-6))
+
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is not between 0.2 and 0.4.
+ with self.assertRaisesOpError("true mean greater than expected"):
+ self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+ samples, 0., 1.,
+ expected_low=0.2, expected_high=0.4, false_fail_rate=1e-6))
+
+ # Test that the test assertion confirms that the mean of the
+ # standard uniform distribution is not between 0.6 and 0.8.
+ with self.assertRaisesOpError("true mean smaller than expected"):
+ self.evaluate(st.assert_true_mean_in_interval_by_dkwm(
+ samples, 0., 1.,
+ expected_low=0.6, expected_high=0.8, false_fail_rate=1e-6))
+
def test_dkwm_mean_two_sample_assertion(self):
rng = np.random.RandomState(seed=0)
num_samples = 4000
@@ -172,7 +197,7 @@ class StatisticalTestingTest(test.TestCase):
# Test that the test assertion confirms that the mean of the
# standard uniform distribution is different from the mean of beta(2, 1).
beta_high_samples = rng.beta(2, 1, size=num_samples).astype(np.float32)
- with self.assertRaisesOpError("samples1 has a smaller mean"):
+ with self.assertRaisesOpError("true mean smaller than expected"):
sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
samples1, 0., 1.,
beta_high_samples, 0., 1.,
@@ -190,7 +215,7 @@ class StatisticalTestingTest(test.TestCase):
# Test that the test assertion confirms that the mean of the
# standard uniform distribution is different from the mean of beta(1, 2).
beta_low_samples = rng.beta(1, 2, size=num_samples).astype(np.float32)
- with self.assertRaisesOpError("samples2 has a smaller mean"):
+ with self.assertRaisesOpError("true mean greater than expected"):
sess.run(st.assert_true_mean_equal_by_dkwm_two_sample(
samples1, 0., 1.,
beta_low_samples, 0., 1.,
diff --git a/tensorflow/contrib/distributions/python/ops/statistical_testing.py b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
index 9c69435fac..3ea9a331c7 100644
--- a/tensorflow/contrib/distributions/python/ops/statistical_testing.py
+++ b/tensorflow/contrib/distributions/python/ops/statistical_testing.py
@@ -140,6 +140,7 @@ __all__ = [
"assert_true_mean_equal_by_dkwm",
"min_discrepancy_of_true_means_detectable_by_dkwm",
"min_num_samples_for_dkwm_mean_test",
+ "assert_true_mean_in_interval_by_dkwm",
"assert_true_mean_equal_by_dkwm_two_sample",
"min_discrepancy_of_true_means_detectable_by_dkwm_two_sample",
"min_num_samples_for_dkwm_mean_two_sample_test",
@@ -454,20 +455,8 @@ def assert_true_mean_equal_by_dkwm(
with ops.name_scope(
name, "assert_true_mean_equal_by_dkwm",
[samples, low, high, expected, false_fail_rate]):
- samples = ops.convert_to_tensor(samples, name="samples")
- low = ops.convert_to_tensor(low, name="low")
- high = ops.convert_to_tensor(high, name="high")
- expected = ops.convert_to_tensor(expected, name="expected")
- false_fail_rate = ops.convert_to_tensor(
- false_fail_rate, name="false_fail_rate")
- samples = _check_shape_dominates(samples, [low, high, expected])
- min_mean, max_mean = true_mean_confidence_interval_by_dkwm(
- samples, low, high, error_rate=false_fail_rate)
- less_op = check_ops.assert_less(
- min_mean, expected, message="Mean confidence interval too high")
- with ops.control_dependencies([less_op]):
- return check_ops.assert_greater(
- max_mean, expected, message="Mean confidence interval too low")
+ return assert_true_mean_in_interval_by_dkwm(
+ samples, low, high, expected, expected, false_fail_rate)
def min_discrepancy_of_true_means_detectable_by_dkwm(
@@ -505,12 +494,15 @@ def min_discrepancy_of_true_means_detectable_by_dkwm(
some scalar distribution supported on `[low[i], high[i]]` is enough
to detect a difference in means of size `discr[i]` or more.
Specifically, we guarantee that (a) if the true mean is the expected
- mean, `assert_true_mean_equal_by_dkwm` will fail with probability at
- most `false_fail_rate / K` (which amounts to `false_fail_rate` if
- applied to the whole batch at once), and (b) if the true mean
- differs from the expected mean by at least `discr[i]`,
- `assert_true_mean_equal_by_dkwm` will pass with probability at most
- `false_pass_rate`.
+ mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with
+ probability at most `false_fail_rate / K` (which amounts to
+ `false_fail_rate` if applied to the whole batch at once), and (b) if
+ the true mean differs from the expected mean (resp. falls outside
+ the expected interval) by at least `discr[i]`,
+ `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with
+ probability at most `false_pass_rate`.
The detectable discrepancy scales as
@@ -578,12 +570,15 @@ def min_num_samples_for_dkwm_mean_test(
some scalar distribution supported on `[low[i], high[i]]` is enough
to detect a difference in means of size `discrepancy[i]` or more.
Specifically, we guarantee that (a) if the true mean is the expected
- mean, `assert_true_mean_equal_by_dkwm` will fail with probability at
- most `false_fail_rate / K` (which amounts to `false_fail_rate` if
- applied to the whole batch at once), and (b) if the true mean
- differs from the expected mean by at least `discrepancy[i]`,
- `assert_true_mean_equal_by_dkwm` will pass with probability at most
- `false_pass_rate`.
+ mean (resp. in the expected interval), then `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will fail with
+ probability at most `false_fail_rate / K` (which amounts to
+ `false_fail_rate` if applied to the whole batch at once), and (b) if
+ the true mean differs from the expected mean (resp. falls outside
+ the expected interval) by at least `discrepancy[i]`,
+ `assert_true_mean_equal_by_dkwm`
+ (resp. `assert_true_mean_in_interval_by_dkwm`) will pass with
+ probability at most `false_pass_rate`.
The required number of samples scales
as `O((high[i] - low[i])**2)`, `O(-log(false_fail_rate/K))`,
@@ -610,6 +605,76 @@ def min_num_samples_for_dkwm_mean_test(
return math_ops.maximum(n1, n2)
+def assert_true_mean_in_interval_by_dkwm(
+ samples, low, high, expected_low, expected_high,
+ false_fail_rate=1e-6, name=None):
+ """Asserts the mean of the given distribution is in the given interval.
+
+ More precisely, fails if there is enough evidence (using the
+ [Dvoretzky-Kiefer-Wolfowitz-Massart inequality]
+ (https://en.wikipedia.org/wiki/CDF-based_nonparametric_confidence_interval))
+ that the mean of the distribution from which the given samples are
+ drawn is _outside_ the given interval with statistical significance
+ `false_fail_rate` or stronger, otherwise passes. If you also want
+ to check that you are gathering enough evidence that a pass is not
+ spurious, see `min_num_samples_for_dkwm_mean_test` and
+ `min_discrepancy_of_true_means_detectable_by_dkwm`.
+
+ Note that `false_fail_rate` is a total false failure rate for all
+ the assertions in the batch. As such, if the batch is nontrivial,
+ the assertion will insist on stronger evidence to fail any one member.
+
+ Args:
+ samples: Floating-point `Tensor` of samples from the distribution(s)
+ of interest. Entries are assumed IID across the 0th dimension.
+ The other dimensions must broadcast with `low` and `high`.
+ The support is bounded: `low <= samples <= high`.
+ low: Floating-point `Tensor` of lower bounds on the distributions'
+ supports.
+ high: Floating-point `Tensor` of upper bounds on the distributions'
+ supports.
+ expected_low: Floating-point `Tensor` of lower bounds on the
+ expected true means.
+ expected_high: Floating-point `Tensor` of upper bounds on the
+ expected true means.
+ false_fail_rate: *Scalar* floating-point `Tensor` admissible total
+ rate of mistakes.
+ name: A name for this operation (optional).
+
+ Returns:
+ check: Op that raises `InvalidArgumentError` if any expected mean
+ interval does not overlap with the corresponding confidence
+ interval.
+ """
+ with ops.name_scope(
+ name, "assert_true_mean_in_interval_by_dkwm",
+ [samples, low, high, expected_low, expected_high, false_fail_rate]):
+ samples = ops.convert_to_tensor(samples, name="samples")
+ low = ops.convert_to_tensor(low, name="low")
+ high = ops.convert_to_tensor(high, name="high")
+ expected_low = ops.convert_to_tensor(expected_low, name="expected_low")
+ expected_high = ops.convert_to_tensor(expected_high, name="expected_high")
+ false_fail_rate = ops.convert_to_tensor(
+ false_fail_rate, name="false_fail_rate")
+ samples = _check_shape_dominates(
+ samples, [low, high, expected_low, expected_high])
+ min_mean, max_mean = true_mean_confidence_interval_by_dkwm(
+ samples, low, high, false_fail_rate)
+ # Assert that the interval [min_mean, max_mean] intersects the
+ # interval [expected_low, expected_high]. This is true if
+ # max_mean >= expected_low and min_mean <= expected_high.
+ # By DeMorgan's law, that's also equivalent to
+ # not (max_mean < expected_low or min_mean > expected_high),
+ # which is a way of saying the two intervals are not disjoint.
+ check_confidence_interval_can_intersect = check_ops.assert_greater_equal(
+ max_mean, expected_low, message="Confidence interval does not "
+ "intersect: true mean smaller than expected")
+ with ops.control_dependencies([check_confidence_interval_can_intersect]):
+ return check_ops.assert_less_equal(
+ min_mean, expected_high, message="Confidence interval does not "
+ "intersect: true mean greater than expected")
+
+
def assert_true_mean_equal_by_dkwm_two_sample(
samples1, low1, high1, samples2, low2, high2,
false_fail_rate=1e-6, name=None):
@@ -676,20 +741,10 @@ def assert_true_mean_equal_by_dkwm_two_sample(
# and sample counts should be valid; however, because the intervals
# scale as O(-log(false_fail_rate)), there doesn't seem to be much
# room to win.
- min_mean_1, max_mean_1 = true_mean_confidence_interval_by_dkwm(
- samples1, low1, high1, false_fail_rate / 2.)
min_mean_2, max_mean_2 = true_mean_confidence_interval_by_dkwm(
samples2, low2, high2, false_fail_rate / 2.)
- # I want to assert
- # not (max_mean_1 < min_mean_2 or min_mean_1 > max_mean_2),
- # but I think I only have and-combination of asserts, so use DeMorgan.
- check_confidence_intervals_can_intersect = check_ops.assert_greater_equal(
- max_mean_1, min_mean_2, message="Confidence intervals do not "
- "intersect: samples1 has a smaller mean than samples2")
- with ops.control_dependencies([check_confidence_intervals_can_intersect]):
- return check_ops.assert_less_equal(
- min_mean_1, max_mean_2, message="Confidence intervals do not "
- "intersect: samples2 has a smaller mean than samples1")
+ return assert_true_mean_in_interval_by_dkwm(
+ samples1, low1, high1, min_mean_2, max_mean_2, false_fail_rate / 2.)
def min_discrepancy_of_true_means_detectable_by_dkwm_two_sample(