diff options
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/resources.cc | 138 | ||||
-rw-r--r-- | tensorflow/core/kernels/boosted_trees/resources.h | 154 |
2 files changed, 178 insertions, 114 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc index 2ea12c522c..c410748c27 100644 --- a/tensorflow/core/kernels/boosted_trees/resources.cc +++ b/tensorflow/core/kernels/boosted_trees/resources.cc @@ -21,6 +21,35 @@ limitations under the License. namespace tensorflow { +// Constructor. +BoostedTreesEnsembleResource::BoostedTreesEnsembleResource() + : tree_ensemble_( + protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>( + &arena_)) {} + +string BoostedTreesEnsembleResource::DebugString() { + return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(), + "]"); +} + +bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized, + const int64 stamp_token) { + CHECK_EQ(stamp(), -1) << "Must Reset before Init."; + if (ParseProtoUnlimited(tree_ensemble_, serialized)) { + set_stamp(stamp_token); + return true; + } + return false; +} + +string BoostedTreesEnsembleResource::SerializeAsString() const { + return tree_ensemble_->SerializeAsString(); +} + +int32 BoostedTreesEnsembleResource::num_trees() const { + return tree_ensemble_->trees_size(); +} + int32 BoostedTreesEnsembleResource::next_node( const int32 tree_id, const int32 node_id, const int32 index_in_batch, const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const { @@ -49,6 +78,115 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id, } } +int32 BoostedTreesEnsembleResource::GetNumLayersGrown( + const int32 tree_id) const { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + return tree_ensemble_->tree_metadata(tree_id).num_layers_grown(); +} + +void BoostedTreesEnsembleResource::SetNumLayersGrown( + const int32 tree_id, int32 new_num_layers) const { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown( + new_num_layers); +} + +void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange( + const int32 node_range_start, int32 node_range_end) const { + tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start( + node_range_start); + tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end( + node_range_end); +} + +void BoostedTreesEnsembleResource::GetLastLayerNodesRange( + int32* node_range_start, int32* node_range_end) const { + *node_range_start = + tree_ensemble_->growing_metadata().last_layer_node_start(); + *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end(); +} + +int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + return tree_ensemble_->trees(tree_id).nodes_size(); +} + +int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() { + return tree_ensemble_->growing_metadata().num_layers_attempted(); +} + +bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id, + const int32 node_id) const { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size()); + const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id); + return node.node_case() == boosted_trees::Node::kLeaf; +} + +int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id, + const int32 node_id) const { + const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); + DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); + return node.bucketized_split().feature_id(); +} + +int32 BoostedTreesEnsembleResource::bucket_threshold( + const int32 tree_id, const int32 node_id) const { + const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); + DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); + return node.bucketized_split().threshold(); +} + +int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id, + const int32 node_id) const { + const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); + DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); + return node.bucketized_split().left_id(); +} + +int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id, + const int32 node_id) const { + const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); + DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); + return node.bucketized_split().right_id(); +} + +std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const { + return {tree_ensemble_->tree_weights().begin(), + tree_ensemble_->tree_weights().end()}; +} + +float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const { + return tree_ensemble_->tree_weights(tree_id); +} + +float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + return tree_ensemble_->tree_metadata(tree_id).is_finalized(); +} + +float BoostedTreesEnsembleResource::IsTreePostPruned( + const int32 tree_id) const { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() > + 0; +} + +void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id, + const bool is_finalized) { + DCHECK_LT(tree_id, tree_ensemble_->trees_size()); + return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized( + is_finalized); +} + +// Sets the weight of i'th tree. +void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id, + const float weight) { + DCHECK_GE(tree_id, 0); + DCHECK_LT(tree_id, num_trees()); + tree_ensemble_->set_tree_weights(tree_id, weight); +} + void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const { tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted( tree_ensemble_->growing_metadata().num_layers_attempted() + 1); diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h index 561ca3a18a..df78d3f275 100644 --- a/tensorflow/core/kernels/boosted_trees/resources.h +++ b/tensorflow/core/kernels/boosted_trees/resources.h @@ -17,12 +17,16 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_ #include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { +// Forward declaration for proto class TreeEnsemble +namespace boosted_trees { +class TreeEnsemble; +} // namespace boosted_trees + // A StampedResource is a resource that has a stamp token associated with it. // Before reading from or applying updates to the resource, the stamp should // be checked to verify that the update is not stale. @@ -42,31 +46,15 @@ class StampedResource : public ResourceBase { // Keep a tree ensemble in memory for efficient evaluation and mutation. class BoostedTreesEnsembleResource : public StampedResource { public: - // Constructor. - BoostedTreesEnsembleResource() - : tree_ensemble_( - protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>( - &arena_)) {} - - string DebugString() override { - return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(), - "]"); - } - - bool InitFromSerialized(const string& serialized, const int64 stamp_token) { - CHECK_EQ(stamp(), -1) << "Must Reset before Init."; - if (ParseProtoUnlimited(tree_ensemble_, serialized)) { - set_stamp(stamp_token); - return true; - } - return false; - } - - string SerializeAsString() const { - return tree_ensemble_->SerializeAsString(); - } - - int32 num_trees() const { return tree_ensemble_->trees_size(); } + BoostedTreesEnsembleResource(); + + string DebugString() override; + + bool InitFromSerialized(const string& serialized, const int64 stamp_token); + + string SerializeAsString() const; + + int32 num_trees() const; // Find the next node to which the example (specified by index_in_batch) // traverses down from the current node indicated by tree_id and node_id. @@ -82,73 +70,31 @@ class BoostedTreesEnsembleResource : public StampedResource { float node_value(const int32 tree_id, const int32 node_id) const; - int32 GetNumLayersGrown(const int32 tree_id) const { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - return tree_ensemble_->tree_metadata(tree_id).num_layers_grown(); - } + int32 GetNumLayersGrown(const int32 tree_id) const; - void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown( - new_num_layers); - } + void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const; void UpdateLastLayerNodesRange(const int32 node_range_start, - int32 node_range_end) const { - tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start( - node_range_start); - tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end( - node_range_end); - } + int32 node_range_end) const; void GetLastLayerNodesRange(int32* node_range_start, - int32* node_range_end) const { - *node_range_start = - tree_ensemble_->growing_metadata().last_layer_node_start(); - *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end(); - } + int32* node_range_end) const; - int64 GetNumNodes(const int32 tree_id) { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - return tree_ensemble_->trees(tree_id).nodes_size(); - } + int64 GetNumNodes(const int32 tree_id); void UpdateGrowingMetadata() const; - int32 GetNumLayersAttempted() { - return tree_ensemble_->growing_metadata().num_layers_attempted(); - } - - bool is_leaf(const int32 tree_id, const int32 node_id) const { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size()); - const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id); - return node.node_case() == boosted_trees::Node::kLeaf; - } - - int32 feature_id(const int32 tree_id, const int32 node_id) const { - const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); - DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); - return node.bucketized_split().feature_id(); - } - - int32 bucket_threshold(const int32 tree_id, const int32 node_id) const { - const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); - DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); - return node.bucketized_split().threshold(); - } - - int32 left_id(const int32 tree_id, const int32 node_id) const { - const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); - DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); - return node.bucketized_split().left_id(); - } - - int32 right_id(const int32 tree_id, const int32 node_id) const { - const auto node = tree_ensemble_->trees(tree_id).nodes(node_id); - DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); - return node.bucketized_split().right_id(); - } + int32 GetNumLayersAttempted(); + + bool is_leaf(const int32 tree_id, const int32 node_id) const; + + int32 feature_id(const int32 tree_id, const int32 node_id) const; + + int32 bucket_threshold(const int32 tree_id, const int32 node_id) const; + + int32 left_id(const int32 tree_id, const int32 node_id) const; + + int32 right_id(const int32 tree_id, const int32 node_id) const; // Add a tree to the ensemble and returns a new tree_id. int32 AddNewTree(const float weight); @@ -163,38 +109,18 @@ class BoostedTreesEnsembleResource : public StampedResource { // Retrieves tree weights and returns as a vector. // It involves a copy, so should be called only sparingly (like once per // iteration, not per example). - std::vector<float> GetTreeWeights() const { - return {tree_ensemble_->tree_weights().begin(), - tree_ensemble_->tree_weights().end()}; - } - - float GetTreeWeight(const int32 tree_id) const { - return tree_ensemble_->tree_weights(tree_id); - } - - float IsTreeFinalized(const int32 tree_id) const { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - return tree_ensemble_->tree_metadata(tree_id).is_finalized(); - } - - float IsTreePostPruned(const int32 tree_id) const { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - return tree_ensemble_->tree_metadata(tree_id) - .post_pruned_nodes_meta_size() > 0; - } - - void SetIsFinalized(const int32 tree_id, const bool is_finalized) { - DCHECK_LT(tree_id, tree_ensemble_->trees_size()); - return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized( - is_finalized); - } + std::vector<float> GetTreeWeights() const; + + float GetTreeWeight(const int32 tree_id) const; + + float IsTreeFinalized(const int32 tree_id) const; + + float IsTreePostPruned(const int32 tree_id) const; + + void SetIsFinalized(const int32 tree_id, const bool is_finalized); // Sets the weight of i'th tree. - void SetTreeWeight(const int32 tree_id, const float weight) { - DCHECK_GE(tree_id, 0); - DCHECK_LT(tree_id, num_trees()); - tree_ensemble_->set_tree_weights(tree_id, weight); - } + void SetTreeWeight(const int32 tree_id, const float weight); // Resets the resource and frees the protos in arena. // Caller needs to hold the mutex lock while calling this. |