aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-26 13:20:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 13:24:52 -0800
commitd1910fa9eb274717719c4dcff3247498ea30caa4 (patch)
tree2f5e8aece7643911621601d68a9718f061626c2d /tensorflow/contrib/boosted_trees/python
parentf6a53e7abd54afdff4d1377535d61dbc1efd174c (diff)
Add more tests to validate the bucket boundaries for
inputs with equal distributions. PiperOrigin-RevId: 183435084
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python')
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
index eefa7ef0dc..81f58de28c 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py
@@ -183,11 +183,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
self.assertEqual(num_quantiles + 1, len(buckets))
self.assertAllEqual([2030, 2040, 2050, 2060], buckets)
- def _testStreamingQuantileBucketsHelper(self, inputs):
+ def _testStreamingQuantileBucketsHelper(
+ self, inputs, num_quantiles=3, expected_buckets=None):
"""Helper to test quantile buckets on different inputs."""
- # Use 3 quantiles, 4 boundaries for simplicity.
- num_quantiles = 3
# set generate_quantiles to True since the test will generate fewer
# boundaries otherwise.
with self.test_session() as sess:
@@ -213,7 +212,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
buckets, are_ready_flush = (sess.run(
[buckets, are_ready_flush]))
self.assertEqual(True, are_ready_flush)
+ # By default, use 3 quantiles, 4 boundaries for simplicity.
self.assertEqual(num_quantiles + 1, len(buckets))
+ if expected_buckets:
+ self.assertAllEqual(buckets, expected_buckets)
def testStreamingQuantileBucketsRepeatedSingleValue(self):
inputs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
@@ -231,6 +233,28 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase):
inputs = [5]
self._testStreamingQuantileBucketsHelper(inputs)
+ def testStreamingQuantileBucketsEqualDistributionInSequence(self):
+ # Input pattern is of the form [1, 1, 1, 2, 2, 2, 3, 3, 3, ...]
+ ones = 100 * [1]
+ inputs = []
+ for i in range(1, 101):
+ inputs += [i * k for k in ones]
+ # Expect 100 equally spaced buckets.
+ expected_buckets = range(1, 101)
+ self._testStreamingQuantileBucketsHelper(
+ inputs, num_quantiles=99, expected_buckets=expected_buckets)
+
+ def testStreamingQuantileBucketsEqualDistributionInterleaved(self):
+ # Input pattern is of the form [1, 2, 3, 1, 2, 3, 1, 2, 3, ...]
+ sequence = range(1, 101)
+ inputs = []
+ for _ in range(1, 101):
+ inputs += sequence
+ # Expect 100 equally spaced buckets.
+ expected_buckets = range(1, 101)
+ self._testStreamingQuantileBucketsHelper(
+ inputs, num_quantiles=99, expected_buckets=expected_buckets)
+
def testStreamingQuantileBuckets(self):
"""Sets up the quantile summary op test as follows.