diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 10:41:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 10:41:13 -0700 |
commit | 03cf21a5202b8515d77fdaee3184fd20da2a201c (patch) | |
tree | 3ebbdbfafcd69f0ccf9d3beb619abe8cb278c92f /tensorflow/contrib/estimator | |
parent | 410905d8e8af12e928031aa026683e43b665c8ae (diff) | |
parent | 046c74c8e7c68aaa726977dd6e8a2523f854f9cc (diff) |
Merge pull request #21509 from facaiy:ENH/feature_importances_for_boosted_tree
PiperOrigin-RevId: 214462540
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/boosted_trees.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 11f60c8238..a1f1c5f3d7 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -34,18 +34,19 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn): return _input_fn -# pylint: disable=protected-access def _is_classification_head(head): """Infers if the head is a classification head.""" # Check using all classification heads defined in canned/head.py. However, it # is not a complete list - it does not check for other classification heads # not defined in the head library. + # pylint: disable=protected-access return isinstance(head, (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss, head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss)) + # pylint: enable=protected-access -class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): +class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access """An Estimator for Tensorflow Boosted Trees models.""" def __init__(self, @@ -113,6 +114,7 @@ class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): are requested. """ # HParams for the model. + # pylint: disable=protected-access tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, tree_complexity, min_node_weight, center_bias, pruning_mode) |