diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-19 15:21:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 15:29:25 -0700 |
commit | b2b98a5ad1b647b77cb42761671cd9b3cf0e87b6 (patch) | |
tree | eabd186d74aace535112edde70b7948d5eaa78cc /tensorflow/python/estimator | |
parent | 237c6ccae40005e3b6199731c45e1c9f5cd86c5f (diff) |
Boosted trees: Add error messages when tree complexity parameter is not properly set.
PiperOrigin-RevId: 213706101
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 10 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 35 |
2 files changed, 42 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 36048a2bfd..756d32d03f 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -422,9 +422,13 @@ class _EnsembleGrower(object): self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str( tree_hparams.pruning_mode) - if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING - and tree_hparams.tree_complexity <= 0): - raise ValueError('For pruning, tree_complexity must be positive.') + if tree_hparams.tree_complexity > 0: + if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING: + raise ValueError( + 'Tree complexity have no effect unless pruning mode is chosen.') + else: + if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING: + raise ValueError('For pruning, tree_complexity must be positive.') # pylint: enable=protected-access @abc.abstractmethod diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 9409cb5cc7..d4cb3e27d0 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -564,6 +564,41 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id) self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + def testTreeComplexityIsSetCorrectly(self): + input_fn = _make_train_input_fn(is_classification=True) + + num_steps = 10 + # Tree complexity is set but no pruning. + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + tree_complexity=1e-3) + with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'): + est.train(input_fn, steps=num_steps) + + # Pruning but no tree complexity. + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + pruning_mode='pre') + with self.assertRaisesRegexp(ValueError, + 'tree_complexity must be positive'): + est.train(input_fn, steps=num_steps) + + # All is good. + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5, + pruning_mode='pre', + tree_complexity=1e-3) + est.train(input_fn, steps=num_steps) + class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase): """Test debug/model explainability outputs for individual predictions. |