aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-14 11:32:37 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-14 12:03:05 +0800
commit30e176f584d80898ebad00d2a2ff226e6c281c50 (patch)
treebd97352c1e491eda01f16596f3ed817fb9ae7c41 /tensorflow/python/estimator
parent04ddc2daf4c76bb4c99fdc6b582025e9a4ffba52 (diff)
CLN: only assert gains >= 0 for normalization
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py8
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py12
2 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 812c892363..7c04ff7970 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -1036,8 +1036,8 @@ def _compute_feature_importances(tree_ensemble, num_features, normalize):
feature_importances: A list of corresponding feature importances.
Raises:
- AssertionError: If feature importances contain negative value.
- Or if normalize = True and normalization is not possible
+ AssertionError: When normalize = True, if feature importances
+ contain negative value, or if normalization is not possible
(e.g. ensemble is empty or trees contain only a root node).
"""
tree_importances = [_compute_feature_importances_per_tree(tree, num_features)
@@ -1045,9 +1045,9 @@ def _compute_feature_importances(tree_ensemble, num_features, normalize):
tree_importances = np.array(tree_importances)
tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1)
feature_importances = np.sum(tree_importances * tree_weights, axis=0)
- assert np.all(feature_importances >= 0), ('feature_importances '
- 'must be non-negative.')
if normalize:
+ assert np.all(feature_importances >= 0), ('feature_importances '
+ 'must be non-negative.')
normalizer = np.sum(feature_importances)
assert normalizer > 0, 'Trees are all empty or contain only a root node.'
feature_importances /= normalizer
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 1ce4f7d765..3158ccca81 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -949,8 +949,16 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self._create_fake_checkpoint_with_tree_ensemble_proto(
est, tree_ensemble_text)
- with self.assertRaisesRegexp(AssertionError, 'non-negative'):
- est.experimental_feature_importances(normalize=False)
+ # Github #21509 (nataliaponomareva):
+ # The gains stored in the splits can be negative
+ # if people are using complexity regularization.
+ feature_names_expected = ['f_2_bucketized',
+ 'f_0_bucketized',
+ 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(
+ normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.0, 0.0, -5.0], importances)
with self.assertRaisesRegexp(AssertionError, 'non-negative'):
est.experimental_feature_importances(normalize=True)