aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc138
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.h154
2 files changed, 178 insertions, 114 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index 2ea12c522c..c410748c27 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -21,6 +21,35 @@ limitations under the License.
namespace tensorflow {
+// Constructor.
+BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
+ : tree_ensemble_(
+ protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
+ &arena_)) {}
+
+string BoostedTreesEnsembleResource::DebugString() {
+ return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
+ "]");
+}
+
+bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
+ const int64 stamp_token) {
+ CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
+ if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
+ set_stamp(stamp_token);
+ return true;
+ }
+ return false;
+}
+
+string BoostedTreesEnsembleResource::SerializeAsString() const {
+ return tree_ensemble_->SerializeAsString();
+}
+
+int32 BoostedTreesEnsembleResource::num_trees() const {
+ return tree_ensemble_->trees_size();
+}
+
int32 BoostedTreesEnsembleResource::next_node(
const int32 tree_id, const int32 node_id, const int32 index_in_batch,
const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
@@ -49,6 +78,115 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
}
}
+int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
+ const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
+}
+
+void BoostedTreesEnsembleResource::SetNumLayersGrown(
+ const int32 tree_id, int32 new_num_layers) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
+ new_num_layers);
+}
+
+void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
+ const int32 node_range_start, int32 node_range_end) const {
+ tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
+ node_range_start);
+ tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
+ node_range_end);
+}
+
+void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
+ int32* node_range_start, int32* node_range_end) const {
+ *node_range_start =
+ tree_ensemble_->growing_metadata().last_layer_node_start();
+ *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
+}
+
+int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->trees(tree_id).nodes_size();
+}
+
+int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
+ return tree_ensemble_->growing_metadata().num_layers_attempted();
+}
+
+bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
+ const int32 node_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ return node.node_case() == boosted_trees::Node::kLeaf;
+}
+
+int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
+ const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().feature_id();
+}
+
+int32 BoostedTreesEnsembleResource::bucket_threshold(
+ const int32 tree_id, const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().threshold();
+}
+
+int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
+ const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().left_id();
+}
+
+int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
+ const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().right_id();
+}
+
+std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
+ return {tree_ensemble_->tree_weights().begin(),
+ tree_ensemble_->tree_weights().end()};
+}
+
+float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
+ return tree_ensemble_->tree_weights(tree_id);
+}
+
+float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).is_finalized();
+}
+
+float BoostedTreesEnsembleResource::IsTreePostPruned(
+ const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
+ 0;
+}
+
+void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
+ const bool is_finalized) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
+ is_finalized);
+}
+
+// Sets the weight of i'th tree.
+void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
+ const float weight) {
+ DCHECK_GE(tree_id, 0);
+ DCHECK_LT(tree_id, num_trees());
+ tree_ensemble_->set_tree_weights(tree_id, weight);
+}
+
void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
index 561ca3a18a..df78d3f275 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.h
+++ b/tensorflow/core/kernels/boosted_trees/resources.h
@@ -17,12 +17,16 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
+// Forward declaration for proto class TreeEnsemble
+namespace boosted_trees {
+class TreeEnsemble;
+} // namespace boosted_trees
+
// A StampedResource is a resource that has a stamp token associated with it.
// Before reading from or applying updates to the resource, the stamp should
// be checked to verify that the update is not stale.
@@ -42,31 +46,15 @@ class StampedResource : public ResourceBase {
// Keep a tree ensemble in memory for efficient evaluation and mutation.
class BoostedTreesEnsembleResource : public StampedResource {
public:
- // Constructor.
- BoostedTreesEnsembleResource()
- : tree_ensemble_(
- protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
- &arena_)) {}
-
- string DebugString() override {
- return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
- "]");
- }
-
- bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
- CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
- if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
- set_stamp(stamp_token);
- return true;
- }
- return false;
- }
-
- string SerializeAsString() const {
- return tree_ensemble_->SerializeAsString();
- }
-
- int32 num_trees() const { return tree_ensemble_->trees_size(); }
+ BoostedTreesEnsembleResource();
+
+ string DebugString() override;
+
+ bool InitFromSerialized(const string& serialized, const int64 stamp_token);
+
+ string SerializeAsString() const;
+
+ int32 num_trees() const;
// Find the next node to which the example (specified by index_in_batch)
// traverses down from the current node indicated by tree_id and node_id.
@@ -82,73 +70,31 @@ class BoostedTreesEnsembleResource : public StampedResource {
float node_value(const int32 tree_id, const int32 node_id) const;
- int32 GetNumLayersGrown(const int32 tree_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
- }
+ int32 GetNumLayersGrown(const int32 tree_id) const;
- void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
- new_num_layers);
- }
+ void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
void UpdateLastLayerNodesRange(const int32 node_range_start,
- int32 node_range_end) const {
- tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
- node_range_start);
- tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
- node_range_end);
- }
+ int32 node_range_end) const;
void GetLastLayerNodesRange(int32* node_range_start,
- int32* node_range_end) const {
- *node_range_start =
- tree_ensemble_->growing_metadata().last_layer_node_start();
- *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
- }
+ int32* node_range_end) const;
- int64 GetNumNodes(const int32 tree_id) {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->trees(tree_id).nodes_size();
- }
+ int64 GetNumNodes(const int32 tree_id);
void UpdateGrowingMetadata() const;
- int32 GetNumLayersAttempted() {
- return tree_ensemble_->growing_metadata().num_layers_attempted();
- }
-
- bool is_leaf(const int32 tree_id, const int32 node_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
- const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- return node.node_case() == boosted_trees::Node::kLeaf;
- }
-
- int32 feature_id(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().feature_id();
- }
-
- int32 bucket_threshold(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().threshold();
- }
-
- int32 left_id(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().left_id();
- }
-
- int32 right_id(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().right_id();
- }
+ int32 GetNumLayersAttempted();
+
+ bool is_leaf(const int32 tree_id, const int32 node_id) const;
+
+ int32 feature_id(const int32 tree_id, const int32 node_id) const;
+
+ int32 bucket_threshold(const int32 tree_id, const int32 node_id) const;
+
+ int32 left_id(const int32 tree_id, const int32 node_id) const;
+
+ int32 right_id(const int32 tree_id, const int32 node_id) const;
// Add a tree to the ensemble and returns a new tree_id.
int32 AddNewTree(const float weight);
@@ -163,38 +109,18 @@ class BoostedTreesEnsembleResource : public StampedResource {
// Retrieves tree weights and returns as a vector.
// It involves a copy, so should be called only sparingly (like once per
// iteration, not per example).
- std::vector<float> GetTreeWeights() const {
- return {tree_ensemble_->tree_weights().begin(),
- tree_ensemble_->tree_weights().end()};
- }
-
- float GetTreeWeight(const int32 tree_id) const {
- return tree_ensemble_->tree_weights(tree_id);
- }
-
- float IsTreeFinalized(const int32 tree_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->tree_metadata(tree_id).is_finalized();
- }
-
- float IsTreePostPruned(const int32 tree_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->tree_metadata(tree_id)
- .post_pruned_nodes_meta_size() > 0;
- }
-
- void SetIsFinalized(const int32 tree_id, const bool is_finalized) {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
- is_finalized);
- }
+ std::vector<float> GetTreeWeights() const;
+
+ float GetTreeWeight(const int32 tree_id) const;
+
+ float IsTreeFinalized(const int32 tree_id) const;
+
+ float IsTreePostPruned(const int32 tree_id) const;
+
+ void SetIsFinalized(const int32 tree_id, const bool is_finalized);
// Sets the weight of i'th tree.
- void SetTreeWeight(const int32 tree_id, const float weight) {
- DCHECK_GE(tree_id, 0);
- DCHECK_LT(tree_id, num_trees());
- tree_ensemble_->set_tree_weights(tree_id, weight);
- }
+ void SetTreeWeight(const int32 tree_id, const float weight);
// Resets the resource and frees the protos in arena.
// Caller needs to hold the mutex lock while calling this.