aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-20 06:44:27 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-20 06:44:27 +0800
commitfb2918f81053e15801e08d1a90cf7960b6d219e9 (patch)
treedfdc13d42a9bd8ca58c3ae395a79eacbc184e730 /tensorflow/python/estimator
parentc7fcdf847750b364629299579c19be39576c6b04 (diff)
TST: introduce test case from upstream/master
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 1497d4253b..23687a738b 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -1099,6 +1099,41 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
learning_rate=1.0,
max_depth=1)
+ def testTreeComplexityIsSetCorrectly(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ num_steps = 10
+ # Tree complexity is set but no pruning.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ tree_complexity=1e-3)
+ with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+ est.train(input_fn, steps=num_steps)
+
+ # Pruning but no tree complexity.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre')
+ with self.assertRaisesRegexp(ValueError,
+ 'tree_complexity must be positive'):
+ est.train(input_fn, steps=num_steps)
+
+ # All is good.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=1e-3)
+ est.train(input_fn, steps=num_steps)
+
class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
"""Test debug/model explainability outputs for individual predictions.