aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 11:08:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 11:12:23 -0700
commita586140da6d0460bbf18384556d5cc449b67b322 (patch)
tree0d645778e79fbed3eb3cb0e3004d949e585e283c /tensorflow/python/estimator
parent5330ede39fa2f1f7b3302bc316061baf180fab44 (diff)
Python interface for Boosted Trees model explainability (currently includes directional feature contributions); fixed ExampleDebugOutputs bug where it errors with empty trees.
PiperOrigin-RevId: 213658470
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/BUILD30
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py246
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py134
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_utils.py80
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_utils_test.py187
5 files changed, 655 insertions, 22 deletions
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index bfcc019dd5..7f2349954d 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -197,6 +197,7 @@ py_library(
srcs = ["canned/boosted_trees.py"],
srcs_version = "PY2AND3",
deps = [
+ ":boosted_trees_utils",
":estimator",
":head",
":model_fn",
@@ -224,6 +225,35 @@ py_test(
)
py_library(
+ name = "boosted_trees_utils",
+ srcs = ["canned/boosted_trees_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":estimator",
+ ":head",
+ ":model_fn",
+ "//tensorflow:tensorflow_py_no_contrib",
+ ],
+)
+
+py_test(
+ name = "boosted_trees_utils_test",
+ size = "medium",
+ srcs = ["canned/boosted_trees_utils_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "optonly",
+ ],
+ deps = [
+ ":boosted_trees",
+ ":inputs",
+ "//tensorflow:tensorflow_py_no_contrib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "dnn",
srcs = ["canned/dnn.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 19f18015e4..36048a2bfd 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -22,7 +22,8 @@ import collections
import functools
from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.canned import boosted_trees_utils
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import dtypes
@@ -36,6 +37,7 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.array_ops import identity as tf_identity
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
@@ -197,8 +199,7 @@ def _cache_transformed_features(features, sorted_feature_columns, batch_size):
cached_features = [
_local_variable(
array_ops.zeros([batch_size], dtype=dtypes.int32),
- name='cached_feature_{}'.format(i))
- for i in range(num_features)
+ name='cached_feature_{}'.format(i)) for i in range(num_features)
]
are_features_cached = _local_variable(False, name='are_features_cached')
@@ -228,8 +229,7 @@ def _cache_transformed_features(features, sorted_feature_columns, batch_size):
return cached, cache_flip_op
input_feature_list, cache_flip_op = control_flow_ops.cond(
- are_features_cached,
- lambda: (cached_features, control_flow_ops.no_op()),
+ are_features_cached, lambda: (cached_features, control_flow_ops.no_op()),
cache_features_and_return)
return input_feature_list, cache_flip_op
@@ -263,8 +263,8 @@ class _CacheTrainingStatesUsingHashTable(object):
elif dtypes.as_dtype(dtypes.string).is_compatible_with(example_ids.dtype):
empty_key = ''
else:
- raise ValueError('Unsupported example_id_feature dtype %s.' %
- example_ids.dtype)
+ raise ValueError(
+ 'Unsupported example_id_feature dtype %s.' % example_ids.dtype)
# Cache holds latest <tree_id, node_id, logits> for each example.
# tree_id and node_id are both int32 but logits is a float32.
# To reduce the overhead, we store all of them together as float32 and
@@ -273,8 +273,8 @@ class _CacheTrainingStatesUsingHashTable(object):
empty_key=empty_key, value_dtype=dtypes.float32, value_shape=[3])
self._example_ids = ops.convert_to_tensor(example_ids)
if self._example_ids.shape.ndims not in (None, 1):
- raise ValueError('example_id should have rank 1, but got %s' %
- self._example_ids)
+ raise ValueError(
+ 'example_id should have rank 1, but got %s' % self._example_ids)
self._logits_dimension = logits_dimension
def lookup(self):
@@ -334,7 +334,7 @@ class _CacheTrainingStatesUsingVariables(object):
array_ops.zeros([batch_size], dtype=dtypes.int32),
name='tree_ids_cache')
self._node_ids = _local_variable(
- _DUMMY_NODE_ID*array_ops.ones([batch_size], dtype=dtypes.int32),
+ _DUMMY_NODE_ID * array_ops.ones([batch_size], dtype=dtypes.int32),
name='node_ids_cache')
self._logits = _local_variable(
array_ops.zeros([batch_size, logits_dimension], dtype=dtypes.float32),
@@ -719,7 +719,7 @@ def _bt_model_fn(
tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
# Create logits.
- if mode != model_fn.ModeKeys.TRAIN:
+ if mode != model_fn_lib.ModeKeys.TRAIN:
input_feature_list = _get_transformed_features(features,
sorted_feature_columns)
logits = boosted_trees_ops.predict(
@@ -886,6 +886,7 @@ def _bt_model_fn(
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
+
# Add an early stop hook.
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
@@ -927,8 +928,8 @@ def _create_classification_head_and_closed_form(n_classes, weight_column,
label_vocabulary):
"""Creates a head for classifier and the closed form gradients/hessians."""
head = _create_classification_head(n_classes, weight_column, label_vocabulary)
- if (n_classes == 2 and head.logits_dimension == 1 and weight_column is None
- and label_vocabulary is None):
+ if (n_classes == 2 and head.logits_dimension == 1 and
+ weight_column is None and label_vocabulary is None):
# Use the closed-form gradients/hessians for 2 class.
def _grad_and_hess_for_logloss(logits, labels):
"""A closed form gradient and hessian for logistic loss."""
@@ -961,8 +962,196 @@ def _create_regression_head(label_dimension, weight_column=None):
# pylint: enable=protected-access
+def _bt_explanations_fn(features,
+ head,
+ sorted_feature_columns,
+ name='boosted_trees'):
+ """Gradient Boosted Trees predict with explanations model_fn.
+
+ Args:
+ features: dict of `Tensor`.
+ head: A `head_lib._Head` instance.
+ sorted_feature_columns: Sorted iterable of `feature_column._FeatureColumn`
+ model inputs.
+ name: Name used for the model.
+
+ Returns:
+ An `EstimatorSpec` instance.
+
+ Raises:
+ ValueError: mode or params are invalid, or features has the wrong type.
+ """
+ mode = model_fn_lib.ModeKeys.PREDICT
+ with ops.name_scope(name) as name:
+ # Create Ensemble resources.
+ tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
+
+ input_feature_list = _get_transformed_features(features,
+ sorted_feature_columns)
+
+ logits = boosted_trees_ops.predict(
+ # For non-TRAIN mode, ensemble doesn't change after initialization,
+ # so no local copy is needed; using tree_ensemble directly.
+ tree_ensemble_handle=tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=None,
+ train_op_fn=control_flow_ops.no_op,
+ logits=logits)
+
+ debug_op = boosted_trees_ops.example_debug_outputs(
+ tree_ensemble.resource_handle,
+ bucketized_features=input_feature_list,
+ logits_dimension=head.logits_dimension)
+ estimator_spec.predictions[boosted_trees_utils._DEBUG_PROTO_KEY] = debug_op # pylint: disable=protected-access
+ return estimator_spec
+
+
+class _BoostedTreesBase(estimator.Estimator):
+ """Base class for boosted trees estimators.
+
+ This class is intended to keep tree-specific functions (E.g., methods for
+ feature importances and directional feature contributions) in one central
+ place.
+
+ It is not a valid (working) Estimator on its own and should only be used as a
+ base class.
+ """
+
+ def __init__(self, model_fn, model_dir, config, feature_columns, head,
+ center_bias, is_classification):
+ """Initializes a `_BoostedTreesBase` instance.
+
+ Args:
+ model_fn: model_fn: Model function. See base class for more detail.
+ model_dir: Directory to save model parameters, graph and etc. See base
+ class for more detail.
+ config: `estimator.RunConfig` configuration object.
+ feature_columns: An iterable containing all the feature columns used by
+ the model. All items in the set should be instances of classes derived
+ from `FeatureColumn`
+ head: A `head_lib._Head` instance.
+ center_bias: Whether bias centering needs to occur. Bias centering refers
+ to the first node in the very first tree returning the prediction that
+ is aligned with the original labels distribution. For example, for
+ regression problems, the first node will return the mean of the labels.
+ For binary classification problems, it will return a logit for a prior
+ probability of label 1.
+ is_classification: If the estimator is for classification.
+ """
+ super(_BoostedTreesBase, self).__init__(
+ model_fn=model_fn, model_dir=model_dir, config=config)
+ self._sorted_feature_columns = sorted(
+ feature_columns, key=lambda tc: tc.name)
+ self._head = head
+ self._n_features = _calculate_num_features(self._sorted_feature_columns)
+ self._center_bias = center_bias
+ self._is_classification = is_classification
+
+ def experimental_predict_with_explanations(self,
+ input_fn,
+ predict_keys=None,
+ hooks=None,
+ checkpoint_path=None):
+ """Computes model explainability outputs per example along with predictions.
+
+ Currently supports directional feature contributions (DFCs). For each
+ instance, DFCs indicate the aggregate contribution of each feature. See
+ https://arxiv.org/abs/1312.1121 and
+ http://blog.datadive.net/interpreting-random-forests/ for more details.
+ Args:
+ input_fn: A function that provides input data for predicting as
+ minibatches. See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ for more information. The function should construct and return one of
+ the following: * A `tf.data.Dataset` object: Outputs of `Dataset`
+ object must be a tuple `(features, labels)` with same constraints as
+ below. * A tuple `(features, labels)`: Where `features` is a `tf.Tensor`
+ or a dictionary of string feature name to `Tensor` and `labels` is a
+ `Tensor` or a dictionary of string label name to `Tensor`. Both
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
+ predict_keys: list of `str`, name of the keys to predict. It is used if
+ the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
+ `predict_keys` is used then rest of the predictions will be filtered
+ from the dictionary, with the exception of 'bias' and 'dfc', which will
+ always be in the dictionary. If `None`, returns all keys in prediction
+ dict, as well as two new keys 'dfc' and 'bias'.
+ hooks: List of `tf.train.SessionRunHook` subclass instances. Used for
+ callbacks inside the prediction call.
+ checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
+ latest checkpoint in `model_dir` is used. If there are no checkpoints
+ in `model_dir`, prediction is run with newly initialized `Variables`
+ instead of ones restored from checkpoint.
+
+ Yields:
+ Evaluated values of `predictions` tensors. The `predictions` tensors will
+ contain at least two keys 'dfc' and 'bias' for model explanations. The
+ `dfc` value corresponds to the contribution of each feature to the overall
+ prediction for this instance (positive indicating that the feature makes
+ it more likely to select class 1 and negative less likely). The 'bias'
+ value will be the same across all the instances, corresponding to the
+ probability (classification) or prediction (regression) of the training
+ data distribution.
+
+ Raises:
+ ValueError: when wrong arguments are given or unsupported functionalities
+ are requested.
+ """
+ if not self._center_bias:
+ raise ValueError('center_bias must be enabled during estimator '
+ 'instantiation when using '
+ 'experimental_predict_with_explanations.')
+ # pylint: disable=protected-access
+ if not self._is_classification:
+ identity_inverse_link_fn = self._head._inverse_link_fn in (None,
+ tf_identity)
+ # pylint:enable=protected-access
+ if not identity_inverse_link_fn:
+ raise ValueError(
+ 'For now only identity inverse_link_fn in regression_head is '
+ 'supported for experimental_predict_with_explanations.')
+
+ # pylint:disable=unused-argument
+ def new_model_fn(features, labels, mode):
+ return _bt_explanations_fn(features, self._head,
+ self._sorted_feature_columns)
+
+ # pylint:enable=unused-argument
+ est = estimator.Estimator(
+ model_fn=new_model_fn,
+ model_dir=self.model_dir,
+ config=self.config,
+ warm_start_from=self._warm_start_settings)
+ # Make sure bias and dfc will be in prediction dict.
+ user_supplied_predict_keys = predict_keys is not None
+ if user_supplied_predict_keys:
+ predict_keys = set(predict_keys)
+ predict_keys.add(boosted_trees_utils._DEBUG_PROTO_KEY)
+ predictions = est.predict(
+ input_fn,
+ predict_keys=predict_keys,
+ hooks=hooks,
+ checkpoint_path=checkpoint_path,
+ yield_single_examples=True)
+ for pred in predictions:
+ bias, dfcs = boosted_trees_utils._parse_explanations_from_prediction(
+ pred[boosted_trees_utils._DEBUG_PROTO_KEY], self._n_features,
+ self._is_classification)
+ pred['bias'] = bias
+ pred['dfc'] = dfcs
+ # Don't need to expose serialized proto to end user.
+ del pred[boosted_trees_utils._DEBUG_PROTO_KEY]
+ yield pred
+
+
+# pylint: disable=protected-access
@estimator_export('estimator.BoostedTreesClassifier')
-class BoostedTreesClassifier(estimator.Estimator):
+class BoostedTreesClassifier(_BoostedTreesBase):
"""A Classifier for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1082,14 +1271,13 @@ class BoostedTreesClassifier(estimator.Estimator):
n_classes = 2
head, closed_form = _create_classification_head_and_closed_form(
n_classes, weight_column, label_vocabulary=label_vocabulary)
-
# HParams for the model.
tree_hparams = _TreeHParams(
n_trees, max_depth, learning_rate, l1_regularization, l2_regularization,
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
+ return _bt_model_fn(
features,
labels,
mode,
@@ -1101,11 +1289,17 @@ class BoostedTreesClassifier(estimator.Estimator):
closed_form_grad_and_hess_fn=closed_form)
super(BoostedTreesClassifier, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=True)
@estimator_export('estimator.BoostedTreesRegressor')
-class BoostedTreesRegressor(estimator.Estimator):
+class BoostedTreesRegressor(_BoostedTreesBase):
"""A Regressor for Tensorflow Boosted Trees models.
@compatibility(eager)
@@ -1223,9 +1417,17 @@ class BoostedTreesRegressor(estimator.Estimator):
tree_complexity, min_node_weight, center_bias, pruning_mode)
def _model_fn(features, labels, mode, config):
- return _bt_model_fn( # pylint: disable=protected-access
- features, labels, mode, head, feature_columns, tree_hparams,
- n_batches_per_layer, config)
+ return _bt_model_fn(features, labels, mode, head, feature_columns,
+ tree_hparams, n_batches_per_layer, config)
super(BoostedTreesRegressor, self).__init__(
- model_fn=_model_fn, model_dir=model_dir, config=config)
+ model_fn=_model_fn,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns,
+ head=head,
+ center_bias=center_bias,
+ is_classification=False)
+
+
+# pylint: enable=protected-access
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 6e28c72151..9409cb5cc7 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -565,6 +565,140 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
self.assertEqual(0, ensemble.trees[0].nodes[0].bucketized_split.threshold)
+class BoostedTreesDebugOutputsTest(test_util.TensorFlowTestCase):
+ """Test debug/model explainability outputs for individual predictions.
+
+ Includes directional feature contributions (DFC).
+ """
+
+ def setUp(self):
+ self._feature_columns = {
+ feature_column.bucketized_column(
+ feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
+ BUCKET_BOUNDARIES) for i in range(NUM_FEATURES)
+ }
+
+ def testBinaryClassifierThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=True)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=3, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesClassifier(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([0.4] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: 0.19650601422250574,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: 0.16057487356133376,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }, {
+ 0: -0.12108613453574479,
+ 1: 0.0,
+ 2: -0.039254929814481143
+ }, {
+ 0: -0.10832468554550384,
+ 1: 0.0,
+ 2: 0.02693827052766018
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == probabilities.
+ expected_probabilities = [
+ 0.23965894, 0.62344426, 0.58751315, 0.23965894, 0.31861359
+ ]
+ probabilities = [
+ sum(dfc.values()) + bias for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_probabilities, probabilities)
+
+ # When user doesn't include bias or dfc in predict_keys, make sure to still
+ # include dfc and bias.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['probabilities'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('probabilities' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+ def testRegressorThatDFCIsInPredictions(self):
+ train_input_fn = _make_train_input_fn(is_classification=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False)
+
+ est = boosted_trees.BoostedTreesRegressor(
+ feature_columns=self._feature_columns,
+ n_batches_per_layer=1,
+ n_trees=1,
+ max_depth=5,
+ center_bias=True)
+
+ num_steps = 100
+ # Train for a few steps. Validate debug outputs in prediction dicts.
+ est.train(train_input_fn, steps=num_steps)
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn)
+ biases, dfcs = zip(*[(pred['bias'], pred['dfc'])
+ for pred in debug_predictions])
+ self.assertAllClose([1.8] * 5, biases)
+ self.assertAllClose(({
+ 0: -0.070499420166015625,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: -0.53763031959533691,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: -0.51756942272186279,
+ 1: -0.095000028610229492,
+ 2: 0.0
+ }, {
+ 0: 0.1563495397567749,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }, {
+ 0: 0.96934974193572998,
+ 1: 0.063333392143249512,
+ 2: 0.0
+ }), dfcs)
+
+ # Assert sum(dfcs) + bias == predictions.
+ expected_predictions = [[1.6345005], [1.32570302], [1.1874305],
+ [2.01968288], [2.83268309]]
+ predictions = [
+ [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases)
+ ]
+ self.assertAllClose(expected_predictions, predictions)
+
+ # Test when user doesn't include bias or dfc in predict_keys.
+ debug_predictions = est.experimental_predict_with_explanations(
+ predict_input_fn, predict_keys=['predictions'])
+ for prediction_dict in debug_predictions:
+ self.assertTrue('bias' in prediction_dict)
+ self.assertTrue('dfc' in prediction_dict)
+ self.assertTrue('predictions' in prediction_dict)
+ self.assertEqual(len(prediction_dict), 3)
+
+
class ModelFnTests(test_util.TensorFlowTestCase):
"""Tests bt_model_fn including unexposed internal functionalities."""
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils.py b/tensorflow/python/estimator/canned/boosted_trees_utils.py
new file mode 100644
index 0000000000..85efc2304a
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Debug and model explainability logic for boosted trees."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
+
+# For directional feature contributions.
+_DEBUG_PROTO_KEY = '_serialized_debug_outputs_proto'
+_BIAS_ID = 0
+
+
+def _parse_debug_proto_string(example_proto_serialized):
+ example_debug_outputs = boosted_trees_pb2.DebugOutput()
+ example_debug_outputs.ParseFromString(example_proto_serialized)
+ feature_ids = example_debug_outputs.feature_ids
+ logits_path = example_debug_outputs.logits_path
+ return feature_ids, logits_path
+
+
+def _compute_directional_feature_contributions(example_feature_ids,
+ example_logits_paths, activation,
+ num_bucketized_features):
+ """Directional feature contributions and bias, per example."""
+ # Initialize contributions to 0.
+ dfcs = {k: 0 for k in range(num_bucketized_features)}
+
+ # Traverse tree subtracting child prediction from parent prediction and
+ # associating change with feature id used to split.
+ predictions = np.array(activation(example_logits_paths))
+ delta_pred = predictions[_BIAS_ID + 1:] - predictions[:-1]
+ # Group by feature id, then sum delta_pred.
+ contribs = np.bincount(
+ example_feature_ids,
+ weights=delta_pred,
+ minlength=num_bucketized_features)
+ for f, dfc in zip(range(num_bucketized_features), contribs):
+ dfcs[f] = dfc
+ return predictions[_BIAS_ID], dfcs
+
+
+def _identity(logits):
+ return logits
+
+
+def _sigmoid(logits):
+ # TODO(crawles): Change to softmax once multiclass support is available.
+ return 1 / (1 + np.exp(-np.array(logits)))
+
+
+def _parse_explanations_from_prediction(serialized_debug_proto,
+ n_features,
+ classification=False):
+ """Parse serialized explanability proto, compute dfc, and return bias, dfc."""
+ feature_ids, logits_path = _parse_debug_proto_string(serialized_debug_proto)
+ if classification:
+ activation = _sigmoid
+ else:
+ activation = _identity
+ bias, dfcs = _compute_directional_feature_contributions(
+ feature_ids, logits_path, activation, n_features)
+ # TODO(crawles): Prediction path and leaf IDs.
+ return bias, dfcs
diff --git a/tensorflow/python/estimator/canned/boosted_trees_utils_test.py b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
new file mode 100644
index 0000000000..506d4ea6fb
--- /dev/null
+++ b/tensorflow/python/estimator/canned/boosted_trees_utils_test.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests boosted_trees estimators and model_fn."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.estimator.canned import boosted_trees_utils
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class BoostedTreesDFCTest(test_util.TensorFlowTestCase):
+ """Test directional feature contributions (DFC) helper functions. """
+
+ def testDirectionalFeatureContributionsCompute(self):
+ """Tests logic to compute DFCs given feature ids and logits paths."""
+ num_bucketized_features = 3 # Includes one unused feature.
+ examples_feature_ids = ((2, 2, 0, 0), (2, 2, 0))
+ e1_feature_ids, e2_feature_ids = examples_feature_ids
+
+ # DFCs are computed by traversing the prediction path and subtracting each
+ # child prediction from its parent prediction and associating the change in
+ # prediction with the respective feature id used for the split.
+ # For each activation function, f, (currently identity or sigmoid), DFCs are
+ # calculated for the two examples as:
+ # example 1:
+ # feature_0 = (f(1.114) - f(1.214)) + (f(6.114) - f(1.114))
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.214) - f(0.114))
+ # example 2:
+ # feature_0 = f(-5.486) - f(1.514)
+ # feature_1 = 0 # Feature not in ensemble, thus zero contrib.
+ # feature_2 = (f(0.114) - bias_pred) + (f(1.514) - f(0.114))
+ # where bias_pred is = f(0) or f(0.21), with center_bias = {True, False},
+ # respectively.
+ # Keys are center_bias.
+ expected_dfcs_identity = {
+ False: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.214
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.514
+ }),
+ True: ({
+ 0: 4.9,
+ 1: 0,
+ 2: 1.0039999999999998
+ }, {
+ 0: -7.0,
+ 1: 0,
+ 2: 1.3039999999999998
+ })
+ }
+ expected_dfcs_sigmoid = {
+ False: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2710059376234506
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.319653250251275
+ }),
+ True: ({
+ 0: 0.22678725678805578,
+ 1: 0,
+ 2: 0.2186980280491253
+ }, {
+ 0: -0.81552596670046507,
+ 1: 0,
+ 2: 0.26734534067694971
+ })
+ }
+ # pylint: disable=protected-access
+ for f, expected_dfcs in zip(
+ (boosted_trees_utils._identity, boosted_trees_utils._sigmoid),
+ (expected_dfcs_identity, expected_dfcs_sigmoid)):
+ for center_bias in [False, True]:
+ # If not center_bias, the bias after activation is 0.
+ if center_bias:
+ bias_logit = 0.21 # Root node of tree_0.
+ else:
+ bias_logit = 0 # 0 is default value when there is no original_leaf.
+ f_bias = f(bias_logit)
+
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ examples_logits_paths = ((bias_logit, 0.114, 1.214, 1.114, 6.114),
+ (bias_logit, 0.114, 1.514, -5.486))
+ e1_logits_path, e2_logits_path = examples_logits_paths
+ e1_expected_dfcs, e2_expected_dfcs = expected_dfcs[center_bias]
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint:disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, f, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, f, num_bucketized_features)
+ # pylint:enable=line-too-long
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (e1_logits_path[-1], e2_logits_path[-1])
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [f(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ preds = [e1_pred, e2_pred]
+ self.assertAllClose(preds, expected_preds)
+ # pylint: enable=protected-access
+
+ def testDFCComputeComparedToExternalExample(self):
+ """Tests `compute_dfc` compared to external example (regression).
+
+ Example from http://blog.datadive.net/interpreting-random-forests.
+ """
+ # DIS:3, RM: 2, LSTAT:1, NOX:0
+ num_bucketized_features = 4
+ e1_feature_ids = (2, 1, 0)
+ e2_feature_ids = (2, 2, 2)
+ e3_feature_ids = (2, 2, 0)
+
+ bias_logit = 22.60 # Root node of tree_0.
+ activation = boosted_trees_utils._identity
+ f_bias = activation(bias_logit)
+ # Logits before and after, as is outputed from
+ # boosted_trees_ops.example_debug_outputs
+ e1_logits_path = (bias_logit, 19.96, 14.91, 18.11)
+ e2_logits_path = (bias_logit, 37.42, 45.10, 45.90)
+ e3_logits_path = (bias_logit, 37.42, 32.30, 33.58)
+ e1_expected_dfcs = {0: 3.20, 1: -5.05, 2: -2.64, 3: 0}
+ e2_expected_dfcs = {0: 0, 1: 0, 2: 23.3, 3: 0}
+ e3_expected_dfcs = {0: 1.28, 1: 0, 2: 9.7, 3: 0}
+ # Check feature contributions are correct for both examples.
+ # Example 1.
+ # pylint: disable=protected-access
+ # pylint: disable=line-too-long
+ e1_bias, e1_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e1_feature_ids, e1_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e1_bias, f_bias)
+ self.assertAllClose(e1_dfc, e1_expected_dfcs)
+ # Example 2.
+ e2_bias, e2_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e2_feature_ids, e2_logits_path, activation, num_bucketized_features)
+ self.assertAllClose(e2_bias, f_bias)
+ self.assertAllClose(e2_dfc, e2_expected_dfcs)
+ # Example 3.
+ e3_bias, e3_dfc = boosted_trees_utils._compute_directional_feature_contributions(
+ e3_feature_ids, e3_logits_path, activation, num_bucketized_features)
+ # pylint: enable=line-too-long
+ self.assertAllClose(e3_bias, f_bias)
+ self.assertAllClose(e3_dfc, e3_expected_dfcs)
+ # pylint: enable=protected-access
+ # Check if contributions sum to final prediction.
+ # For each tree, get leaf of last tree.
+ expected_logits = (18.11, 45.90, 33.58)
+ # Predictions should be the sum of contributions + bias.
+ expected_preds = [activation(logit) for logit in expected_logits]
+ e1_pred = e1_bias + sum(e1_dfc.values())
+ e2_pred = e2_bias + sum(e2_dfc.values())
+ e3_pred = e3_bias + sum(e3_dfc.values())
+ preds = [e1_pred, e2_pred, e3_pred]
+ self.assertAllClose(preds, expected_preds)
+
+
+if __name__ == '__main__':
+ googletest.main()