aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 11:08:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 11:12:23 -0700
commita586140da6d0460bbf18384556d5cc449b67b322 (patch)
tree0d645778e79fbed3eb3cb0e3004d949e585e283c /tensorflow/contrib/estimator/python
parent5330ede39fa2f1f7b3302bc316061baf180fab44 (diff)
Python interface for Boosted Trees model explainability (currently includes directional feature contributions); fixed ExampleDebugOutputs bug where it errors with empty trees.
PiperOrigin-RevId: 213658470
Diffstat (limited to 'tensorflow/contrib/estimator/python')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py28
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py74
2 files changed, 98 insertions, 4 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index 7ed77bcce6..11f60c8238 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
+from tensorflow.python.estimator.canned import head as head_lib
def _validate_input_fn_and_repeat_dataset(train_input_fn):
@@ -33,7 +34,18 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn):
return _input_fn
-class _BoostedTreesEstimator(estimator.Estimator):
+# pylint: disable=protected-access
+def _is_classification_head(head):
+ """Infers if the head is a classification head."""
+ # Check using all classification heads defined in canned/head.py. However, it
+ # is not a complete list - it does not check for other classification heads
+ # not defined in the head library.
+ return isinstance(head,
+ (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss,
+ head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss))
+
+
+class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase):
"""An Estimator for Tensorflow Boosted Trees models."""
def __init__(self,
@@ -96,8 +108,10 @@ class _BoostedTreesEstimator(estimator.Estimator):
negative gain). For pre and post pruning, you MUST provide
tree_complexity >0.
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
"""
- # pylint:disable=protected-access
# HParams for the model.
tree_hparams = canned_boosted_trees._TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
@@ -115,8 +129,14 @@ class _BoostedTreesEstimator(estimator.Estimator):
config=config)
super(_BoostedTreesEstimator, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
- # pylint:enable=protected-access
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=_is_classification_head(head))
+ # pylint: enable=protected-access
def boosted_trees_classifier_train_in_memory(
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index b1581f3750..e23d9c0fc4 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -360,5 +360,79 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
[pred['predictions'] for pred in predictions])
+class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testContribEstimatorThatDFCIsInPredictions(self):
+ # pylint:disable=protected-access
+ head = canned_boosted_trees._create_regression_head(label_dimension=1)
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees._BoostedTreesEstimator(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ head=head,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+ # pylint:enable=protected-access
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+
if __name__ == '__main__':
googletest.main()