aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/boosted_trees.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees.py')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py420
1 files changed, 310 insertions, 110 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 8afef1b65a..3292e2724d 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
+import functools
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -44,12 +46,13 @@ from tensorflow.python.util.tf_export import estimator_export
# TODO(nponomareva): Reveal pruning params here.
_TreeHParams = collections.namedtuple('TreeHParams', [
'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity',
- 'min_node_weight'
+ 'min_node_weight', 'center_bias'
])
_HOLD_FOR_MULTI_CLASS_SUPPORT = object()
_HOLD_FOR_MULTI_DIM_SUPPORT = object()
_DUMMY_NUM_BUCKETS = -1
+_DUMMY_NODE_ID = -1
def _get_transformed_features(features, sorted_feature_columns):
@@ -279,7 +282,9 @@ class _CacheTrainingStatesUsingHashTable(object):
"""Returns cached_tree_ids, cached_node_ids, cached_logits."""
cached_tree_ids, cached_node_ids, cached_logits = array_ops.split(
lookup_ops.lookup_table_find_v2(
- self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]),
+ self._table_ref,
+ self._example_ids,
+ default_value=[0.0, _DUMMY_NODE_ID, 0.0]),
[1, 1, self._logits_dimension],
axis=1)
cached_tree_ids = array_ops.squeeze(
@@ -330,7 +335,7 @@ class _CacheTrainingStatesUsingVariables(object):
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
self._node_ids = _local_variable(
- array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -380,6 +385,249 @@ class _StopAtAttemptsHook(session_run_hook.SessionRunHook):
run_context.request_stop()
+def _get_max_splits(tree_hparams):
+ """Calculates the max possible number of splits based on tree params."""
+ # maximum number of splits possible in the whole tree =2^(D-1)-1
+ max_splits = (1 << tree_hparams.max_depth) - 1
+ return max_splits
+
+
+class _EnsembleGrower(object):
+ """Abstract base class for different types of ensemble growers.
+
+ Use it to receive training ops for growing and centering bias, depending
+ on the implementation (for example, in memory or accumulator-based
+ distributed):
+ grower = ...create subclass grower(tree_ensemble, tree_hparams)
+ grow_op = grower.grow_tree(stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range)
+ training_ops.append(grow_op)
+ """
+
+ def __init__(self, tree_ensemble, tree_hparams):
+ """Initializes a grower object.
+
+ Args:
+ tree_ensemble: A TreeEnsemble variable.
+ tree_hparams: TODO. collections.namedtuple for hyper parameters.
+ """
+ self._tree_ensemble = tree_ensemble
+ self._tree_hparams = tree_hparams
+
+ @abc.abstractmethod
+ def center_bias(self, center_bias_var, gradients, hessians):
+ """Centers bias, if ready, based on statistics.
+
+ Args:
+ center_bias_var: A variable that will be updated when bias centering
+ finished.
+ gradients: A rank 2 tensor of gradients.
+ hessians: A rank 2 tensor of hessians.
+
+ Returns:
+ An operation for centering bias.
+ """
+
+ @abc.abstractmethod
+ def grow_tree(self, stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range):
+ """Grows a tree, if ready, based on provided statistics.
+
+ Args:
+ stats_summaries_list: List of stats summary tensors, representing sums of
+ gradients and hessians for each feature bucket.
+ feature_ids_list: a list of lists of feature ids for each bucket size.
+ last_layer_nodes_range: A tensor representing ids of the nodes in the
+ current layer, to be split.
+
+ Returns:
+ An op for growing a tree.
+ """
+
+ # ============= Helper methods ===========
+
+ def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians):
+ """Updates the ensembles and cache (if needed) with logits prior."""
+ continue_centering = boosted_trees_ops.center_bias(
+ self._tree_ensemble.resource_handle,
+ mean_gradients=mean_gradients,
+ mean_hessians=mean_hessians,
+ l1=self._tree_hparams.l1,
+ l2=self._tree_hparams.l2)
+ return center_bias_var.assign(continue_centering)
+
+ def _grow_tree_from_stats_summaries(self, stats_summaries_list,
+ feature_ids_list, last_layer_nodes_range):
+ """Updates ensemble based on the best gains from stats summaries."""
+ node_ids_per_feature = []
+ gains_list = []
+ thresholds_list = []
+ left_node_contribs_list = []
+ right_node_contribs_list = []
+ all_feature_ids = []
+ assert len(stats_summaries_list) == len(feature_ids_list)
+
+ max_splits = _get_max_splits(self._tree_hparams)
+
+ for i, feature_ids in enumerate(feature_ids_list):
+ (numeric_node_ids_per_feature, numeric_gains_list,
+ numeric_thresholds_list, numeric_left_node_contribs_list,
+ numeric_right_node_contribs_list) = (
+ boosted_trees_ops.calculate_best_gains_per_feature(
+ node_id_range=last_layer_nodes_range,
+ stats_summary_list=stats_summaries_list[i],
+ l1=self._tree_hparams.l1,
+ l2=self._tree_hparams.l2,
+ tree_complexity=self._tree_hparams.tree_complexity,
+ min_node_weight=self._tree_hparams.min_node_weight,
+ max_splits=max_splits))
+
+ all_feature_ids += feature_ids
+ node_ids_per_feature += numeric_node_ids_per_feature
+ gains_list += numeric_gains_list
+ thresholds_list += numeric_thresholds_list
+ left_node_contribs_list += numeric_left_node_contribs_list
+ right_node_contribs_list += numeric_right_node_contribs_list
+
+ grow_op = boosted_trees_ops.update_ensemble(
+ # Confirm if local_tree_ensemble or tree_ensemble should be used.
+ self._tree_ensemble.resource_handle,
+ feature_ids=all_feature_ids,
+ node_ids=node_ids_per_feature,
+ gains=gains_list,
+ thresholds=thresholds_list,
+ left_node_contribs=left_node_contribs_list,
+ right_node_contribs=right_node_contribs_list,
+ learning_rate=self._tree_hparams.learning_rate,
+ max_depth=self._tree_hparams.max_depth,
+ pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
+ return grow_op
+
+
+class _InMemoryEnsembleGrower(_EnsembleGrower):
+ """A base class for ensemble growers."""
+
+ def __init__(self, tree_ensemble, tree_hparams):
+
+ super(_InMemoryEnsembleGrower, self).__init__(
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+
+ def center_bias(self, center_bias_var, gradients, hessians):
+ # For in memory, we already have a full batch of gradients and hessians,
+ # so just take a mean and proceed with centering.
+ mean_gradients = array_ops.expand_dims(
+ math_ops.reduce_mean(gradients, 0), 0)
+ mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0)
+ return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians)
+
+ def grow_tree(self, stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range):
+ # For in memory, we already have full data in one batch, so we can grow the
+ # tree immediately.
+ return self._grow_tree_from_stats_summaries(
+ stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+
+
+class _AccumulatorEnsembleGrower(_EnsembleGrower):
+ """A base class for ensemble growers."""
+
+ def __init__(self, tree_ensemble, tree_hparams, stamp_token,
+ n_batches_per_layer, bucket_size_list, is_chief):
+ super(_AccumulatorEnsembleGrower, self).__init__(
+ tree_ensemble=tree_ensemble, tree_hparams=tree_hparams)
+ self._stamp_token = stamp_token
+ self._n_batches_per_layer = n_batches_per_layer
+ self._bucket_size_list = bucket_size_list
+ self._is_chief = is_chief
+
+ def center_bias(self, center_bias_var, gradients, hessians):
+ # For not in memory situation, we need to accumulate enough of batches first
+ # before proceeding with centering bias.
+
+ # Create an accumulator.
+ bias_dependencies = []
+ bias_accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians means only.
+ # TODO(nponomareva): this will change for a multiclass
+ shape=[2, 1],
+ shared_name='bias_accumulator')
+
+ grads_and_hess = array_ops.stack([gradients, hessians], axis=0)
+ grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1)
+
+ apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token)
+ bias_dependencies.append(apply_grad)
+
+ # Center bias if enough batches were processed.
+ with ops.control_dependencies(bias_dependencies):
+ if not self._is_chief:
+ return control_flow_ops.no_op()
+
+ def center_bias_from_accumulator():
+ accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0)
+ return self._center_bias_fn(center_bias_var,
+ array_ops.expand_dims(accumulated[0], 0),
+ array_ops.expand_dims(accumulated[1], 0))
+
+ center_bias_op = control_flow_ops.cond(
+ math_ops.greater_equal(bias_accumulator.num_accumulated(),
+ self._n_batches_per_layer),
+ center_bias_from_accumulator,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_for_bias_accumulated')
+ return center_bias_op
+
+ def grow_tree(self, stats_summaries_list, feature_ids_list,
+ last_layer_nodes_range):
+ # For not in memory situation, we need to accumulate enough of batches first
+ # before proceeding with building a tree layer.
+ max_splits = _get_max_splits(self._tree_hparams)
+
+ # Prepare accumulators.
+ accumulators = []
+ dependencies = []
+ for i, feature_ids in enumerate(feature_ids_list):
+ stats_summaries = stats_summaries_list[i]
+ accumulator = data_flow_ops.ConditionalAccumulator(
+ dtype=dtypes.float32,
+ # The stats consist of grads and hessians (the last dimension).
+ shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2],
+ shared_name='numeric_stats_summary_accumulator_' + str(i))
+ accumulators.append(accumulator)
+
+ apply_grad = accumulator.apply_grad(
+ array_ops.stack(stats_summaries, axis=0), self._stamp_token)
+ dependencies.append(apply_grad)
+
+ # Grow the tree if enough batches is accumulated.
+ with ops.control_dependencies(dependencies):
+ if not self._is_chief:
+ return control_flow_ops.no_op()
+
+ min_accumulated = math_ops.reduce_min(
+ array_ops.stack([acc.num_accumulated() for acc in accumulators]))
+
+ def grow_tree_from_accumulated_summaries_fn():
+ """Updates tree with the best layer from accumulated summaries."""
+ # Take out the accumulated summaries from the accumulator and grow.
+ stats_summaries_list = []
+ stats_summaries_list = [
+ array_ops.unstack(accumulator.take_grad(1), axis=0)
+ for accumulator in accumulators
+ ]
+ grow_op = self._grow_tree_from_stats_summaries(
+ stats_summaries_list, feature_ids_list, last_layer_nodes_range)
+ return grow_op
+
+ grow_model = control_flow_ops.cond(
+ math_ops.greater_equal(min_accumulated, self._n_batches_per_layer),
+ grow_tree_from_accumulated_summaries_fn,
+ control_flow_ops.no_op,
+ name='wait_until_n_batches_accumulated')
+ return grow_model
+
+
def _bt_model_fn(
features,
labels,
@@ -425,8 +673,8 @@ def _bt_model_fn(
ValueError: mode or params are invalid, or features has the wrong type.
"""
is_single_machine = (config.num_worker_replicas <= 1)
-
sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
+ center_bias = tree_hparams.center_bias
if train_in_memory:
assert n_batches_per_layer == 1, (
'When train_in_memory is enabled, input_fn should return the entire '
@@ -437,11 +685,6 @@ def _bt_model_fn(
raise ValueError('train_in_memory is supported only for '
'non-distributed training.')
worker_device = control_flow_ops.no_op().device
- # maximum number of splits possible in the whole tree =2^(D-1)-1
- # TODO(youngheek): perhaps storage could be optimized by storing stats with
- # the dimension max_splits_per_layer, instead of max_splits (for the entire
- # tree).
- max_splits = (1 << tree_hparams.max_depth) - 1
train_op = []
with ops.name_scope(name) as name:
# Prepare.
@@ -469,6 +712,9 @@ def _bt_model_fn(
# Create Ensemble resources.
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ # Variable that determines whether bias centering is needed.
+ center_bias_var = variable_scope.variable(
+ initial_value=center_bias, name='center_bias_needed', trainable=False)
# Create logits.
if mode != model_fn.ModeKeys.TRAIN:
logits = boosted_trees_ops.predict(
@@ -489,6 +735,7 @@ def _bt_model_fn(
# TODO(soroush): Do partial updates if this becomes a bottleneck.
ensemble_reload = local_tree_ensemble.deserialize(
*tree_ensemble.serialize())
+
if training_state_cache:
cached_tree_ids, cached_node_ids, cached_logits = (
training_state_cache.lookup())
@@ -497,9 +744,10 @@ def _bt_model_fn(
batch_size = array_ops.shape(labels)[0]
cached_tree_ids, cached_node_ids, cached_logits = (
array_ops.zeros([batch_size], dtype=dtypes.int32),
- array_ops.zeros([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
array_ops.zeros(
[batch_size, head.logits_dimension], dtype=dtypes.float32))
+
with ops.control_dependencies([ensemble_reload]):
(stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
last_layer_nodes_range) = local_tree_ensemble.get_states()
@@ -513,13 +761,20 @@ def _bt_model_fn(
cached_node_ids=cached_node_ids,
bucketized_features=input_feature_list,
logits_dimension=head.logits_dimension)
+
logits = cached_logits + partial_logits
# Create training graph.
def _train_op_fn(loss):
"""Run one training iteration."""
if training_state_cache:
- train_op.append(training_state_cache.insert(tree_ids, node_ids, logits))
+ # Cache logits only after center_bias is complete, if it's in progress.
+ train_op.append(
+ control_flow_ops.cond(
+ center_bias_var, control_flow_ops.no_op,
+ lambda: training_state_cache.insert(tree_ids, node_ids, logits))
+ )
+
if closed_form_grad_and_hess_fn:
gradients, hessians = closed_form_grad_and_hess_fn(logits, labels)
else:
@@ -527,6 +782,11 @@ def _bt_model_fn(
hessians = gradients_impl.gradients(
gradients, logits, name='Hessians')[0]
+ # TODO(youngheek): perhaps storage could be optimized by storing stats
+ # with the dimension max_splits_per_layer, instead of max_splits (for the
+ # entire tree).
+ max_splits = _get_max_splits(tree_hparams)
+
stats_summaries_list = []
for i, feature_ids in enumerate(feature_ids_list):
num_buckets = bucket_size_list[i]
@@ -543,103 +803,28 @@ def _bt_model_fn(
]
stats_summaries_list.append(summaries)
- accumulators = []
-
- def grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list):
- """Updates ensemble based on the best gains from stats summaries."""
- node_ids_per_feature = []
- gains_list = []
- thresholds_list = []
- left_node_contribs_list = []
- right_node_contribs_list = []
- all_feature_ids = []
-
- assert len(stats_summaries_list) == len(feature_ids_list)
-
- for i, feature_ids in enumerate(feature_ids_list):
- (numeric_node_ids_per_feature, numeric_gains_list,
- numeric_thresholds_list, numeric_left_node_contribs_list,
- numeric_right_node_contribs_list) = (
- boosted_trees_ops.calculate_best_gains_per_feature(
- node_id_range=last_layer_nodes_range,
- stats_summary_list=stats_summaries_list[i],
- l1=tree_hparams.l1,
- l2=tree_hparams.l2,
- tree_complexity=tree_hparams.tree_complexity,
- min_node_weight=tree_hparams.min_node_weight,
- max_splits=max_splits))
-
- all_feature_ids += feature_ids
- node_ids_per_feature += numeric_node_ids_per_feature
- gains_list += numeric_gains_list
- thresholds_list += numeric_thresholds_list
- left_node_contribs_list += numeric_left_node_contribs_list
- right_node_contribs_list += numeric_right_node_contribs_list
-
- grow_op = boosted_trees_ops.update_ensemble(
- # Confirm if local_tree_ensemble or tree_ensemble should be used.
- tree_ensemble.resource_handle,
- feature_ids=all_feature_ids,
- node_ids=node_ids_per_feature,
- gains=gains_list,
- thresholds=thresholds_list,
- left_node_contribs=left_node_contribs_list,
- right_node_contribs=right_node_contribs_list,
- learning_rate=tree_hparams.learning_rate,
- max_depth=tree_hparams.max_depth,
- pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
- return grow_op
-
if train_in_memory and is_single_machine:
- train_op.append(distribute_lib.increment_var(global_step))
- train_op.append(
- grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list))
+ grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams)
else:
- dependencies = []
-
- for i, feature_ids in enumerate(feature_ids_list):
- stats_summaries = stats_summaries_list[i]
- accumulator = data_flow_ops.ConditionalAccumulator(
- dtype=dtypes.float32,
- # The stats consist of grads and hessians (the last dimension).
- shape=[len(feature_ids), max_splits, bucket_size_list[i], 2],
- shared_name='numeric_stats_summary_accumulator_' + str(i))
- accumulators.append(accumulator)
-
- apply_grad = accumulator.apply_grad(
- array_ops.stack(stats_summaries, axis=0), stamp_token)
- dependencies.append(apply_grad)
-
- def grow_tree_from_accumulated_summaries_fn():
- """Updates the tree with the best layer from accumulated summaries."""
- # Take out the accumulated summaries from the accumulator and grow.
- stats_summaries_list = []
-
- stats_summaries_list = [
- array_ops.unstack(accumulator.take_grad(1), axis=0)
- for accumulator in accumulators
- ]
-
- grow_op = grow_tree_from_stats_summaries(stats_summaries_list,
- feature_ids_list)
- return grow_op
-
- with ops.control_dependencies(dependencies):
- train_op.append(distribute_lib.increment_var(global_step))
- if config.is_chief:
- min_accumulated = math_ops.reduce_min(
- array_ops.stack(
- [acc.num_accumulated() for acc in accumulators]))
-
- train_op.append(
- control_flow_ops.cond(
- math_ops.greater_equal(min_accumulated,
- n_batches_per_layer),
- grow_tree_from_accumulated_summaries_fn,
- control_flow_ops.no_op,
- name='wait_until_n_batches_accumulated'))
+ grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams,
+ stamp_token, n_batches_per_layer,
+ bucket_size_list, config.is_chief)
+
+ update_model = control_flow_ops.cond(
+ center_bias_var,
+ functools.partial(
+ grower.center_bias,
+ center_bias_var,
+ gradients,
+ hessians,
+ ),
+ functools.partial(grower.grow_tree, stats_summaries_list,
+ feature_ids_list, last_layer_nodes_range))
+ train_op.append(update_model)
+
+ with ops.control_dependencies([update_model]):
+ increment_global = distribute_lib.increment_var(global_step)
+ train_op.append(increment_global)
return control_flow_ops.group(train_op, name='train_op')
@@ -739,7 +924,8 @@ class BoostedTreesClassifier(estimator.Estimator):
l2_regularization=0.,
tree_complexity=0.,
min_node_weight=0.,
- config=None):
+ config=None,
+ center_bias=False):
"""Initializes a `BoostedTreesClassifier` instance.
Example:
@@ -807,6 +993,13 @@ class BoostedTreesClassifier(estimator.Estimator):
split to be considered. The value will be compared with
sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -821,7 +1014,7 @@ class BoostedTreesClassifier(estimator.Estimator):
# HParams for the model.
tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access
@@ -864,7 +1057,8 @@ class BoostedTreesRegressor(estimator.Estimator):
l2_regularization=0.,
tree_complexity=0.,
min_node_weight=0.,
- config=None):
+ config=None,
+ center_bias=False):
"""Initializes a `BoostedTreesRegressor` instance.
Example:
@@ -925,6 +1119,12 @@ class BoostedTreesRegressor(estimator.Estimator):
split to be considered. The value will be compared with
sum(leaf_hessian)/(batch_size * n_batches_per_layer).
config: `RunConfig` object to configure the runtime settings.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
Raises:
ValueError: when wrong arguments are given or unsupported functionalities
@@ -938,7 +1138,7 @@ class BoostedTreesRegressor(estimator.Estimator):
# HParams for the model.
tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate,
l1_regularization, l2_regularization,
- tree_complexity, min_node_weight)
+ tree_complexity, min_node_weight, center_bias)
def _model_fn(features, labels, mode, config):
return _bt_model_fn( # pylint: disable=protected-access