aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/prediction_ops.cc')
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc28
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index 2920132a27..b2efa06941 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -104,8 +104,8 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
const int32 latest_tree = resource->num_trees() - 1;
if (latest_tree < 0) {
- // Ensemble was empty. Nothing changes.
- output_node_ids = cached_node_ids;
+ // Ensemble was empty. Output the very first node.
+ output_node_ids.setZero();
output_tree_ids = cached_tree_ids;
// All the predictions are zeros.
output_partial_logits.setZero();
@@ -120,16 +120,20 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
int32 node_id = cached_node_ids(i);
float partial_tree_logit = 0.0;
- // If the tree was pruned, returns the node id into which the
- // current_node_id was pruned, as well the correction of the cached
- // logit prediction.
- resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
- &partial_tree_logit);
-
- // Logic in the loop adds the cached node value again if it is a leaf.
- // If it is not a leaf anymore we need to subtract the old node's
- // value. The following logic handles both of these cases.
- partial_tree_logit -= resource->node_value(tree_id, node_id);
+ if (node_id >= 0) {
+ // If the tree was pruned, returns the node id into which the
+ // current_node_id was pruned, as well the correction of the cached
+ // logit prediction.
+ resource->GetPostPruneCorrection(tree_id, node_id, &node_id,
+ &partial_tree_logit);
+ // Logic in the loop adds the cached node value again if it is a
+ // leaf. If it is not a leaf anymore we need to subtract the old
+ // node's value. The following logic handles both of these cases.
+ partial_tree_logit -= resource->node_value(tree_id, node_id);
+ } else {
+ // No cache exists, start from the very first node.
+ node_id = 0;
+ }
float partial_all_logit = 0.0;
while (true) {
if (resource->is_leaf(tree_id, node_id)) {