diff options
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/resources.cc')
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/resources.cc | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc index c410748c27..cc90bb2f45 100644 --- a/tensorflow/core/kernels/boosted_trees/resources.cc +++ b/tensorflow/core/kernels/boosted_trees/resources.cc @@ -21,6 +21,10 @@ limitations under the License. namespace tensorflow { +namespace { +constexpr float kLayerByLayerTreeWeight = 1.0; +} // namespace + // Constructor. BoostedTreesEnsembleResource::BoostedTreesEnsembleResource() : tree_ensemble_( @@ -78,6 +82,16 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id, } } +void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id, + const int32 node_id, + const float logits) { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size()); + auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id); + DCHECK(node->node_case() == boosted_trees::Node::kLeaf); + node->mutable_leaf()->set_scalar(logits); +} + int32 BoostedTreesEnsembleResource::GetNumLayersGrown( const int32 tree_id) const { DCHECK_LT(tree_id, tree_ensemble_->trees_size()); @@ -204,9 +218,14 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const { // Add a tree to the ensemble and returns a new tree_id. int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) { + return AddNewTreeWithLogits(weight, 0.0); +} + +int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight, + const float logits) { const int32 new_tree_id = tree_ensemble_->trees_size(); auto* node = tree_ensemble_->add_trees()->add_nodes(); - node->mutable_leaf()->set_scalar(0.0); + node->mutable_leaf()->set_scalar(logits); tree_ensemble_->add_tree_weights(weight); tree_ensemble_->add_tree_metadata(); @@ -225,7 +244,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode( *right_node_id = *left_node_id + 1; auto* left_node = tree->add_nodes(); auto* right_node = tree->add_nodes(); - if (node_id != 0) { + if (node_id != 0 || (node->has_leaf() && node->leaf().scalar() != 0)) { // Save previous leaf value if it is not the first leaf in the tree. node->mutable_metadata()->mutable_original_leaf()->Swap( node->mutable_leaf()); |