diff options
author | 2018-01-18 01:35:01 -0800 | |
---|---|---|
committer | 2018-01-18 01:38:27 -0800 | |
commit | 72a4e62f67496bb7d60fa82cbdcf4ed710a6762f (patch) | |
tree | 88332d98df8880758a34792bbe09c33374aeafc7 /tensorflow/contrib/boosted_trees/resources | |
parent | e19deb4e7af968cba1acdf9df682493ce4052ddd (diff) |
Basic feature selection for boosted trees.
The idea is that we grow the trees normally until we reach the requested number of unique features, and once we reach that limit, we avoid using any new features.
PiperOrigin-RevId: 182336278
Diffstat (limited to 'tensorflow/contrib/boosted_trees/resources')
-rw-r--r-- | tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 284ad5cdb9..ad9c8961aa 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -111,6 +111,35 @@ class DecisionTreeEnsembleResource : public StampedResource { return decision_tree_ensemble_->tree_weights(index); } + void MaybeAddUsedHandler(const int32 handler_id) { + protobuf::RepeatedField<protobuf_int64>* used_ids = + decision_tree_ensemble_->mutable_growing_metadata() + ->mutable_used_handler_ids(); + protobuf::RepeatedField<protobuf_int64>::iterator first = + std::lower_bound(used_ids->begin(), used_ids->end(), handler_id); + if (first == used_ids->end()) { + used_ids->Add(handler_id); + return; + } + if (handler_id == *first) { + // It is a duplicate entry. + return; + } + used_ids->Add(handler_id); + std::rotate(first, used_ids->end() - 1, used_ids->end()); + } + + std::vector<int64> GetUsedHandlers() const { + std::vector<int64> result; + result.reserve( + decision_tree_ensemble_->growing_metadata().used_handler_ids().size()); + for (int64 h : + decision_tree_ensemble_->growing_metadata().used_handler_ids()) { + result.push_back(h); + } + return result; + } + // Sets the weight of i'th tree, and increment num_updates in tree_metadata. void SetTreeWeight(const int32 index, const float weight, const int32 increment_num_updates) { |