aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 11:38:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 11:49:38 -0700
commit91ff408ecd52dd167d966c9df222e840f1d43f8f (patch)
tree642ab26958e3748cecc5f22c9b0e9a850fc31868 /tensorflow/contrib/estimator
parentc83525a1887ac3d7c03d4d25351e421cd90069a4 (diff)
Boosted trees: Revealing pruning mode as one of the parameters for a gbdt estimator
PiperOrigin-RevId: 206193733
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py30
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py68
2 files changed, 92 insertions, 6 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 43bfcffd79..7ed77bcce6 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -50,7 +50,8 @@ class _BoostedTreesEstimator(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesEstimator` instance.
Args:
@@ -89,13 +90,18 @@ class _BoostedTreesEstimator(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
"""
# pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -129,7 +135,8 @@ def boosted_trees_classifier_train_in_memory(
min_node_weight=0.,
config=None,
train_hooks=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Trains a boosted tree classifier with in memory dataset.
Example:
@@ -208,6 +215,11 @@ def boosted_trees_classifier_train_in_memory(
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -228,7 +240,7 @@ def boosted_trees_classifier_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -269,7 +281,8 @@ def boosted_trees_regressor_train_in_memory(
min_node_weight=0.,
config=None,
train_hooks=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Trains a boosted tree regressor with in memory dataset.
Example:
@@ -341,6 +354,11 @@ def boosted_trees_regressor_train_in_memory(
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -360,7 +378,7 @@ def boosted_trees_regressor_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index 999c2aa5e2..b1581f3750 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -136,6 +136,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['average_loss'], 0.614642)
+ def testTrainAndEvaluateEstimatorWithPrePruning(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ tree_complexity=0.001,
+ pruning_mode='pre')
+
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ # We stop actually after 2*depth*n_trees steps (via a hook) because we still
+ # could not grow 2 trees of depth 5 (due to pre-pruning).
+ self._assert_checkpoint(
+ est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 3.83943)
+
+ def testTrainAndEvaluateEstimatorWithPostPruning(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ tree_complexity=0.001,
+ pruning_mode='post')
+
+ # It will stop after 10 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ self._assert_checkpoint(
+ est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.37652)
+
def testInferEstimator(self):
train_input_fn = _make_train_input_fn(is_classification=False)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -231,6 +274,31 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertAllClose([[0], [1], [1], [0], [0]],
[pred['class_ids'] for pred in predictions])
+ def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_classifier_train_in_memory(
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=0.01)
+ # We stop actually after 2*depth*n_trees steps (via a hook) because we still
+ # could not grow 1 trees of depth 5 (due to pre-pruning).
+ self._assert_checkpoint(
+ est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11)
+
+ # Check evaluate and predict.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+ # Validate predictions.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+
def testBinaryClassifierTrainInMemoryWithDataset(self):
train_input_fn = _make_train_input_fn_dataset(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(