diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-14 11:32:37 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-14 12:03:05 +0800 |
commit | 30e176f584d80898ebad00d2a2ff226e6c281c50 (patch) | |
tree | bd97352c1e491eda01f16596f3ed817fb9ae7c41 /tensorflow/python/estimator | |
parent | 04ddc2daf4c76bb4c99fdc6b582025e9a4ffba52 (diff) |
CLN: only assert gains >= 0 for normalization
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees.py | 8 | ||||
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 12 |
2 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py index 812c892363..7c04ff7970 100644 --- a/tensorflow/python/estimator/canned/boosted_trees.py +++ b/tensorflow/python/estimator/canned/boosted_trees.py @@ -1036,8 +1036,8 @@ def _compute_feature_importances(tree_ensemble, num_features, normalize): feature_importances: A list of corresponding feature importances. Raises: - AssertionError: If feature importances contain negative value. - Or if normalize = True and normalization is not possible + AssertionError: When normalize = True, if feature importances + contain negative value, or if normalization is not possible (e.g. ensemble is empty or trees contain only a root node). """ tree_importances = [_compute_feature_importances_per_tree(tree, num_features) @@ -1045,9 +1045,9 @@ def _compute_feature_importances(tree_ensemble, num_features, normalize): tree_importances = np.array(tree_importances) tree_weights = np.array(tree_ensemble.tree_weights).reshape(-1, 1) feature_importances = np.sum(tree_importances * tree_weights, axis=0) - assert np.all(feature_importances >= 0), ('feature_importances ' - 'must be non-negative.') if normalize: + assert np.all(feature_importances >= 0), ('feature_importances ' + 'must be non-negative.') normalizer = np.sum(feature_importances) assert normalizer > 0, 'Trees are all empty or contain only a root node.' feature_importances /= normalizer diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 1ce4f7d765..3158ccca81 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -949,8 +949,16 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): self._create_fake_checkpoint_with_tree_ensemble_proto( est, tree_ensemble_text) - with self.assertRaisesRegexp(AssertionError, 'non-negative'): - est.experimental_feature_importances(normalize=False) + # Github #21509 (nataliaponomareva): + # The gains stored in the splits can be negative + # if people are using complexity regularization. + feature_names_expected = ['f_2_bucketized', + 'f_0_bucketized', + 'f_1_bucketized'] + feature_names, importances = est.experimental_feature_importances( + normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.0, 0.0, -5.0], importances) with self.assertRaisesRegexp(AssertionError, 'non-negative'): est.experimental_feature_importances(normalize=True) |