aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 11:38:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 11:49:38 -0700
commit91ff408ecd52dd167d966c9df222e840f1d43f8f (patch)
tree642ab26958e3748cecc5f22c9b0e9a850fc31868
parentc83525a1887ac3d7c03d4d25351e421cd90069a4 (diff)
Boosted trees: Revealing pruning mode as one of the parameters for a gbdt estimator
PiperOrigin-RevId: 206193733
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py30
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py68
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py45
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py3
-rw-r--r--tensorflow/python/ops/boosted_trees_ops.py11
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt2
7 files changed, 141 insertions, 20 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 43bfcffd79..7ed77bcce6 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -50,7 +50,8 @@ class _BoostedTreesEstimator(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesEstimator` instance.
Args:
@@ -89,13 +90,18 @@ class _BoostedTreesEstimator(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
"""
# pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -129,7 +135,8 @@ def boosted_trees_classifier_train_in_memory(
min_node_weight=0.,
config=None,
train_hooks=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Trains a boosted tree classifier with in memory dataset.
Example:
@@ -208,6 +215,11 @@ def boosted_trees_classifier_train_in_memory(
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -228,7 +240,7 @@ def boosted_trees_classifier_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -269,7 +281,8 @@ def boosted_trees_regressor_train_in_memory(
min_node_weight=0.,
config=None,
train_hooks=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Trains a boosted tree regressor with in memory dataset.
Example:
@@ -341,6 +354,11 @@ def boosted_trees_regressor_train_in_memory(
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Returns:
a `BoostedTreesClassifier` instance created with the given arguments and
@@ -360,7 +378,7 @@ def boosted_trees_regressor_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index 999c2aa5e2..b1581f3750 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -136,6 +136,49 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['average_loss'], 0.614642)
+ def testTrainAndEvaluateEstimatorWithPrePruning(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ tree_complexity=0.001,
+ pruning_mode='pre')
+
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ # We stop actually after 2*depth*n_trees steps (via a hook) because we still
+ # could not grow 2 trees of depth 5 (due to pre-pruning).
+ self._assert_checkpoint(
+ est.model_dir, global_step=21, finalized_trees=0, attempted_layers=21)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 3.83943)
+
+ def testTrainAndEvaluateEstimatorWithPostPruning(self):
+ input_fn = _make_train_input_fn(is_classification=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ head=self._head,
+ max_depth=5,
+ tree_complexity=0.001,
+ pruning_mode='post')
+
+ # It will stop after 10 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+ self._assert_checkpoint(
+ est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10)
+ eval_res = est.evaluate(input_fn=input_fn, steps=1)
+ self.assertAllClose(eval_res['average_loss'], 2.37652)
+
def testInferEstimator(self):
train_input_fn = _make_train_input_fn(is_classification=False)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -231,6 +274,31 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertAllClose([[0], [1], [1], [0], [0]],
[pred['class_ids'] for pred in predictions])
+ def testBinaryClassifierTrainInMemoryAndEvalAndInferWithPrePruning(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.boosted_trees_classifier_train_in_memory(
+ train_input_fn=train_input_fn,
+ feature_columns=self._feature_columns,
+ n_trees=1,
+ max_depth=5,
+ pruning_mode='pre',
+ tree_complexity=0.01)
+ # We stop actually after 2*depth*n_trees steps (via a hook) because we still
+ # could not grow 1 trees of depth 5 (due to pre-pruning).
+ self._assert_checkpoint(
+ est.model_dir, global_step=11, finalized_trees=0, attempted_layers=11)
+
+ # Check evaluate and predict.
+ eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
+ self.assertAllClose(eval_res['accuracy'], 1.0)
+ # Validate predictions.
+ predictions = list(est.predict(input_fn=predict_input_fn))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
+
def testBinaryClassifierTrainInMemoryWithDataset(self):
train_input_fn = _make_train_input_fn_dataset(is_classification=True)
predict_input_fn = numpy_io.numpy_input_fn(
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 3292e2724d..8b423f76de 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -46,7 +46,7 @@ from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity',
- 'min_node_weight', 'center_bias'
+ 'min_node_weight', 'center_bias', 'pruning_mode'
])
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
@@ -410,9 +410,20 @@ class _EnsembleGrower(object):
Args:
tree_ensemble: A TreeEnsemble variable.
tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ Raises:
+ ValueError: when pruning mode is invalid or pruning is used and no tree
+ complexity is set.
"""
self._tree_ensemble = tree_ensemble
self._tree_hparams = tree_hparams
+ # pylint: disable=protected-access
+ self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str(
+ tree_hparams.pruning_mode)
+
+ if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING
+ and tree_hparams.tree_complexity <= 0):
+ raise ValueError('For pruning, tree_complexity must be positive.')
+ # pylint: enable=protected-access
@abc.abstractmethod
def center_bias(self, center_bias_var, gradients, hessians):
@@ -500,7 +511,7 @@ class _EnsembleGrower(object):
right_node_contribs=right_node_contribs_list,
learning_rate=self._tree_hparams.learning_rate,
max_depth=self._tree_hparams.max_depth,
- pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
+ pruning_mode=self._pruning_mode_parsed)
return grow_op
@@ -675,6 +686,7 @@ def _bt_model_fn(
is_single_machine = (config.num_worker_replicas <= 1)
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
center_bias = tree_hparams.center_bias
+
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
@@ -925,7 +937,8 @@ class BoostedTreesClassifier(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesClassifier` instance.
Example:
@@ -999,7 +1012,11 @@ class BoostedTreesClassifier(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
-
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -1012,9 +1029,9 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes, weight_column, label_vocabulary=label_vocabulary)
# HParams for the model.
- tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
- l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_hparams = _TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
@@ -1058,7 +1075,8 @@ class BoostedTreesRegressor(estimator.Estimator):
tree_complexity=0.,
min_node_weight=0.,
config=None,
- center_bias=False):
+ center_bias=False,
+ pruning_mode='none'):
"""Initializes a `BoostedTreesRegressor` instance.
Example:
@@ -1125,6 +1143,11 @@ class BoostedTreesRegressor(estimator.Estimator):
regression problems, the first node will return the mean of the labels.
For binary classification problems, it will return a logit for a prior
probability of label 1.
+ pruning_mode: one of 'none', 'pre', 'post' to indicate no pruning, pre-
+ pruning (do not split a node if not enough gain is observed) and post
+ pruning (build the tree up to a max depth and then prune branches with
+ negative gain). For pre and post pruning, you MUST provide
+ tree_complexity >0.
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -1136,9 +1159,9 @@ class BoostedTreesRegressor(estimator.Estimator):
head = _create_regression_head(label_dimension, weight_column)
# HParams for the model.
- tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
- l1_regularization, l2_regularization,
- tree_complexity, min_node_weight, center_bias)
+ tree_hparams = _TreeHParams(
+ n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index f807641057..ec597e4686 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -1508,7 +1508,8 @@ class ModelFnTests(test_util.TensorFlowTestCase):
l2=0.01,
tree_complexity=0.,
min_node_weight=0.,
- center_bias=center_bias)
+ center_bias=center_bias,
+ pruning_mode='none')
estimator_spec = boosted_trees._bt_model_fn( # pylint:disable=protected-access
features=features,
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index 868a4f6b84..f7cbfe0312 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -37,8 +37,19 @@ from tensorflow.python.training import saver
class PruningMode(object):
+ """Class for working with Pruning modes."""
NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
+ _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING}
+
+ @classmethod
+ def from_str(cls, mode):
+ if mode in cls._map:
+ return cls._map[mode]
+ else:
+ raise ValueError('pruning_mode mode must be one of: {}'.format(', '.join(
+ sorted(cls._map))))
+
class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
"""SaveableObject implementation for TreeEnsemble."""
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index 9dbb5d16a4..c23b04b4ef 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
}
member_method {
name: "eval_dir"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 34a30c2874..6878d28fff 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
}
member_method {
name: "eval_dir"