aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 11:08:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 11:12:23 -0700
commita586140da6d0460bbf18384556d5cc449b67b322 (patch)
tree0d645778e79fbed3eb3cb0e3004d949e585e283c /tensorflow/core/kernels/boosted_trees
parent5330ede39fa2f1f7b3302bc316061baf180fab44 (diff)
Python interface for Boosted Trees model explainability (currently includes directional feature contributions); fixed ExampleDebugOutputs bug where it errors with empty trees.
PiperOrigin-RevId: 213658470
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees')
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc38
1 files changed, 21 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06941..4ae26fb95b 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
- example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
- resource->node_value(0, 0));
+ float tree_logit =
+ resource->GetTreeWeight(0) * resource->node_value(0, 0);
+ example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
- float tree_logit;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
- // Populate proto.
+ // Go through each tree and populate proto.
while (tree_id <= last_tree) {
- // Feature id used to split.
- feature_id = resource->feature_id(tree_id, node_id);
- example_debug_info.add_feature_ids(feature_id);
- // Get logit after split.
- node_id = resource->next_node(tree_id, node_id, i,
- batch_bucketized_features);
- tree_logit = resource->GetTreeWeight(tree_id) *
- resource->node_value(tree_id, node_id);
- // Output logit incorporates sum of leaf logits from prior trees.
- example_debug_info.add_logits_path(tree_logit + past_trees_logit);
- if (resource->is_leaf(tree_id, node_id)) {
- // Move onto other trees.
- past_trees_logit += tree_logit;
+ if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
+ // Accumulate tree_logits only if the leaf is non-root, but do so
+ // for bias tree.
+ if (tree_id == 0 || node_id > 0) {
+ past_trees_logit += tree_logit;
+ }
++tree_id;
node_id = 0;
+ } else { // Add to proto.
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.