aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/resources
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-18 01:35:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-18 01:38:27 -0800
commit72a4e62f67496bb7d60fa82cbdcf4ed710a6762f (patch)
tree88332d98df8880758a34792bbe09c33374aeafc7 /tensorflow/contrib/boosted_trees/resources
parente19deb4e7af968cba1acdf9df682493ce4052ddd (diff)
Basic feature selection for boosted trees.
The idea is that we grow the trees normally until we reach the requested number of unique features, and once we reach that limit, we avoid using any new features. PiperOrigin-RevId: 182336278
Diffstat (limited to 'tensorflow/contrib/boosted_trees/resources')
-rw-r--r--tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h
index 284ad5cdb9..ad9c8961aa 100644
--- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h
+++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h
@@ -111,6 +111,35 @@ class DecisionTreeEnsembleResource : public StampedResource {
return decision_tree_ensemble_->tree_weights(index);
}
+ void MaybeAddUsedHandler(const int32 handler_id) {
+ protobuf::RepeatedField<protobuf_int64>* used_ids =
+ decision_tree_ensemble_->mutable_growing_metadata()
+ ->mutable_used_handler_ids();
+ protobuf::RepeatedField<protobuf_int64>::iterator first =
+ std::lower_bound(used_ids->begin(), used_ids->end(), handler_id);
+ if (first == used_ids->end()) {
+ used_ids->Add(handler_id);
+ return;
+ }
+ if (handler_id == *first) {
+ // It is a duplicate entry.
+ return;
+ }
+ used_ids->Add(handler_id);
+ std::rotate(first, used_ids->end() - 1, used_ids->end());
+ }
+
+ std::vector<int64> GetUsedHandlers() const {
+ std::vector<int64> result;
+ result.reserve(
+ decision_tree_ensemble_->growing_metadata().used_handler_ids().size());
+ for (int64 h :
+ decision_tree_ensemble_->growing_metadata().used_handler_ids()) {
+ result.push_back(h);
+ }
+ return result;
+ }
+
// Sets the weight of i'th tree, and increment num_updates in tree_metadata.
void SetTreeWeight(const int32 index, const float weight,
const int32 increment_num_updates) {