aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt8
-rw-r--r--tensorflow/core/kernels/boosted_trees/stats_ops.cc9
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc1
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt4
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py85
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py3
-rw-r--r--tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py51
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt2
10 files changed, 138 insertions, 45 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 314c54ed00..00356ce0ca 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -36,6 +36,7 @@ class _BoostedTreesEstimator(estimator.Estimator):
l1_regularization=0.,
l2_regularization=0.,
tree_complexity=0.,
+ min_node_weight=0.,
config=None):
"""Initializes a `BoostedTreesEstimator` instance.
@@ -65,13 +66,16 @@ class _BoostedTreesEstimator(estimator.Estimator):
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: minimum hessian a node must have for a split to be
+ considered. The value will be compared with sum(leaf_hessian)/
+ (batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
"""
# pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -96,6 +100,7 @@ def boosted_trees_classifier_train_in_memory(
l1_regularization=0.,
l2_regularization=0.,
tree_complexity=0.,
+ min_node_weight=0.,
config=None,
train_hooks=None):
"""Trains a boosted tree classifier with in memory dataset.
@@ -162,6 +167,9 @@ def boosted_trees_classifier_train_in_memory(
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: minimum hessian a node must have for a split to be
+ considered. The value will be compared with sum(leaf_hessian)/
+ (batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
train_hooks: a list of Hook instances to be passed to estimator.train().
@@ -184,7 +192,7 @@ def boosted_trees_classifier_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
@@ -220,6 +228,7 @@ def boosted_trees_regressor_train_in_memory(
l1_regularization=0.,
l2_regularization=0.,
tree_complexity=0.,
+ min_node_weight=0.,
config=None,
train_hooks=None):
"""Trains a boosted tree regressor with in memory dataset.
@@ -279,6 +288,9 @@ def boosted_trees_regressor_train_in_memory(
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: minimum hessian a node must have for a split to be
+ considered. The value will be compared with sum(leaf_hessian)/
+ (batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
train_hooks: a list of Hook instances to be passed to estimator.train().
@@ -300,7 +312,7 @@ def boosted_trees_regressor_train_in_memory(
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return canned_boosted_trees._bt_model_fn(
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
index 7f18c64574..3f181e91ce 100644
--- a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCalculateBestGainsPerFeature.pbtxt
@@ -31,6 +31,12 @@ END
adjustment to the gain, per leaf based.
END
}
+ in_arg {
+ name: "min_node_weight"
+ description: <<END
+mininum avg of hessians in a node before required for the node to be considered for splitting.
+END
+ }
out_arg {
name: "node_ids_list"
description: <<END
@@ -84,4 +90,4 @@ In this manner, the output is the best split per features and per node, so that
The length of output lists are all of the same length, `num_features`.
The output shapes are compatible in a way that the first dimension of all tensors of all lists are the same and equal to the number of possible split nodes for each feature.
END
-}
+} \ No newline at end of file
diff --git a/tensorflow/core/kernels/boosted_trees/stats_ops.cc b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
index 40f50333d3..6dfcd63ab3 100644
--- a/tensorflow/core/kernels/boosted_trees/stats_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/stats_ops.cc
@@ -60,6 +60,10 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
OP_REQUIRES_OK(context,
context->input("tree_complexity", &tree_complexity_t));
const auto tree_complexity = tree_complexity_t->scalar<float>()();
+ const Tensor* min_node_weight_t;
+ OP_REQUIRES_OK(context,
+ context->input("min_node_weight", &min_node_weight_t));
+ const auto min_node_weight = min_node_weight_t->scalar<float>()();
// Allocate output lists of tensors:
OpOutputList output_node_ids_list;
@@ -105,6 +109,11 @@ class BoostedTreesCalculateBestGainsPerFeatureOp : public OpKernel {
cum_grad.push_back(total_grad);
cum_hess.push_back(total_hess);
}
+ // Check if node has enough of average hessian.
+ if (total_hess < min_node_weight) {
+ // Do not split the node because not enough avg hessian.
+ continue;
+ }
float best_gain = std::numeric_limits<float>::lowest();
float best_bucket = 0;
float best_contrib_for_left = 0.0;
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 4d74e6d63a..88d6eaf819 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -40,6 +40,7 @@ REGISTER_OP("BoostedTreesCalculateBestGainsPerFeature")
.Input("l1: float")
.Input("l2: float")
.Input("tree_complexity: float")
+ .Input("min_node_weight: float")
.Attr("max_splits: int >= 1")
.Attr("num_features: int >= 1") // not passed but populated automatically.
.Output("node_ids_list: num_features * int32")
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 0af560010f..5bd37efac8 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -10867,6 +10867,10 @@ op {
name: "tree_complexity"
type: DT_FLOAT
}
+ input_arg {
+ name: "min_node_weight"
+ type: DT_FLOAT
+ }
output_arg {
name: "node_ids_list"
type: DT_INT32
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index d099d308f5..536bd2bf81 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -40,9 +40,11 @@ from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import tf_export
-_TreeHParams = collections.namedtuple(
- 'TreeHParams',
- ['n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity'])
+# TODO(nponomareva): Reveal pruning params here.
+_TreeHParams = collections.namedtuple('TreeHParams', [
+ 'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity',
+ 'min_node_weight'
+])
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
_HOLD_FOR_MULTI_DIM_SUPPORT = object()
@@ -397,6 +399,7 @@ def _bt_model_fn(
l1=tree_hparams.l1,
l2=tree_hparams.l2,
tree_complexity=tree_hparams.tree_complexity,
+ min_node_weight=tree_hparams.min_node_weight,
max_splits=max_splits))
grow_op = boosted_trees_ops.update_ensemble(
# Confirm if local_tree_ensemble or tree_ensemble should be used.
@@ -515,21 +518,21 @@ def _create_regression_head(label_dimension, weight_column=None):
class BoostedTreesClassifier(estimator.Estimator):
"""A Classifier for Tensorflow Boosted Trees models."""
- def __init__(
- self,
- feature_columns,
- n_batches_per_layer,
- model_dir=None,
- n_classes=_HOLD_FOR_MULTI_CLASS_SUPPORT,
- weight_column=None,
- label_vocabulary=None,
- n_trees=100,
- max_depth=6,
- learning_rate=0.1,
- l1_regularization=0.,
- l2_regularization=0.,
- tree_complexity=0.,
- config=None):
+ def __init__(self,
+ feature_columns,
+ n_batches_per_layer,
+ model_dir=None,
+ n_classes=_HOLD_FOR_MULTI_CLASS_SUPPORT,
+ weight_column=None,
+ label_vocabulary=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ min_node_weight=0.,
+ config=None):
"""Initializes a `BoostedTreesClassifier` instance.
Example:
@@ -593,6 +596,9 @@ class BoostedTreesClassifier(estimator.Estimator):
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: min_node_weight: minimum hessian a node must have for a
+ split to be considered. The value will be compared with
+ sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
Raises:
@@ -606,9 +612,9 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes, weight_column, label_vocabulary=label_vocabulary)
# HParams for the model.
- tree_hparams = _TreeHParams(
- n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
+ l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
@@ -630,20 +636,20 @@ class BoostedTreesClassifier(estimator.Estimator):
class BoostedTreesRegressor(estimator.Estimator):
"""A Regressor for Tensorflow Boosted Trees models."""
- def __init__(
- self,
- feature_columns,
- n_batches_per_layer,
- model_dir=None,
- label_dimension=_HOLD_FOR_MULTI_DIM_SUPPORT,
- weight_column=None,
- n_trees=100,
- max_depth=6,
- learning_rate=0.1,
- l1_regularization=0.,
- l2_regularization=0.,
- tree_complexity=0.,
- config=None):
+ def __init__(self,
+ feature_columns,
+ n_batches_per_layer,
+ model_dir=None,
+ label_dimension=_HOLD_FOR_MULTI_DIM_SUPPORT,
+ weight_column=None,
+ n_trees=100,
+ max_depth=6,
+ learning_rate=0.1,
+ l1_regularization=0.,
+ l2_regularization=0.,
+ tree_complexity=0.,
+ min_node_weight=0.,
+ config=None):
"""Initializes a `BoostedTreesRegressor` instance.
Example:
@@ -700,6 +706,9 @@ class BoostedTreesRegressor(estimator.Estimator):
l2_regularization: regularization multiplier applied to the square weights
of the tree leafs.
tree_complexity: regularization factor to penalize trees with more leaves.
+ min_node_weight: min_node_weight: minimum hessian a node must have for a
+ split to be considered. The value will be compared with
+ sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
Raises:
@@ -712,9 +721,9 @@ class BoostedTreesRegressor(estimator.Estimator):
head = _create_regression_head(label_dimension, weight_column)
# HParams for the model.
- tree_hparams = _TreeHParams(
- n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
- tree_complexity)
+ tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
+ l1_regularization, l2_regularization,
+ tree_complexity, min_node_weight)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 7823ef8410..56e67a6707 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -188,7 +188,8 @@ class ModelFnTests(test_util.TensorFlowTestCase):
learning_rate=0.1,
l1=0.,
l2=0.01,
- tree_complexity=0.)
+ tree_complexity=0.,
+ min_node_weight=0.)
def _get_expected_ensembles_for_classification(self):
first_round = """
diff --git a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
index 4d09cf94d4..f0bb84e69a 100644
--- a/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
+++ b/tensorflow/python/kernel_tests/boosted_trees/stats_ops_test.py
@@ -59,6 +59,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=0.0,
l2=0.0,
tree_complexity=0.0,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
@@ -106,6 +107,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=0.0,
l2=0.1,
tree_complexity=0.0,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
@@ -154,6 +156,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=l1,
l2=0.0,
tree_complexity=0.0,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[0, 1], [1, 1]], sess.run(thresholds_list))
@@ -205,6 +208,7 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
l1=0.0,
l2=l2,
tree_complexity=tree_complexity,
+ min_node_weight=0,
max_splits=max_splits)
self.assertAllEqual([[1, 2], [1, 2]], sess.run(node_ids_list))
@@ -220,6 +224,53 @@ class StatsOpsTest(test_util.TensorFlowTestCase):
self.assertAllClose([[[-.424658], [-.6]], [[-.043478], [.485294]]],
sess.run(right_node_contribs_list))
+ def testCalculateBestGainsWithMinNodeWEight(self):
+ """Testing Gain calculation without any regularization."""
+ with self.test_session() as sess:
+ max_splits = 7
+ node_id_range = [1, 3] # node 1 through 2 will be processed.
+ stats_summary_list = [
+ [
+ [[0., 0.], [.08, .09], [0., 0.], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.15, .036], [.06, .07], [.1, .2]], # node 1
+ [[0., 0.], [-.33, .68], [0., 0.], [.3, .4]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 0
+ [
+ [[0., 0.], [0., 0.], [.08, .09], [0., 0.]], # node 0; ignored
+ [[0., 0.], [.3, .5], [-.05, .6], [.06, .07]], # node 1
+ [[.1, .1], [.2, .03], [-.4, .05], [.07, .08]], # node 2
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 3; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 4; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 5; ignored
+ [[0., 0.], [0., 0.], [0., 0.], [0., 0.]], # node 6; ignored
+ ], # feature 1
+ ] # num_features * shape=[max_splits, num_buckets, 2]
+
+ (node_ids_list, gains_list, thresholds_list, left_node_contribs_list,
+ right_node_contribs_list
+ ) = boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range,
+ stats_summary_list,
+ l1=0.0,
+ l2=0.0,
+ tree_complexity=0.0,
+ min_node_weight=1,
+ max_splits=max_splits)
+
+ # We can't split node 1 on feature 1 and node 2 on feature 2 because of
+ # the min node weight.
+ self.assertAllEqual([[2], [1]], sess.run(node_ids_list))
+ self.assertAllClose([[0.384314], [0.098013]], sess.run(gains_list))
+ self.assertAllEqual([[1], [1]], sess.run(thresholds_list))
+ self.assertAllClose([[[0.4852941]], [[-.6]]],
+ sess.run(left_node_contribs_list))
+ self.assertAllClose([[[-0.75]], [[-0.014925]]],
+ sess.run(right_node_contribs_list))
+
def testMakeStatsSummarySimple(self):
"""Simple test for MakeStatsSummary."""
with self.test_session():
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
index fd9be8c759..53a903c239 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
}
member_method {
name: "evaluate"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
index 6b305be43f..ba17c90de2 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
@@ -21,7 +21,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\'], "
}
member_method {
name: "evaluate"