aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-21 14:18:27 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-21 14:18:27 +0800
commit73c8cbb413029cf3e540e99b883ae89f4b08fc11 (patch)
tree525c21e4490afd92c889cf3d788cf8432d10c1c1 /tensorflow/python/estimator/canned
parent88d722c13418fd177c3e03e954307fdfa86a474b (diff)
TST: add test case for full tree with leaves
Diffstat (limited to 'tensorflow/python/estimator/canned')
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py111
1 files changed, 111 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 54ad052915..13e1d224bc 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -845,6 +845,117 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(AssertionError, 'empty or contains'):
est.experimental_feature_importances(normalize=True)
+ def testFeatureImportancesWithFullTrees(self):
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=2,
+ max_depth=5)
+
+ tree_ensemble_text = """
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 3.0
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 1
+ left_id: 5
+ right_id: 6
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.34
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 0.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 3.34
+ }
+ }
+ }
+ trees {
+ nodes {
+ bucketized_split {
+ feature_id: 0
+ left_id: 1
+ right_id: 2
+ }
+ metadata {
+ gain: 2.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -0.88
+ }
+ }
+ nodes {
+ bucketized_split {
+ feature_id: 2
+ left_id: 3
+ right_id: 4
+ }
+ metadata {
+ gain: 1.0
+ }
+ }
+ nodes {
+ leaf {
+ scalar: 1.88
+ }
+ }
+ nodes {
+ leaf {
+ scalar: -2.88
+ }
+ }
+ }
+ tree_weights: 1.0
+ tree_weights: 1.0
+ """
+ self._create_fake_checkpoint_with_tree_ensemble_proto(est, tree_ensemble_text)
+
+ feature_names_expected = ['f_0_bucketized', 'f_2_bucketized', 'f_1_bucketized']
+ feature_names, importances = est.experimental_feature_importances(normalize=False)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([5.0, 3.0, 2.0], importances)
+
+ feature_names, importances = est.experimental_feature_importances(normalize=True)
+ self.assertAllEqual(feature_names_expected, feature_names)
+ self.assertAllClose([0.5, 0.3, 0.2], importances)
+
def testFeatureImportancesWithMoreTrees(self):
est = boosted_trees.BoostedTreesClassifier(
feature_columns=self._feature_columns,