aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:41:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:41:13 -0700
commit03cf21a5202b8515d77fdaee3184fd20da2a201c (patch)
tree3ebbdbfafcd69f0ccf9d3beb619abe8cb278c92f /tensorflow/python/estimator
parent410905d8e8af12e928031aa026683e43b665c8ae (diff)
parent046c74c8e7c68aaa726977dd6e8a2523f854f9cc (diff)
Merge pull request #21509 from facaiy:ENH/feature_importances_for_boosted_tree
PiperOrigin-RevId: 214462540
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py127
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py535
2 files changed, 662 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 756d32d03f..0278990cfc 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -21,6 +21,9 @@ import abc
import collections
import functools
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import boosted_trees_utils
@@ -40,6 +43,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
+from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.util.tf_export import estimator_export
@@ -193,6 +197,43 @@ def _calculate_num_features(sorted_feature_columns):
return num_features
+def _generate_feature_name_mapping(sorted_feature_columns):
+ """Return a list of feature name for feature ids.
+
+ Args:
+ sorted_feature_columns: a list/set of tf.feature_column sorted by name.
+
+ Returns:
+ feature_name_mapping: a list of feature names indexed by the feature ids.
+
+ Raises:
+ ValueError: when unsupported features/columns are tried.
+ """
+ names = []
+ for column in sorted_feature_columns:
+ if isinstance(column, feature_column_lib._IndicatorColumn): # pylint:disable=protected-access
+ categorical_column = column.categorical_column
+ if isinstance(categorical_column,
+ feature_column_lib._VocabularyListCategoricalColumn): # pylint:disable=protected-access
+ for value in categorical_column.vocabulary_list:
+ names.append('{}:{}'.format(column.name, value))
+ elif isinstance(categorical_column,
+ feature_column_lib._BucketizedColumn): # pylint:disable=protected-access
+ boundaries = [-np.inf] + list(categorical_column.boundaries) + [np.inf]
+ for pair in zip(boundaries[:-1], boundaries[1:]):
+ names.append('{}:{}'.format(column.name, pair))
+ else:
+ for num in range(categorical_column._num_buckets): # pylint:disable=protected-access
+ names.append('{}:{}'.format(column.name, num))
+ elif isinstance(column, feature_column_lib._BucketizedColumn):
+ names.append(column.name)
+ else:
+ raise ValueError(
+ 'For now, only bucketized_column and indicator_column is supported '
+ 'but got: {}'.format(column))
+ return names
+
+
def _cache_transformed_features(features, sorted_feature_columns, batch_size):
"""Transform features and cache, then returns (cached_features, cache_op)."""
num_features = _calculate_num_features(sorted_feature_columns)
@@ -966,6 +1007,60 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
+def _compute_feature_importances_per_tree(tree, num_features):
+ """Computes the importance of each feature in the tree."""
+ importances = np.zeros(num_features)
+
+ for node in tree.nodes:
+ node_type = node.WhichOneof('node')
+ if node_type == 'bucketized_split':
+ feature_id = node.bucketized_split.feature_id
+ importances[feature_id] += node.metadata.gain
+ elif node_type == 'leaf':
+ assert node.metadata.gain == 0
+ else:
+ raise ValueError('Unexpected split type %s', node_type)
+
+ return importances
+
+
+def _compute_feature_importances(tree_ensemble, num_features, normalize):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the feature.
+
+ Args:
+ tree_ensemble: a trained tree ensemble, instance of proto
+ boosted_trees.TreeEnsemble.
+ num_features: The total number of feature ids.
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_idx: A list of feature_id which is sorted
+ by its feature importance.
+ feature_importances: A list of corresponding feature importances.
+
+ Raises:
+ AssertionError: When normalize = True, if feature importances
+ contain negative value, or if normalization is not possible
+ (e.g. ensemble is empty or trees contain only a root node).
+ """
+ tree_importances = [_compute_feature_importances_per_tree(tree, num_features)
+ for tree in tree_ensemble.trees]
+ tree_importances = np.array(tree_importances)
+ tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1)
+ feature_importances = np.sum(tree_importances * tree_weights, axis=0)
+ if normalize:
+ assert np.all(feature_importances >= 0), ('feature_importances '
+ 'must be non-negative.')
+ normalizer = np.sum(feature_importances)
+ assert normalizer > 0, 'Trees are all empty or contain only a root node.'
+ feature_importances /= normalizer
+
+ sorted_feature_idx = np.argsort(feature_importances)[::-1]
+ return sorted_feature_idx, feature_importances[sorted_feature_idx]
+
+
def _bt_explanations_fn(features,
head,
sorted_feature_columns,
@@ -1053,9 +1148,41 @@ class _BoostedTreesBase(estimator.Estimator):
feature_columns, key=lambda tc: tc.name)
self._head = head
self._n_features = _calculate_num_features(self._sorted_feature_columns)
+ self._names_for_feature_id = np.array(
+ _generate_feature_name_mapping(self._sorted_feature_columns))
self._center_bias = center_bias
self._is_classification = is_classification
+ def experimental_feature_importances(self, normalize=False):
+ """Computes gain-based feature importances.
+
+ The higher the value, the more important the corresponding feature.
+
+ Args:
+ normalize: If True, normalize the feature importances.
+
+ Returns:
+ sorted_feature_names: 1-D array of feature name which is sorted
+ by its feature importance.
+ feature_importances: 1-D array of the corresponding feature importance.
+
+ Raises:
+ ValueError: When attempting to normalize on an empty ensemble
+ or an ensemble of trees which have no splits. Or when attempting
+ to normalize and feature importances have negative values.
+ """
+ reader = checkpoint_utils.load_checkpoint(self._model_dir)
+ serialized = reader.get_tensor('boosted_trees:0_serialized')
+ if not serialized:
+ raise ValueError('Found empty serialized string for TreeEnsemble.'
+ 'You should only call this method after training.')
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+
+ sorted_feature_id, importances = _compute_feature_importances(
+ ensemble_proto, self._n_features, normalize)
+ return self._names_for_feature_id[sorted_feature_id], importances
+
def experimental_predict_with_explanations(self,
input_fn,
predict_keys=None,
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index d4cb3e27d0..23687a738b 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -17,9 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
+from google.protobuf import text_format
import numpy as np
from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator import run_config
@@ -31,10 +35,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_boosted_trees_ops
+from tensorflow.python.ops import boosted_trees_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.training import checkpoint_utils
+from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
NUM_FEATURES = 3
@@ -564,6 +570,535 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(1, ensemble.trees[0].nodes[0].bucketized_split.feature_id)
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+ def testFeatureImportancesWithTrainedEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ # It will stop after 5 steps because of the max depth and num trees.
+ num_steps = 100
+ # Train for a few steps, and validate final checkpoint.
+ est.train(input_fn, steps=num_steps)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.833933, 0.606342, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.579010, 0.420990, 0.0], importances)
+
+ def testFeatureImportancesOnEmptyEnsemble(self):
+ input_fn = _make_train_input_fn(is_classification=True)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ class BailOutWithoutTraining(session_run_hook.SessionRunHook):
+
+ def before_run(self, run_context):
+ raise StopIteration('to bail out.')
+
+ # The step-0 checkpoint will have only an empty ensemble.
+ est.train(input_fn,
+ steps=100, # must stop at 0 anyway.
+ hooks=[BailOutWithoutTraining()])
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=False)
+
+ with self.assertRaisesRegexp(ValueError, 'empty serialized string'):
+ est.experimental_feature_importances(normalize=True)
+
+ def _create_fake_checkpoint_with_tree_ensemble_proto(self,
+ est,
+ tree_ensemble_text):
+ with ops.Graph().as_default():
+ with ops.name_scope('boosted_trees') as name:
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+ tree_ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ text_format.Merge(tree_ensemble_text, tree_ensemble_proto)
+ stamp_token, _ = tree_ensemble.serialize()
+ restore_op = tree_ensemble.deserialize(
+ stamp_token, tree_ensemble_proto.SerializeToString())
+
+ with session.Session() as sess:
+ resources.initialize_resources(resources.shared_resources()).run()
+ restore_op.run()
+ saver = saver_lib.Saver()
+ save_path = os.path.join(est.model_dir, 'model.ckpt')
+ saver.save(sess, save_path)
+
+ def testFeatureImportancesOnNonEmptyEnsemble(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 3.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 7
+ right_id: 8
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [3 + 1, 2, 2] + 1.0 * [1, 1, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithTreeWeights(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=3,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 12.5
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 0.4
+ tree_weights: 0.6
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized',
+ 'f_2_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 0.4 * [12.5, 0, 5] + 0.6 * [0, 5, 0] + 1.0 * [0, 0, 0]
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
+ def testFeatureImportancesWithAllEmptyTree(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Reverse order because feature importances are sorted by np.argsort(f)[::-1]
+ feature_names_expected = ['f_2_bucketized',
+ 'f_1_bucketized',
+ 'f_0_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, 0.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError,
+ 'all empty or contain only a root node'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testNegativeFeatureImportances(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5)
+
+ # In order to generate a negative feature importances,
+ # We assign an invalid value -1 to tree_weights here.
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ }
+ tree_weights: -1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ # Github #21509 (nataliaponomareva):
+ # The gains stored in the splits can be negative
+ # if people are using complexity regularization.
+ feature_names_expected = ['f_2_bucketized',
+ 'f_0_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, -5.0], importances)
+
+ with self.assertRaisesRegexp(AssertionError, 'non-negative'):
+ est.experimental_feature_importances(normalize=True)
+
+ def testFeatureImportancesNamesForCategoricalColumn(self):
+ categorical = feature_column.categorical_column_with_vocabulary_list(
+ key='categorical', vocabulary_list=('bad', 'good', 'ok'))
+ feature_indicator = feature_column.indicator_column(categorical)
+ bucketized_col = feature_column.bucketized_column(
+ feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32),
+ BUCKET_BOUNDARIES)
+ bucketized_indicator = feature_column.indicator_column(bucketized_col)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[feature_indicator,
+ bucketized_col,
+ bucketized_indicator],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 5.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 4
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 5
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -2.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 4.34
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(
+ est, tree_ensemble_text)
+
+ feature_names_expected = ['categorical_indicator:ok',
+ 'continuous_bucketized_indicator:(-2.0, 0.5)',
+ 'continuous_bucketized_indicator:(-inf, -2.0)',
+ 'categorical_indicator:bad',
+ # Reverse order because feature importances
+ # are sorted by np.argsort(f)[::-1]
+ 'continuous_bucketized_indicator:(12.0, inf)',
+ 'continuous_bucketized_indicator:(0.5, 12.0)',
+ 'continuous_bucketized',
+ 'categorical_indicator:good']
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ # Gain sum for each features:
+ # = 1.0 * [5, 0, 2, 0, 0, 0, 0, 0] + 1.0 * [0, 2, 0, 1, 0, 0, 0, 0]
+ self.assertAllClose([5.0, 2.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.2, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0], importances)
+
+ def testFeatureImportancesNamesForUnsupportedColumn(self):
+ numeric_col = feature_column.numeric_column(
+ 'continuous', dtype=dtypes.float32)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'only bucketized_column and indicator_column'):
+ _ = boosted_trees.BoostedTreesRegressor(
+ feature_columns=[numeric_col],
+ n_batches_per_layer=1,
+ n_trees=2,
+ learning_rate=1.0,
+ max_depth=1)
+
def testTreeComplexityIsSetCorrectly(self):
input_fn = _make_train_input_fn(is_classification=True)