diff options
author | 2018-08-21 14:18:27 +0800 | |
---|---|---|
committer | 2018-08-21 14:18:27 +0800 | |
commit | 73c8cbb413029cf3e540e99b883ae89f4b08fc11 (patch) | |
tree | 525c21e4490afd92c889cf3d788cf8432d10c1c1 /tensorflow/python/estimator/canned | |
parent | 88d722c13418fd177c3e03e954307fdfa86a474b (diff) |
TST: add test case for full tree with leaves
Diffstat (limited to 'tensorflow/python/estimator/canned')
-rw-r--r-- | tensorflow/python/estimator/canned/boosted_trees_test.py | 111 |
1 files changed, 111 insertions, 0 deletions
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py index 54ad052915..13e1d224bc 100644 --- a/tensorflow/python/estimator/canned/boosted_trees_test.py +++ b/tensorflow/python/estimator/canned/boosted_trees_test.py @@ -845,6 +845,117 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(AssertionError, 'empty or contains'): est.experimental_feature_importances(normalize=True) + def testFeatureImportancesWithFullTrees(self): + est = boosted_trees.BoostedTreesClassifier( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + n_trees=2, + max_depth=5) + + tree_ensemble_text = """ + trees { + nodes { + bucketized_split { + feature_id: 2 + left_id: 1 + right_id: 2 + } + metadata { + gain: 2.0 + } + } + nodes { + bucketized_split { + feature_id: 0 + left_id: 3 + right_id: 4 + } + metadata { + gain: 3.0 + } + } + nodes { + bucketized_split { + feature_id: 1 + left_id: 5 + right_id: 6 + } + metadata { + gain: 2.0 + } + } + nodes { + leaf { + scalar: -0.34 + } + } + nodes { + leaf { + scalar: 1.34 + } + } + nodes { + leaf { + scalar: 0.0 + } + } + nodes { + leaf { + scalar: 3.34 + } + } + } + trees { + nodes { + bucketized_split { + feature_id: 0 + left_id: 1 + right_id: 2 + } + metadata { + gain: 2.0 + } + } + nodes { + leaf { + scalar: -0.88 + } + } + nodes { + bucketized_split { + feature_id: 2 + left_id: 3 + right_id: 4 + } + metadata { + gain: 1.0 + } + } + nodes { + leaf { + scalar: 1.88 + } + } + nodes { + leaf { + scalar: -2.88 + } + } + } + tree_weights: 1.0 + tree_weights: 1.0 + """ + self._create_fake_checkpoint_with_tree_ensemble_proto(est, tree_ensemble_text) + + feature_names_expected = ['f_0_bucketized', 'f_2_bucketized', 'f_1_bucketized'] + feature_names, importances = est.experimental_feature_importances(normalize=False) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([5.0, 3.0, 2.0], importances) + + feature_names, importances = est.experimental_feature_importances(normalize=True) + self.assertAllEqual(feature_names_expected, feature_names) + self.assertAllClose([0.5, 0.3, 0.2], importances) + def testFeatureImportancesWithMoreTrees(self): est = boosted_trees.BoostedTreesClassifier( feature_columns=self._feature_columns, |