diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees.py')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 420 |
1 files changed, 310 insertions, 110 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 8afef1b65a..3292e2724d 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -17,7 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections +import functools from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn @@ -44,12 +46,13 @@ from tensorflow.python.util.tf_export import estimator_export # TODO(nponomareva): Reveal pruning params here. _TreeHParams = collections.namedtuple('TreeHParams', [ 'n_trees', 'max_depth', 'learning_rate', 'l1', 'l2', 'tree_complexity', - 'min_node_weight' + 'min_node_weight', 'center_bias' ]) _HOLD_FOR_MULTI_CLASS_SUPPORT = object() _HOLD_FOR_MULTI_DIM_SUPPORT = object() _DUMMY_NUM_BUCKETS = -1 +_DUMMY_NODE_ID = -1 def _get_transformed_features(features, sorted_feature_columns): @@ -279,7 +282,9 @@ class _CacheTrainingStatesUsingHashTable(object): """Returns cached_tree_ids, cached_node_ids, cached_logits.""" cached_tree_ids, cached_node_ids, cached_logits = array_ops.split( lookup_ops.lookup_table_find_v2( - self._table_ref, self._example_ids, default_value=[0.0, 0.0, 0.0]), + self._table_ref, + self._example_ids, + default_value=[0.0, _DUMMY_NODE_ID, 0.0]), [1, 1, self._logits_dimension], axis=1) cached_tree_ids = array_ops.squeeze( @@ -330,7 +335,7 @@ class _CacheTrainingStatesUsingVariables(object): array_ops.zeros([batch_size], dtype=dtypes.int32), name='tree_ids_cache') self._node_ids = _local_variable( - array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32), name='node_ids_cache') self._logits = _local_variable( array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32), @@ -380,6 +385,249 @@ class _StopAtAttemptsHook(session_run_hook.SessionRunHook): run_context.request_stop() +def _get_max_splits(tree_hparams): + """Calculates the max possible number of splits based on tree params.""" + # maximum number of splits possible in the whole tree =2^(D-1)-1 + max_splits = (1 << tree_hparams.max_depth) - 1 + return max_splits + + +class _EnsembleGrower(object): + """Abstract base class for different types of ensemble growers. + + Use it to receive training ops for growing and centering bias, depending + on the implementation (for example, in memory or accumulator-based + distributed): + grower = ...create subclass grower(tree_ensemble, tree_hparams) + grow_op = grower.grow_tree(stats_summaries_list, feature_ids_list, + last_layer_nodes_range) + training_ops.append(grow_op) + """ + + def __init__(self, tree_ensemble, tree_hparams): + """Initializes a grower object. + + Args: + tree_ensemble: A TreeEnsemble variable. + tree_hparams: TODO. collections.namedtuple for hyper parameters. + """ + self._tree_ensemble = tree_ensemble + self._tree_hparams = tree_hparams + + @abc.abstractmethod + def center_bias(self, center_bias_var, gradients, hessians): + """Centers bias, if ready, based on statistics. + + Args: + center_bias_var: A variable that will be updated when bias centering + finished. + gradients: A rank 2 tensor of gradients. + hessians: A rank 2 tensor of hessians. + + Returns: + An operation for centering bias. + """ + + @abc.abstractmethod + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + """Grows a tree, if ready, based on provided statistics. + + Args: + stats_summaries_list: List of stats summary tensors, representing sums of + gradients and hessians for each feature bucket. + feature_ids_list: a list of lists of feature ids for each bucket size. + last_layer_nodes_range: A tensor representing ids of the nodes in the + current layer, to be split. + + Returns: + An op for growing a tree. + """ + + # ============= Helper methods =========== + + def _center_bias_fn(self, center_bias_var, mean_gradients, mean_hessians): + """Updates the ensembles and cache (if needed) with logits prior.""" + continue_centering = boosted_trees_ops.center_bias( + self._tree_ensemble.resource_handle, + mean_gradients=mean_gradients, + mean_hessians=mean_hessians, + l1=self._tree_hparams.l1, + l2=self._tree_hparams.l2) + return center_bias_var.assign(continue_centering) + + def _grow_tree_from_stats_summaries(self, stats_summaries_list, + feature_ids_list, last_layer_nodes_range): + """Updates ensemble based on the best gains from stats summaries.""" + node_ids_per_feature = [] + gains_list = [] + thresholds_list = [] + left_node_contribs_list = [] + right_node_contribs_list = [] + all_feature_ids = [] + assert len(stats_summaries_list) == len(feature_ids_list) + + max_splits = _get_max_splits(self._tree_hparams) + + for i, feature_ids in enumerate(feature_ids_list): + (numeric_node_ids_per_feature, numeric_gains_list, + numeric_thresholds_list, numeric_left_node_contribs_list, + numeric_right_node_contribs_list) = ( + boosted_trees_ops.calculate_best_gains_per_feature( + node_id_range=last_layer_nodes_range, + stats_summary_list=stats_summaries_list[i], + l1=self._tree_hparams.l1, + l2=self._tree_hparams.l2, + tree_complexity=self._tree_hparams.tree_complexity, + min_node_weight=self._tree_hparams.min_node_weight, + max_splits=max_splits)) + + all_feature_ids += feature_ids + node_ids_per_feature += numeric_node_ids_per_feature + gains_list += numeric_gains_list + thresholds_list += numeric_thresholds_list + left_node_contribs_list += numeric_left_node_contribs_list + right_node_contribs_list += numeric_right_node_contribs_list + + grow_op = boosted_trees_ops.update_ensemble( + # Confirm if local_tree_ensemble or tree_ensemble should be used. + self._tree_ensemble.resource_handle, + feature_ids=all_feature_ids, + node_ids=node_ids_per_feature, + gains=gains_list, + thresholds=thresholds_list, + left_node_contribs=left_node_contribs_list, + right_node_contribs=right_node_contribs_list, + learning_rate=self._tree_hparams.learning_rate, + max_depth=self._tree_hparams.max_depth, + pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) + return grow_op + + +class _InMemoryEnsembleGrower(_EnsembleGrower): + """A base class for ensemble growers.""" + + def __init__(self, tree_ensemble, tree_hparams): + + super(_InMemoryEnsembleGrower, self).__init__( + tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) + + def center_bias(self, center_bias_var, gradients, hessians): + # For in memory, we already have a full batch of gradients and hessians, + # so just take a mean and proceed with centering. + mean_gradients = array_ops.expand_dims( + math_ops.reduce_mean(gradients, 0), 0) + mean_heassians = array_ops.expand_dims(math_ops.reduce_mean(hessians, 0), 0) + return self._center_bias_fn(center_bias_var, mean_gradients, mean_heassians) + + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + # For in memory, we already have full data in one batch, so we can grow the + # tree immediately. + return self._grow_tree_from_stats_summaries( + stats_summaries_list, feature_ids_list, last_layer_nodes_range) + + +class _AccumulatorEnsembleGrower(_EnsembleGrower): + """A base class for ensemble growers.""" + + def __init__(self, tree_ensemble, tree_hparams, stamp_token, + n_batches_per_layer, bucket_size_list, is_chief): + super(_AccumulatorEnsembleGrower, self).__init__( + tree_ensemble=tree_ensemble, tree_hparams=tree_hparams) + self._stamp_token = stamp_token + self._n_batches_per_layer = n_batches_per_layer + self._bucket_size_list = bucket_size_list + self._is_chief = is_chief + + def center_bias(self, center_bias_var, gradients, hessians): + # For not in memory situation, we need to accumulate enough of batches first + # before proceeding with centering bias. + + # Create an accumulator. + bias_dependencies = [] + bias_accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians means only. + # TODO(nponomareva): this will change for a multiclass + shape=[2, 1], + shared_name='bias_accumulator') + + grads_and_hess = array_ops.stack([gradients, hessians], axis=0) + grads_and_hess = math_ops.reduce_mean(grads_and_hess, axis=1) + + apply_grad = bias_accumulator.apply_grad(grads_and_hess, self._stamp_token) + bias_dependencies.append(apply_grad) + + # Center bias if enough batches were processed. + with ops.control_dependencies(bias_dependencies): + if not self._is_chief: + return control_flow_ops.no_op() + + def center_bias_from_accumulator(): + accumulated = array_ops.unstack(bias_accumulator.take_grad(1), axis=0) + return self._center_bias_fn(center_bias_var, + array_ops.expand_dims(accumulated[0], 0), + array_ops.expand_dims(accumulated[1], 0)) + + center_bias_op = control_flow_ops.cond( + math_ops.greater_equal(bias_accumulator.num_accumulated(), + self._n_batches_per_layer), + center_bias_from_accumulator, + control_flow_ops.no_op, + name='wait_until_n_batches_for_bias_accumulated') + return center_bias_op + + def grow_tree(self, stats_summaries_list, feature_ids_list, + last_layer_nodes_range): + # For not in memory situation, we need to accumulate enough of batches first + # before proceeding with building a tree layer. + max_splits = _get_max_splits(self._tree_hparams) + + # Prepare accumulators. + accumulators = [] + dependencies = [] + for i, feature_ids in enumerate(feature_ids_list): + stats_summaries = stats_summaries_list[i] + accumulator = data_flow_ops.ConditionalAccumulator( + dtype=dtypes.float32, + # The stats consist of grads and hessians (the last dimension). + shape=[len(feature_ids), max_splits, self._bucket_size_list[i], 2], + shared_name='numeric_stats_summary_accumulator_' + str(i)) + accumulators.append(accumulator) + + apply_grad = accumulator.apply_grad( + array_ops.stack(stats_summaries, axis=0), self._stamp_token) + dependencies.append(apply_grad) + + # Grow the tree if enough batches is accumulated. + with ops.control_dependencies(dependencies): + if not self._is_chief: + return control_flow_ops.no_op() + + min_accumulated = math_ops.reduce_min( + array_ops.stack([acc.num_accumulated() for acc in accumulators])) + + def grow_tree_from_accumulated_summaries_fn(): + """Updates tree with the best layer from accumulated summaries.""" + # Take out the accumulated summaries from the accumulator and grow. + stats_summaries_list = [] + stats_summaries_list = [ + array_ops.unstack(accumulator.take_grad(1), axis=0) + for accumulator in accumulators + ] + grow_op = self._grow_tree_from_stats_summaries( + stats_summaries_list, feature_ids_list, last_layer_nodes_range) + return grow_op + + grow_model = control_flow_ops.cond( + math_ops.greater_equal(min_accumulated, self._n_batches_per_layer), + grow_tree_from_accumulated_summaries_fn, + control_flow_ops.no_op, + name='wait_until_n_batches_accumulated') + return grow_model + + def _bt_model_fn( features, labels, @@ -425,8 +673,8 @@ def _bt_model_fn( ValueError: mode or params are invalid, or features has the wrong type. """ is_single_machine = (config.num_worker_replicas <= 1) - sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name) + center_bias = tree_hparams.center_bias if train_in_memory: assert n_batches_per_layer == 1, ( 'When train_in_memory is enabled, input_fn should return the entire ' @@ -437,11 +685,6 @@ def _bt_model_fn( raise ValueError('train_in_memory is supported only for ' 'non-distributed training.') worker_device = control_flow_ops.no_op().device - # maximum number of splits possible in the whole tree =2^(D-1)-1 - # TODO(youngheek): perhaps storage could be optimized by storing stats with - # the dimension max_splits_per_layer, instead of max_splits (for the entire - # tree). - max_splits = (1 << tree_hparams.max_depth) - 1 train_op = [] with ops.name_scope(name) as name: # Prepare. @@ -469,6 +712,9 @@ def _bt_model_fn( # Create Ensemble resources. tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) + # Variable that determines whether bias centering is needed. + center_bias_var = variable_scope.variable( + initial_value=center_bias, name='center_bias_needed', trainable=False) # Create logits. if mode != model_fn.ModeKeys.TRAIN: logits = boosted_trees_ops.predict( @@ -489,6 +735,7 @@ def _bt_model_fn( # TODO(soroush): Do partial updates if this becomes a bottleneck. ensemble_reload = local_tree_ensemble.deserialize( *tree_ensemble.serialize()) + if training_state_cache: cached_tree_ids, cached_node_ids, cached_logits = ( training_state_cache.lookup()) @@ -497,9 +744,10 @@ def _bt_model_fn( batch_size = array_ops.shape(labels)[0] cached_tree_ids, cached_node_ids, cached_logits = ( array_ops.zeros([batch_size], dtype=dtypes.int32), - array_ops.zeros([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), array_ops.zeros( [batch_size, head.logits_dimension], dtype=dtypes.float32)) + with ops.control_dependencies([ensemble_reload]): (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, last_layer_nodes_range) = local_tree_ensemble.get_states() @@ -513,13 +761,20 @@ def _bt_model_fn( cached_node_ids=cached_node_ids, bucketized_features=input_feature_list, logits_dimension=head.logits_dimension) + logits = cached_logits + partial_logits # Create training graph. def _train_op_fn(loss): """Run one training iteration.""" if training_state_cache: - train_op.append(training_state_cache.insert(tree_ids, node_ids, logits)) + # Cache logits only after center_bias is complete, if it's in progress. + train_op.append( + control_flow_ops.cond( + center_bias_var, control_flow_ops.no_op, + lambda: training_state_cache.insert(tree_ids, node_ids, logits)) + ) + if closed_form_grad_and_hess_fn: gradients, hessians = closed_form_grad_and_hess_fn(logits, labels) else: @@ -527,6 +782,11 @@ def _bt_model_fn( hessians = gradients_impl.gradients( gradients, logits, name='Hessians')[0] + # TODO(youngheek): perhaps storage could be optimized by storing stats + # with the dimension max_splits_per_layer, instead of max_splits (for the + # entire tree). + max_splits = _get_max_splits(tree_hparams) + stats_summaries_list = [] for i, feature_ids in enumerate(feature_ids_list): num_buckets = bucket_size_list[i] @@ -543,103 +803,28 @@ def _bt_model_fn( ] stats_summaries_list.append(summaries) - accumulators = [] - - def grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list): - """Updates ensemble based on the best gains from stats summaries.""" - node_ids_per_feature = [] - gains_list = [] - thresholds_list = [] - left_node_contribs_list = [] - right_node_contribs_list = [] - all_feature_ids = [] - - assert len(stats_summaries_list) == len(feature_ids_list) - - for i, feature_ids in enumerate(feature_ids_list): - (numeric_node_ids_per_feature, numeric_gains_list, - numeric_thresholds_list, numeric_left_node_contribs_list, - numeric_right_node_contribs_list) = ( - boosted_trees_ops.calculate_best_gains_per_feature( - node_id_range=last_layer_nodes_range, - stats_summary_list=stats_summaries_list[i], - l1=tree_hparams.l1, - l2=tree_hparams.l2, - tree_complexity=tree_hparams.tree_complexity, - min_node_weight=tree_hparams.min_node_weight, - max_splits=max_splits)) - - all_feature_ids += feature_ids - node_ids_per_feature += numeric_node_ids_per_feature - gains_list += numeric_gains_list - thresholds_list += numeric_thresholds_list - left_node_contribs_list += numeric_left_node_contribs_list - right_node_contribs_list += numeric_right_node_contribs_list - - grow_op = boosted_trees_ops.update_ensemble( - # Confirm if local_tree_ensemble or tree_ensemble should be used. - tree_ensemble.resource_handle, - feature_ids=all_feature_ids, - node_ids=node_ids_per_feature, - gains=gains_list, - thresholds=thresholds_list, - left_node_contribs=left_node_contribs_list, - right_node_contribs=right_node_contribs_list, - learning_rate=tree_hparams.learning_rate, - max_depth=tree_hparams.max_depth, - pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING) - return grow_op - if train_in_memory and is_single_machine: - train_op.append(distribute_lib.increment_var(global_step)) - train_op.append( - grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list)) + grower = _InMemoryEnsembleGrower(tree_ensemble, tree_hparams) else: - dependencies = [] - - for i, feature_ids in enumerate(feature_ids_list): - stats_summaries = stats_summaries_list[i] - accumulator = data_flow_ops.ConditionalAccumulator( - dtype=dtypes.float32, - # The stats consist of grads and hessians (the last dimension). - shape=[len(feature_ids), max_splits, bucket_size_list[i], 2], - shared_name='numeric_stats_summary_accumulator_' + str(i)) - accumulators.append(accumulator) - - apply_grad = accumulator.apply_grad( - array_ops.stack(stats_summaries, axis=0), stamp_token) - dependencies.append(apply_grad) - - def grow_tree_from_accumulated_summaries_fn(): - """Updates the tree with the best layer from accumulated summaries.""" - # Take out the accumulated summaries from the accumulator and grow. - stats_summaries_list = [] - - stats_summaries_list = [ - array_ops.unstack(accumulator.take_grad(1), axis=0) - for accumulator in accumulators - ] - - grow_op = grow_tree_from_stats_summaries(stats_summaries_list, - feature_ids_list) - return grow_op - - with ops.control_dependencies(dependencies): - train_op.append(distribute_lib.increment_var(global_step)) - if config.is_chief: - min_accumulated = math_ops.reduce_min( - array_ops.stack( - [acc.num_accumulated() for acc in accumulators])) - - train_op.append( - control_flow_ops.cond( - math_ops.greater_equal(min_accumulated, - n_batches_per_layer), - grow_tree_from_accumulated_summaries_fn, - control_flow_ops.no_op, - name='wait_until_n_batches_accumulated')) + grower = _AccumulatorEnsembleGrower(tree_ensemble, tree_hparams, + stamp_token, n_batches_per_layer, + bucket_size_list, config.is_chief) + + update_model = control_flow_ops.cond( + center_bias_var, + functools.partial( + grower.center_bias, + center_bias_var, + gradients, + hessians, + ), + functools.partial(grower.grow_tree, stats_summaries_list, + feature_ids_list, last_layer_nodes_range)) + train_op.append(update_model) + + with ops.control_dependencies([update_model]): + increment_global = distribute_lib.increment_var(global_step) + train_op.append(increment_global) return control_flow_ops.group(train_op, name='train_op') @@ -739,7 +924,8 @@ class BoostedTreesClassifier(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False): """Initializes a `BoostedTreesClassifier` instance. Example: @@ -807,6 +993,13 @@ class BoostedTreesClassifier(estimator.Estimator): split to be considered. The value will be compared with sum(leaf_hessian)/(batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + Raises: ValueError: when wrong arguments are given or unsupported functionalities @@ -821,7 +1014,7 @@ class BoostedTreesClassifier(estimator.Estimator): # HParams for the model. tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return _bt_model_fn( # pylint: disable=protected-access @@ -864,7 +1057,8 @@ class BoostedTreesRegressor(estimator.Estimator): l2_regularization=0., tree_complexity=0., min_node_weight=0., - config=None): + config=None, + center_bias=False): """Initializes a `BoostedTreesRegressor` instance. Example: @@ -925,6 +1119,12 @@ class BoostedTreesRegressor(estimator.Estimator): split to be considered. The value will be compared with sum(leaf_hessian)/(batch_size * n_batches_per_layer). config: `RunConfig` object to configure the runtime settings. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. Raises: ValueError: when wrong arguments are given or unsupported functionalities @@ -938,7 +1138,7 @@ class BoostedTreesRegressor(estimator.Estimator): # HParams for the model. tree_hparams = _TreeHParams(n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, - tree_complexity, min_node_weight) + tree_complexity, min_node_weight, center_bias) def _model_fn(features, labels, mode, config): return _bt_model_fn( # pylint: disable=protected-access |