diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 10:41:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 10:41:13 -0700 |
commit | 03cf21a5202b8515d77fdaee3184fd20da2a201c (patch) | |
tree | 3ebbdbfafcd69f0ccf9d3beb619abe8cb278c92f /tensorflow/python/estimator | |
parent | 410905d8e8af12e928031aa026683e43b665c8ae (diff) | |
parent | 046c74c8e7c68aaa726977dd6e8a2523f854f9cc (diff) |
Merge pull request #21509 from facaiy:ENH/feature_importances_for_boosted_tree
PiperOrigin-RevId: 214462540
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 127 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 535 |
2 files changed, 662 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 756d32d03f..0278990cfc 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -21,6 +21,9 @@ import abc import collections import functools +import numpy as np + +from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import boosted_trees_utils @@ -40,6 +43,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.array_ops import identity as tf_identity from tensorflow.python.ops.losses import losses from tensorflow.python.summary import summary +from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.util.tf_export import estimator_export @@ -193,6 +197,43 @@ def _calculate_num_features(sorted_feature_columns): return num_features +def _generate_feature_name_mapping(sorted_feature_columns): + """Return a list of feature name for feature ids. + + Args: + sorted_feature_columns: a list/set of tf.feature_column sorted by name. + + Returns: + feature_name_mapping: a list of feature names indexed by the feature ids. + + Raises: + ValueError: when unsupported features/columns are tried. + """ + names = [] + for column in sorted_feature_columns: + if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access + categorical_column = column.categorical_column + if isinstance(categorical_column, + feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access + for value in categorical_column.vocabulary_list: + names.append('{}:{}'.format(column.name, value)) + elif isinstance(categorical_column, + feature_column_lib._BucketizedColumn): # pylint:disable=protected-access + boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf] + for pair in zip(boundaries[:-1], boundaries[1:]): + names.append('{}:{}'.format(column.name, pair)) + else: + for num in range(categorical_column._num_buckets): # pylint:disable=protected-access + names.append('{}:{}'.format(column.name, num)) + elif isinstance(column, feature_column_lib._BucketizedColumn): + names.append(column.name) + else: + raise ValueError( + 'For now, only bucketized_column and indicator_column is supported ' + 'but got: {}'.format(column)) + return names + + def _cache_transformed_features(features, sorted_feature_columns, batch_size): """Transform features and cache, then returns (cached_features, cache_op).""" num_features = _calculate_num_features(sorted_feature_columns) @@ -966,6 +1007,60 @@ def _create_regression_head(label_dimension, weight_column=None): # pylint: enable=protected-access +def _compute_feature_importances_per_tree(tree, num_features): + """Computes the importance of each feature in the tree.""" + importances = np.zeros(num_features) + + for node in tree.nodes: + node_type = node.WhichOneof('node') + if node_type == 'bucketized_split': + feature_id = node.bucketized_split.feature_id + importances[feature_id] += node.metadata.gain + elif node_type == 'leaf': + assert node.metadata.gain == 0 + else: + raise ValueError('Unexpected split type %s', node_type) + + return importances + + +def _compute_feature_importances(tree_ensemble, num_features, normalize): + """Computes gain-based feature importances. + + The higher the value, the more important the feature. + + Args: + tree_ensemble: a trained tree ensemble, instance of proto + boosted_trees.TreeEnsemble. + num_features: The total number of feature ids. + normalize: If True, normalize the feature importances. + + Returns: + sorted_feature_idx: A list of feature_id which is sorted + by its feature importance. + feature_importances: A list of corresponding feature importances. + + Raises: + AssertionError: When normalize = True, if feature importances + contain negative value, or if normalization is not possible + (e.g. ensemble is empty or trees contain only a root node). + """ + tree_importances = [_compute_feature_importances_per_tree(tree, num_features) + for tree in tree_ensemble.trees] + tree_importances = np.array(tree_importances) + tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1) + feature_importances = np.sum(tree_importances * tree_weights, axis=0) + if normalize: + assert np.all(feature_importances >= 0), ('feature_importances ' + 'must be non-negative.') + normalizer = np.sum(feature_importances) + assert normalizer > 0, 'Trees are all empty or contain only a root node.' + feature_importances /= normalizer + + sorted_feature_idx = np.argsort(feature_importances)[::-1] + return sorted_feature_idx, feature_importances[sorted_feature_idx] + + def _bt_explanations_fn(features, head, sorted_feature_columns, @@ -1053,9 +1148,41 @@ class _BoostedTreesBase(estimator.Estimator): feature_columns, key=lambda tc: tc.name) self._head = head self._n_features = _calculate_num_features(self._sorted_feature_columns) + self._names_for_feature_id = np.array( + _generate_feature_name_mapping(self._sorted_feature_columns)) self._center_bias = center_bias self._is_classification = is_classification + def experimental_feature_importances(self, normalize=False): + """Computes gain-based feature importances. + + The higher the value, the more important the corresponding feature. + + Args: + normalize: If True, normalize the feature importances. + + Returns: + sorted_feature_names: 1-D array of feature name which is sorted + by its feature importance. + feature_importances: 1-D array of the corresponding feature importance. + + Raises: + ValueError: When attempting to normalize on an empty ensemble + or an ensemble of trees which have no splits. Or when attempting + to normalize and feature importances have negative values. + """ + reader = checkpoint_utils.load_checkpoint(self._model_dir) + serialized = reader.get_tensor('boosted_trees:0_serialized') + if not serialized: + raise ValueError('Found empty serialized string for TreeEnsemble.' + 'You should only call this method after training.') + ensemble_proto = boosted_trees_pb2.TreeEnsemble() + ensemble_proto.ParseFromString(serialized) + + sorted_feature_id, importances = _compute_feature_importances( + ensemble_proto, self._n_features, normalize) + return self._names_for_feature_id[sorted_feature_id], importances + def experimental_predict_with_explanations(self, input_fn, predict_keys=None, diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index d4cb3e27d0..23687a738b 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -17,9 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + +from google.protobuf import text_format import numpy as np from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2 +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import model_fn from tensorflow.python.estimator import run_config @@ -31,10 +35,12 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_boosted_trees_ops +from tensorflow.python.ops import boosted_trees_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import session_run_hook NUM_FEATURES = 3 @@ -564,6 +570,535 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id) self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold) + def testFeatureImportancesWithTrainedEnsemble(self): + input_fn = _make_train_input_fn(is_classification=True) + + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + max_depth=5) + + # It will stop after 5 steps because of the max depth and num trees. + num_steps = 100 + # Train for a few steps, and validate final checkpoint. + est.train(input_fn, steps=num_steps) + + feature_names_expected = ['f_0_bucketized', + 'f_2_bucketized', + 'f_1_bucketized'] + + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.833933, 0.606342, 0.0], importances) + + feature_names, importances = est.experimental_feature_importances( + normalize=True) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.579010, 0.420990, 0.0], importances) + + def testFeatureImportancesOnEmptyEnsemble(self): + input_fn = _make_train_input_fn(is_classification=True) + + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + + class BailOutWithoutTraining(session_run_hook.SessionRunHook): + + def before_run(self, run_context): + raise StopIteration('to bail out.') + + # The step-0 checkpoint will have only an empty ensemble. + est.train(input_fn, + steps=100, # must stop at 0 anyway. + hooks=[BailOutWithoutTraining()]) + + with self.assertRaisesRegexp(ValueError, 'empty serialized string'): + est.experimental_feature_importances(normalize=False) + + with self.assertRaisesRegexp(ValueError, 'empty serialized string'): + est.experimental_feature_importances(normalize=True) + + def _create_fake_checkpoint_with_tree_ensemble_proto(self, + est, + tree_ensemble_text): + with ops.Graph().as_default(): + with ops.name_scope('boosted_trees') as name: + tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name) + tree_ensemble_proto = boosted_trees_pb2.TreeEnsemble() + text_format.Merge(tree_ensemble_text, tree_ensemble_proto) + stamp_token, _ = tree_ensemble.serialize() + restore_op = tree_ensemble.deserialize( + stamp_token, tree_ensemble_proto.SerializeToString()) + + with session.Session() as sess: + resources.initialize_resources(resources.shared_resources()).run() + restore_op.run() + saver = saver_lib.Saver() + save_path = os.path.join(est.model_dir, 'model.ckpt') + saver.save(sess, save_path) + + def testFeatureImportancesOnNonEmptyEnsemble(self): + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + max_depth=5) + + tree_ensemble_text = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 2.0 + } + } + nodes { + bucketized_split { + feature_id: 0 + left_id: 3 + right_id: 4 + } + metadata { + gain: 3.0 + } + } + nodes { + bucketized_split { + feature_id: 1 + left_id: 5 + right_id: 6 + } + metadata { + gain: 2.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + nodes { + leaf { + scalar: 0.0 + } + } + nodes { + bucketized_split { + feature_id: 0 + left_id: 7 + right_id: 8 + } + metadata { + gain: 1.0 + } + } + nodes { + leaf { + scalar: 3.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.0 + } + } + nodes { + leaf { + scalar: 3.34 + } + } + nodes { + bucketized_split { + feature_id: 2 + left_id: 3 + right_id: 4 + } + metadata { + gain: 1.0 + } + } + nodes { + leaf { + scalar: 3.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto( + est, tree_ensemble_text) + + feature_names_expected = ['f_0_bucketized', + 'f_2_bucketized', + 'f_1_bucketized'] + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + # Gain sum for each features: + # = 1.0 * [3 + 1, 2, 2] + 1.0 * [1, 1, 0] + self.assertAllClose([5.0, 3.0, 2.0], importances) + + feature_names, importances = est.experimental_feature_importances( + normalize=True) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.5, 0.3, 0.2], importances) + + def testFeatureImportancesWithTreeWeights(self): + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=3, + max_depth=5) + + tree_ensemble_text = """ + trees { + nodes { + bucketized_split { + feature_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 12.5 + } + } + nodes { + bucketized_split { + feature_id: 1 + left_id: 3 + right_id: 4 + } + metadata { + gain: 5.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + nodes { + leaf { + scalar: 0.0 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 5.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 0.4 + tree_weights: 0.6 + tree_weights: 1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto( + est, tree_ensemble_text) + + feature_names_expected = ['f_0_bucketized', + 'f_2_bucketized', + 'f_1_bucketized'] + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + # Gain sum for each features: + # = 0.4 * [12.5, 0, 5] + 0.6 * [0, 5, 0] + 1.0 * [0, 0, 0] + self.assertAllClose([5.0, 3.0, 2.0], importances) + + feature_names, importances = est.experimental_feature_importances( + normalize=True) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.5, 0.3, 0.2], importances) + + def testFeatureImportancesWithAllEmptyTree(self): + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + max_depth=5) + + tree_ensemble_text = """ + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + trees { + nodes { + leaf { + scalar: 0.0 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto( + est, tree_ensemble_text) + + # Reverse order because feature importances are sorted by np.argsort(f)[::-1] + feature_names_expected = ['f_2_bucketized', + 'f_1_bucketized', + 'f_0_bucketized'] + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.0, 0.0, 0.0], importances) + + with self.assertRaisesRegexp(AssertionError, + 'all empty or contain only a root node'): + est.experimental_feature_importances(normalize=True) + + def testNegativeFeatureImportances(self): + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=1, + max_depth=5) + + # In order to generate a negative feature importances, + # We assign an invalid value -1 to tree_weights here. + tree_ensemble_text = """ + trees { + nodes { + bucketized_split { + feature_id: 1 + left_id: 1 + right_id: 2 + } + metadata { + gain: 5.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + } + tree_weights: -1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto( + est, tree_ensemble_text) + + # Github #21509 (nataliaponomareva): + # The gains stored in the splits can be negative + # if people are using complexity regularization. + feature_names_expected = ['f_2_bucketized', + 'f_0_bucketized', + 'f_1_bucketized'] + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.0, 0.0, -5.0], importances) + + with self.assertRaisesRegexp(AssertionError, 'non-negative'): + est.experimental_feature_importances(normalize=True) + + def testFeatureImportancesNamesForCategoricalColumn(self): + categorical = feature_column.categorical_column_with_vocabulary_list( + key='categorical', vocabulary_list=('bad', 'good', 'ok')) + feature_indicator = feature_column.indicator_column(categorical) + bucketized_col = feature_column.bucketized_column( + feature_column.numeric_column( + 'continuous', dtype=dtypes.float32), + BUCKET_BOUNDARIES) + bucketized_indicator = feature_column.indicator_column(bucketized_col) + + est = boosted_trees.BoostedTreesRegressor( + feature_columns=[feature_indicator, + bucketized_col, + bucketized_indicator], + n_batches_per_layer=1, + n_trees=2, + learning_rate=1.0, + max_depth=1) + + tree_ensemble_text = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 5.0 + } + } + nodes { + bucketized_split { + feature_id: 4 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + nodes { + leaf { + scalar: 0.0 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 1.0 + } + } + nodes { + bucketized_split { + feature_id: 5 + left_id: 3 + right_id: 4 + } + metadata { + gain: 2.0 + } + } + nodes { + leaf { + scalar: -2.34 + } + } + nodes { + leaf { + scalar: 3.34 + } + } + nodes { + leaf { + scalar: 4.34 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto( + est, tree_ensemble_text) + + feature_names_expected = ['categorical_indicator:ok', + 'continuous_bucketized_indicator:(-2.0, 0.5)', + 'continuous_bucketized_indicator:(-inf, -2.0)', + 'categorical_indicator:bad', + # Reverse order because feature importances + # are sorted by np.argsort(f)[::-1] + 'continuous_bucketized_indicator:(12.0, inf)', + 'continuous_bucketized_indicator:(0.5, 12.0)', + 'continuous_bucketized', + 'categorical_indicator:good'] + + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + # Gain sum for each features: + # = 1.0 * [5, 0, 2, 0, 0, 0, 0, 0] + 1.0 * [0, 2, 0, 1, 0, 0, 0, 0] + self.assertAllClose([5.0, 2.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0], importances) + + feature_names, importances = est.experimental_feature_importances( + normalize=True) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances) + + def testFeatureImportancesNamesForUnsupportedColumn(self): + numeric_col = feature_column.numeric_column( + 'continuous', dtype=dtypes.float32) + + with self.assertRaisesRegexp(ValueError, + 'only bucketized_column and indicator_column'): + _ = boosted_trees.BoostedTreesRegressor( + feature_columns=[numeric_col], + n_batches_per_layer=1, + n_trees=2, + learning_rate=1.0, + max_depth=1) + def testTreeComplexityIsSetCorrectly(self): input_fn = _make_train_input_fn(is_classification=True) |