diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/boosted_trees.py')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 383 |
1 files changed, 358 insertions, 25 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 19f18015e4..0278990cfc 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -21,8 +21,12 @@ import abc import collections import functools +import numpy as np + +from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 from tensorflow.python.estimator import estimator -from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import boosted_trees_utils from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.feature_column import feature_column as feature_column_lib from tensorflow.python.framework import dtypes @@ -36,8 +40,10 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.array_ops import identity as tf_identity from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import estimator_export @@ -191,14 +197,50 @@ def _calculate_num_features(sorted_feature_columns): return num_features +def _generate_feature_name_mapping(sorted_feature_columns): + """Return a list of feature name for feature ids. + + Args: + sorted_feature_columns: a list/set of tf.feature_column sorted by name. + + Returns: + feature_name_mapping: a list of feature names indexed by the feature ids. + + Raises: + ValueError: when unsupported features/columns are tried. + """ + names = [] + for column in sorted_feature_columns: + if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access + categorical_column = column.categorical_column + if isinstance(categorical_column, + feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access + for value in categorical_column.vocabulary_list: + names.append('{}:{}'.format(column.name, value)) + elif isinstance(categorical_column, + feature_column_lib._BucketizedColumn): # pylint:disable=protected-access + boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf] + for pair in zip(boundaries[:-1], boundaries[1:]): + names.append('{}:{}'.format(column.name, pair)) + else: + for num in range(categorical_column._num_buckets): # pylint:disable=protected-access + names.append('{}:{}'.format(column.name, num)) + elif isinstance(column, feature_column_lib._BucketizedColumn): + names.append(column.name) + else: + raise ValueError( + 'For now, only bucketized_column and indicator_column is supported ' + 'but got: {}'.format(column)) + return names + + def _cache_transformed_features(features, sorted_feature_columns, batch_size): """Transform features and cache, then returns (cached_features, cache_op).""" num_features = _calculate_num_features(sorted_feature_columns) cached_features = [ _local_variable( array_ops.zeros([batch_size], dtype=dtypes.int32), - name='cached_feature_{}'.format(i)) - for i in range(num_features) + name='cached_feature_{}'.format(i)) for i in range(num_features) ] are_features_cached = _local_variable(False, name='are_features_cached') @@ -228,8 +270,7 @@ def _cache_transformed_features(features, sorted_feature_columns, batch_size): return cached, cache_flip_op input_feature_list, cache_flip_op = control_flow_ops.cond( - are_features_cached, - lambda: (cached_features, control_flow_ops.no_op()), + are_features_cached, lambda: (cached_features, control_flow_ops.no_op()), cache_features_and_return) return input_feature_list, cache_flip_op @@ -263,8 +304,8 @@ class _CacheTrainingStatesUsingHashTable(object): elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype): empty_key = '' else: - raise ValueError('Unsupported example_id_feature dtype %s.' % - example_ids.dtype) + raise ValueError( + 'Unsupported example_id_feature dtype %s.' % example_ids.dtype) # Cache holds latest <tree_id, node_id, logits> for each example. # tree_id and node_id are both int32 but logits is a float32. # To reduce the overhead, we store all of them together as float32 and @@ -273,8 +314,8 @@ class _CacheTrainingStatesUsingHashTable(object): empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3]) self._example_ids = ops.convert_to_tensor(example_ids) if self._example_ids.shape.ndims not in (None, 1): - raise ValueError('example_id should have rank 1, but got %s' % - self._example_ids) + raise ValueError( + 'example_id should have rank 1, but got %s' % self._example_ids) self._logits_dimension = logits_dimension def lookup(self): @@ -334,7 +375,7 @@ class _CacheTrainingStatesUsingVariables(object): array_ops.zeros([batch_size], dtype=dtypes.int32), name='tree_ids_cache') self._node_ids = _local_variable( - _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32), + _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32), name='node_ids_cache') self._logits = _local_variable( array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32), @@ -422,9 +463,13 @@ class _EnsembleGrower(object): self._pruning_mode_parsed = boosted_trees_ops.PruningMode.from_str( tree_hparams.pruning_mode) - if (self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING - and tree_hparams.tree_complexity <= 0): - raise ValueError('For pruning, tree_complexity must be positive.') + if tree_hparams.tree_complexity > 0: + if self._pruning_mode_parsed == boosted_trees_ops.PruningMode.NO_PRUNING: + raise ValueError( + 'Tree complexity have no effect unless pruning mode is chosen.') + else: + if self._pruning_mode_parsed != boosted_trees_ops.PruningMode.NO_PRUNING: + raise ValueError('For pruning, tree_complexity must be positive.') # pylint: enable=protected-access @abc.abstractmethod @@ -719,7 +764,7 @@ def _bt_model_fn( tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) # Create logits. - if mode != model_fn.ModeKeys.TRAIN: + if mode != model_fn_lib.ModeKeys.TRAIN: input_feature_list = _get_transformed_features(features, sorted_feature_columns) logits = boosted_trees_ops.predict( @@ -886,6 +931,7 @@ def _bt_model_fn( labels=labels, train_op_fn=_train_op_fn, logits=logits) + # Add an early stop hook. estimator_spec = estimator_spec._replace( training_hooks=estimator_spec.training_hooks + @@ -927,8 +973,8 @@ def _create_classification_head_and_closed_form(n_classes, weight_column, label_vocabulary): """Creates a head for classifier and the closed form gradients/hessians.""" head = _create_classification_head(n_classes, weight_column, label_vocabulary) - if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None - and label_vocabulary is None): + if (n_classes == 2 and head.logits_dimension == 1 and + weight_column is None and label_vocabulary is None): # Use the closed-form gradients/hessians for 2 class. def _grad_and_hess_for_logloss(logits, labels): """A closed form gradient and hessian for logistic loss.""" @@ -961,8 +1007,282 @@ def _create_regression_head(label_dimension, weight_column=None): # pylint: enable=protected-access +def _compute_feature_importances_per_tree(tree, num_features): + """Computes the importance of each feature in the tree.""" + importances = np.zeros(num_features) + + for node in tree.nodes: + node_type = node.WhichOneof('node') + if node_type == 'bucketized_split': + feature_id = node.bucketized_split.feature_id + importances[feature_id] += node.metadata.gain + elif node_type == 'leaf': + assert node.metadata.gain == 0 + else: + raise ValueError('Unexpected split type %s', node_type) + + return importances + + +def _compute_feature_importances(tree_ensemble, num_features, normalize): + """Computes gain-based feature importances. + + The higher the value, the more important the feature. + + Args: + tree_ensemble: a trained tree ensemble, instance of proto + boosted_trees.TreeEnsemble. + num_features: The total number of feature ids. + normalize: If True, normalize the feature importances. + + Returns: + sorted_feature_idx: A list of feature_id which is sorted + by its feature importance. + feature_importances: A list of corresponding feature importances. + + Raises: + AssertionError: When normalize = True, if feature importances + contain negative value, or if normalization is not possible + (e.g. ensemble is empty or trees contain only a root node). + """ + tree_importances = [_compute_feature_importances_per_tree(tree, num_features) + for tree in tree_ensemble.trees] + tree_importances = np.array(tree_importances) + tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1) + feature_importances = np.sum(tree_importances * tree_weights, axis=0) + if normalize: + assert np.all(feature_importances >= 0), ('feature_importances ' + 'must be non-negative.') + normalizer = np.sum(feature_importances) + assert normalizer > 0, 'Trees are all empty or contain only a root node.' + feature_importances /= normalizer + + sorted_feature_idx = np.argsort(feature_importances)[::-1] + return sorted_feature_idx, feature_importances[sorted_feature_idx] + + +def _bt_explanations_fn(features, + head, + sorted_feature_columns, + name='boosted_trees'): + """Gradient Boosted Trees predict with explanations model_fn. + + Args: + features: dict of `Tensor`. + head: A `head_lib._Head` instance. + sorted_feature_columns: Sorted iterable of `feature_column._FeatureColumn` + model inputs. + name: Name used for the model. + + Returns: + An `EstimatorSpec` instance. + + Raises: + ValueError: mode or params are invalid, or features has the wrong type. + """ + mode = model_fn_lib.ModeKeys.PREDICT + with ops.name_scope(name) as name: + # Create Ensemble resources. + tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) + + input_feature_list = _get_transformed_features(features, + sorted_feature_columns) + + logits = boosted_trees_ops.predict( + # For non-TRAIN mode, ensemble doesn't change after initialization, + # so no local copy is needed; using tree_ensemble directly. + tree_ensemble_handle=tree_ensemble.resource_handle, + bucketized_features=input_feature_list, + logits_dimension=head.logits_dimension) + + estimator_spec = head.create_estimator_spec( + features=features, + mode=mode, + labels=None, + train_op_fn=control_flow_ops.no_op, + logits=logits) + + debug_op = boosted_trees_ops.example_debug_outputs( + tree_ensemble.resource_handle, + bucketized_features=input_feature_list, + logits_dimension=head.logits_dimension) + estimator_spec.predictions[boosted_trees_utils._DEBUG_PROTO_KEY] = debug_op # pylint: disable=protected-access + return estimator_spec + + +class _BoostedTreesBase(estimator.Estimator): + """Base class for boosted trees estimators. + + This class is intended to keep tree-specific functions (E.g., methods for + feature importances and directional feature contributions) in one central + place. + + It is not a valid (working) Estimator on its own and should only be used as a + base class. + """ + + def __init__(self, model_fn, model_dir, config, feature_columns, head, + center_bias, is_classification): + """Initializes a `_BoostedTreesBase` instance. + + Args: + model_fn: model_fn: Model function. See base class for more detail. + model_dir: Directory to save model parameters, graph and etc. See base + class for more detail. + config: `estimator.RunConfig` configuration object. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn` + head: A `head_lib._Head` instance. + center_bias: Whether bias centering needs to occur. Bias centering refers + to the first node in the very first tree returning the prediction that + is aligned with the original labels distribution. For example, for + regression problems, the first node will return the mean of the labels. + For binary classification problems, it will return a logit for a prior + probability of label 1. + is_classification: If the estimator is for classification. + """ + super(_BoostedTreesBase, self).__init__( + model_fn=model_fn, model_dir=model_dir, config=config) + self._sorted_feature_columns = sorted( + feature_columns, key=lambda tc: tc.name) + self._head = head + self._n_features = _calculate_num_features(self._sorted_feature_columns) + self._names_for_feature_id = np.array( + _generate_feature_name_mapping(self._sorted_feature_columns)) + self._center_bias = center_bias + self._is_classification = is_classification + + def experimental_feature_importances(self, normalize=False): + """Computes gain-based feature importances. + + The higher the value, the more important the corresponding feature. + + Args: + normalize: If True, normalize the feature importances. + + Returns: + sorted_feature_names: 1-D array of feature name which is sorted + by its feature importance. + feature_importances: 1-D array of the corresponding feature importance. + + Raises: + ValueError: When attempting to normalize on an empty ensemble + or an ensemble of trees which have no splits. Or when attempting + to normalize and feature importances have negative values. + """ + reader = checkpoint_utils.load_checkpoint(self._model_dir) + serialized = reader.get_tensor('boosted_trees:0_serialized') + if not serialized: + raise ValueError('Found empty serialized string for TreeEnsemble.' + 'You should only call this method after training.') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + sorted_feature_id, importances = _compute_feature_importances( + ensemble_proto, self._n_features, normalize) + return self._names_for_feature_id[sorted_feature_id], importances + + def experimental_predict_with_explanations(self, + input_fn, + predict_keys=None, + hooks=None, + checkpoint_path=None): + """Computes model explainability outputs per example along with predictions. + + Currently supports directional feature contributions (DFCs). For each + instance, DFCs indicate the aggregate contribution of each feature. See + https://arxiv.org/abs/1312.1121 and + http://blog.datadive.net/interpreting-random-forests/ for more details. + Args: + input_fn: A function that provides input data for predicting as + minibatches. See [Premade Estimators]( + https://tensorflow.org/guide/premade_estimators#create_input_functions) + for more information. The function should construct and return one of + the following: * A `tf.data.Dataset` object: Outputs of `Dataset` + object must be a tuple `(features, labels)` with same constraints as + below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` + or a dictionary of string feature name to `Tensor` and `labels` is a + `Tensor` or a dictionary of string label name to `Tensor`. Both + `features` and `labels` are consumed by `model_fn`. They should + satisfy the expectation of `model_fn` from inputs. + predict_keys: list of `str`, name of the keys to predict. It is used if + the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If + `predict_keys` is used then rest of the predictions will be filtered + from the dictionary, with the exception of 'bias' and 'dfc', which will + always be in the dictionary. If `None`, returns all keys in prediction + dict, as well as two new keys 'dfc' and 'bias'. + hooks: List of `tf.train.SessionRunHook` subclass instances. Used for + callbacks inside the prediction call. + checkpoint_path: Path of a specific checkpoint to predict. If `None`, the + latest checkpoint in `model_dir` is used. If there are no checkpoints + in `model_dir`, prediction is run with newly initialized `Variables` + instead of ones restored from checkpoint. + + Yields: + Evaluated values of `predictions` tensors. The `predictions` tensors will + contain at least two keys 'dfc' and 'bias' for model explanations. The + `dfc` value corresponds to the contribution of each feature to the overall + prediction for this instance (positive indicating that the feature makes + it more likely to select class 1 and negative less likely). The 'bias' + value will be the same across all the instances, corresponding to the + probability (classification) or prediction (regression) of the training + data distribution. + + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. + """ + if not self._center_bias: + raise ValueError('center_bias must be enabled during estimator ' + 'instantiation when using ' + 'experimental_predict_with_explanations.') + # pylint: disable=protected-access + if not self._is_classification: + identity_inverse_link_fn = self._head._inverse_link_fn in (None, + tf_identity) + # pylint:enable=protected-access + if not identity_inverse_link_fn: + raise ValueError( + 'For now only identity inverse_link_fn in regression_head is ' + 'supported for experimental_predict_with_explanations.') + + # pylint:disable=unused-argument + def new_model_fn(features, labels, mode): + return _bt_explanations_fn(features, self._head, + self._sorted_feature_columns) + + # pylint:enable=unused-argument + est = estimator.Estimator( + model_fn=new_model_fn, + model_dir=self.model_dir, + config=self.config, + warm_start_from=self._warm_start_settings) + # Make sure bias and dfc will be in prediction dict. + user_supplied_predict_keys = predict_keys is not None + if user_supplied_predict_keys: + predict_keys = set(predict_keys) + predict_keys.add(boosted_trees_utils._DEBUG_PROTO_KEY) + predictions = est.predict( + input_fn, + predict_keys=predict_keys, + hooks=hooks, + checkpoint_path=checkpoint_path, + yield_single_examples=True) + for pred in predictions: + bias, dfcs = boosted_trees_utils._parse_explanations_from_prediction( + pred[boosted_trees_utils._DEBUG_PROTO_KEY], self._n_features, + self._is_classification) + pred['bias'] = bias + pred['dfc'] = dfcs + # Don't need to expose serialized proto to end user. + del pred[boosted_trees_utils._DEBUG_PROTO_KEY] + yield pred + + +# pylint: disable=protected-access @estimator_export('estimator.BoostedTreesClassifier') -class BoostedTreesClassifier(estimator.Estimator): +class BoostedTreesClassifier(_BoostedTreesBase): """A Classifier for Tensorflow Boosted Trees models. @compatibility(eager) @@ -1082,14 +1402,13 @@ class BoostedTreesClassifier(estimator.Estimator): n_classes = 2 head, closed_form = _create_classification_head_and_closed_form( n_classes, weight_column, label_vocabulary=label_vocabulary) - # HParams for the model. tree_hparams = _TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): - return _bt_model_fn( # pylint: disable=protected-access + return _bt_model_fn( features, labels, mode, @@ -1101,11 +1420,17 @@ class BoostedTreesClassifier(estimator.Estimator): closed_form_grad_and_hess_fn=closed_form) super(BoostedTreesClassifier, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_columns=feature_columns, + head=head, + center_bias=center_bias, + is_classification=True) @estimator_export('estimator.BoostedTreesRegressor') -class BoostedTreesRegressor(estimator.Estimator): +class BoostedTreesRegressor(_BoostedTreesBase): """A Regressor for Tensorflow Boosted Trees models. @compatibility(eager) @@ -1223,9 +1548,17 @@ class BoostedTreesRegressor(estimator.Estimator): tree_complexity, min_node_weight, center_bias, pruning_mode) def _model_fn(features, labels, mode, config): - return _bt_model_fn( # pylint: disable=protected-access - features, labels, mode, head, feature_columns, tree_hparams, - n_batches_per_layer, config) + return _bt_model_fn(features, labels, mode, head, feature_columns, + tree_hparams, n_batches_per_layer, config) super(BoostedTreesRegressor, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_columns=feature_columns, + head=head, + center_bias=center_bias, + is_classification=False) + + +# pylint: enable=protected-access |