diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-20 06:44:27 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-20 06:44:27 +0800 |
commit | fb2918f81053e15801e08d1a90cf7960b6d219e9 (patch) | |
tree | dfdc13d42a9bd8ca58c3ae395a79eacbc184e730 /tensorflow/python/estimator | |
parent | c7fcdf847750b364629299579c19be39576c6b04 (diff) |
TST: introduce test case from upstream/master
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 1497d4253b..23687a738b 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -1099,6 +1099,41 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): learning_rate=1.0, max_depth=1) + def testTreeComplexityIsSetCorrectly(self): + input_fn = _make_train_input_fn(is_classification=True) + + num_steps = 10 + # Tree complexity is set but no pruning. + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + tree_complexity=1e-3) + with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'): + est.train(input_fn, steps=num_steps) + + # Pruning but no tree complexity. + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + pruning_mode='pre') + with self.assertRaisesRegexp(ValueError, + 'tree_complexity must be positive'): + est.train(input_fn, steps=num_steps) + + # All is good. + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + pruning_mode='pre', + tree_complexity=1e-3) + est.train(input_fn, steps=num_steps) + class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase): """Test debug/model explainability outputs for individual predictions. |