diff options
author | 2018-09-19 11:08:53 -0700 | |
---|---|---|
committer | 2018-09-19 11:12:23 -0700 | |
commit | a586140da6d0460bbf18384556d5cc449b67b322 (patch) | |
tree | 0d645778e79fbed3eb3cb0e3004d949e585e283c /tensorflow/core/kernels/boosted_trees | |
parent | 5330ede39fa2f1f7b3302bc316061baf180fab44 (diff) |
Python interface for Boosted Trees model explainability (currently includes directional feature contributions); fixed ExampleDebugOutputs bug where it errors with empty trees.
PiperOrigin-RevId: 213658470
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees')
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/prediction_ops.cc | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc index b2efa06941..4ae26fb95b 100644 --- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc @@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel { // Proto to store debug outputs, per example. boosted_trees::DebugOutput example_debug_info; // Initial bias prediction. E.g., prediction based off training mean. - example_debug_info.add_logits_path(resource->GetTreeWeight(0) * - resource->node_value(0, 0)); + float tree_logit = + resource->GetTreeWeight(0) * resource->node_value(0, 0); + example_debug_info.add_logits_path(tree_logit); int32 node_id = 0; int32 tree_id = 0; int32 feature_id; - float tree_logit; float past_trees_logit = 0; // Sum of leaf logits from prior trees. - // Populate proto. + // Go through each tree and populate proto. while (tree_id <= last_tree) { - // Feature id used to split. - feature_id = resource->feature_id(tree_id, node_id); - example_debug_info.add_feature_ids(feature_id); - // Get logit after split. - node_id = resource->next_node(tree_id, node_id, i, - batch_bucketized_features); - tree_logit = resource->GetTreeWeight(tree_id) * - resource->node_value(tree_id, node_id); - // Output logit incorporates sum of leaf logits from prior trees. - example_debug_info.add_logits_path(tree_logit + past_trees_logit); - if (resource->is_leaf(tree_id, node_id)) { - // Move onto other trees. - past_trees_logit += tree_logit; + if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees. + // Accumulate tree_logits only if the leaf is non-root, but do so + // for bias tree. + if (tree_id == 0 || node_id > 0) { + past_trees_logit += tree_logit; + } ++tree_id; node_id = 0; + } else { // Add to proto. + // Feature id used to split. + feature_id = resource->feature_id(tree_id, node_id); + example_debug_info.add_feature_ids(feature_id); + // Get logit after split. + node_id = resource->next_node(tree_id, node_id, i, + batch_bucketized_features); + tree_logit = resource->GetTreeWeight(tree_id) * + resource->node_value(tree_id, node_id); + // Output logit incorporates sum of leaf logits from prior trees. + example_debug_info.add_logits_path(tree_logit + past_trees_logit); } } // Set output as serialized proto containing debug info. |