aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/boosted_trees/resources.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/boosted_trees/resources.cc')
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc23
1 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index c410748c27..cc90bb2f45 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -21,6 +21,10 @@ limitations under the License.
namespace tensorflow {
+namespace {
+constexpr float kLayerByLayerTreeWeight = 1.0;
+} // namespace
+
// Constructor.
BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
: tree_ensemble_(
@@ -78,6 +82,16 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
}
}
+void BoostedTreesEnsembleResource::set_node_value(const int32 tree_id,
+ const int32 node_id,
+ const float logits) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ auto* node = tree_ensemble_->mutable_trees(tree_id)->mutable_nodes(node_id);
+ DCHECK(node->node_case() == boosted_trees::Node::kLeaf);
+ node->mutable_leaf()->set_scalar(logits);
+}
+
int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
const int32 tree_id) const {
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
@@ -204,9 +218,14 @@ void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
// Add a tree to the ensemble and returns a new tree_id.
int32 BoostedTreesEnsembleResource::AddNewTree(const float weight) {
+ return AddNewTreeWithLogits(weight, 0.0);
+}
+
+int32 BoostedTreesEnsembleResource::AddNewTreeWithLogits(const float weight,
+ const float logits) {
const int32 new_tree_id = tree_ensemble_->trees_size();
auto* node = tree_ensemble_->add_trees()->add_nodes();
- node->mutable_leaf()->set_scalar(0.0);
+ node->mutable_leaf()->set_scalar(logits);
tree_ensemble_->add_tree_weights(weight);
tree_ensemble_->add_tree_metadata();
@@ -225,7 +244,7 @@ void BoostedTreesEnsembleResource::AddBucketizedSplitNode(
*right_node_id = *left_node_id + 1;
auto* left_node = tree->add_nodes();
auto* right_node = tree->add_nodes();
- if (node_id != 0) {
+ if (node_id != 0 || (node->has_leaf() && node->leaf().scalar() != 0)) {
// Save previous leaf value if it is not the first leaf in the tree.
node->mutable_metadata()->mutable_original_leaf()->Swap(
node->mutable_leaf());