aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:41:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 10:41:13 -0700
commit03cf21a5202b8515d77fdaee3184fd20da2a201c (patch)
tree3ebbdbfafcd69f0ccf9d3beb619abe8cb278c92f /tensorflow/contrib/estimator
parent410905d8e8af12e928031aa026683e43b665c8ae (diff)
parent046c74c8e7c68aaa726977dd6e8a2523f854f9cc (diff)
Merge pull request #21509 from facaiy:ENH/feature_importances_for_boosted_tree
PiperOrigin-RevId: 214462540
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 11f60c8238..a1f1c5f3d7 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -34,18 +34,19 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn):
return _input_fn
-# pylint: disable=protected-access
def _is_classification_head(head):
"""Infers if the head is a classification head."""
# Check using all classification heads defined in canned/head.py. However, it
# is not a complete list - it does not check for other classification heads
# not defined in the head library.
+ # pylint: disable=protected-access
return isinstance(head,
(head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+ # pylint: enable=protected-access
-class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint: disable=protected-access
"""An Estimator for Tensorflow Boosted Trees models."""
def __init__(self,
@@ -113,6 +114,7 @@ class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
are requested.
"""
# HParams for the model.
+ # pylint: disable=protected-access
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)