diff options
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/prediction_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/prediction_ops.cc | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc index 2920132a27..b2efa06941 100644 --- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc @@ -104,8 +104,8 @@ class BoostedTreesTrainingPredictOp : public OpKernel { const int32 latest_tree = resource->num_trees() - 1; if (latest_tree < 0) { - // Ensemble was empty. Nothing changes. - output_node_ids = cached_node_ids; + // Ensemble was empty. Output the very first node. + output_node_ids.setZero(); output_tree_ids = cached_tree_ids; // All the predictions are zeros. output_partial_logits.setZero(); @@ -120,16 +120,20 @@ class BoostedTreesTrainingPredictOp : public OpKernel { int32 node_id = cached_node_ids(i); float partial_tree_logit = 0.0; - // If the tree was pruned, returns the node id into which the - // current_node_id was pruned, as well the correction of the cached - // logit prediction. - resource->GetPostPruneCorrection(tree_id, node_id, &node_id, - &partial_tree_logit); - - // Logic in the loop adds the cached node value again if it is a leaf. - // If it is not a leaf anymore we need to subtract the old node's - // value. The following logic handles both of these cases. - partial_tree_logit -= resource->node_value(tree_id, node_id); + if (node_id >= 0) { + // If the tree was pruned, returns the node id into which the + // current_node_id was pruned, as well the correction of the cached + // logit prediction. + resource->GetPostPruneCorrection(tree_id, node_id, &node_id, + &partial_tree_logit); + // Logic in the loop adds the cached node value again if it is a + // leaf. If it is not a leaf anymore we need to subtract the old + // node's value. The following logic handles both of these cases. + partial_tree_logit -= resource->node_value(tree_id, node_id); + } else { + // No cache exists, start from the very first node. + node_id = 0; + } float partial_all_logit = 0.0; while (true) { if (resource->is_leaf(tree_id, node_id)) { |