aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 15:21:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 15:29:25 -0700
commitb2b98a5ad1b647b77cb42761671cd9b3cf0e87b6 (patch)
treeeabd186d74aace535112edde70b7948d5eaa78cc /tensorflow/python/estimator
parent237c6ccae40005e3b6199731c45e1c9f5cd86c5f (diff)
Boosted trees: Add error messages when tree complexity parameter is not properly set.
PiperOrigin-RevId: 213706101
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py10
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py35
2 files changed, 42 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 36048a2bfd..756d32d03f 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -422,9 +422,13 @@ class _EnsembleGrower(object):
self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
tree_hparams.pruning_mode)
- if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
- and tree_hparams.tree_complexity <= 0):
- raise ValueError('For pruning, tree_complexity must be positive.')
+ if tree_hparams.tree_complexity > 0:
+ if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError(
+ 'Tree complexity have no effect unless pruning mode is chosen.')
+ else:
+ if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING:
+ raise ValueError('For pruning, tree_complexity must be positive.')
# pylint: enable=protected-access
@abc.abstractmethod
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 9409cb5cc7..d4cb3e27d0 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -564,6 +564,41 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testTreeComplexityIsSetCorrectly(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ num_steps = 10
+ # Tree complexity is set but no pruning.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ tree_complexity=1e-3)
+ with self.assertRaisesRegexp(ValueError, 'Tree complexity have no effect'):
+ est.train(input_fn, steps=num_steps)
+
+ # Pruning but no tree complexity.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre')
+ with self.assertRaisesRegexp(ValueError,
+ 'tree_complexity must be positive'):
+ est.train(input_fn, steps=num_steps)
+
+ # All is good.
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=1e-3)
+ est.train(input_fn, steps=num_steps)
+
class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
"""Test debug/model explainability outputs for individual predictions.