aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-19 01:26:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 01:29:24 -0700
commitd218339e6a05a984ef7b9a49d66db219d862936e (patch)
tree7781816b6a53b7e8de9154b3c8df65d5c640eb0b
parenta4b0b02ef66586ac98d558099a37662a892f14f1 (diff)
Remove proto import in header files for core/kernels/boosted_trees.
Move implementations that requires declaration of TreeEnsemble to .cc files. The goal is to make kernels mostly independent of proto headers, which will let us lock down our .so import PiperOrigin-RevId: 193478404
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc138
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.h154
2 files changed, 178 insertions, 114 deletions
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index 2ea12c522c..c410748c27 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -21,6 +21,35 @@ limitations under the License.
namespace tensorflow {
+// Constructor.
+BoostedTreesEnsembleResource::BoostedTreesEnsembleResource()
+ : tree_ensemble_(
+ protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
+ &arena_)) {}
+
+string BoostedTreesEnsembleResource::DebugString() {
+ return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
+ "]");
+}
+
+bool BoostedTreesEnsembleResource::InitFromSerialized(const string& serialized,
+ const int64 stamp_token) {
+ CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
+ if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
+ set_stamp(stamp_token);
+ return true;
+ }
+ return false;
+}
+
+string BoostedTreesEnsembleResource::SerializeAsString() const {
+ return tree_ensemble_->SerializeAsString();
+}
+
+int32 BoostedTreesEnsembleResource::num_trees() const {
+ return tree_ensemble_->trees_size();
+}
+
int32 BoostedTreesEnsembleResource::next_node(
const int32 tree_id, const int32 node_id, const int32 index_in_batch,
const std::vector<TTypes<int32>::ConstVec>& bucketized_features) const {
@@ -49,6 +78,115 @@ float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
}
}
+int32 BoostedTreesEnsembleResource::GetNumLayersGrown(
+ const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
+}
+
+void BoostedTreesEnsembleResource::SetNumLayersGrown(
+ const int32 tree_id, int32 new_num_layers) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
+ new_num_layers);
+}
+
+void BoostedTreesEnsembleResource::UpdateLastLayerNodesRange(
+ const int32 node_range_start, int32 node_range_end) const {
+ tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
+ node_range_start);
+ tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
+ node_range_end);
+}
+
+void BoostedTreesEnsembleResource::GetLastLayerNodesRange(
+ int32* node_range_start, int32* node_range_end) const {
+ *node_range_start =
+ tree_ensemble_->growing_metadata().last_layer_node_start();
+ *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
+}
+
+int64 BoostedTreesEnsembleResource::GetNumNodes(const int32 tree_id) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->trees(tree_id).nodes_size();
+}
+
+int32 BoostedTreesEnsembleResource::GetNumLayersAttempted() {
+ return tree_ensemble_->growing_metadata().num_layers_attempted();
+}
+
+bool BoostedTreesEnsembleResource::is_leaf(const int32 tree_id,
+ const int32 node_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
+ const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ return node.node_case() == boosted_trees::Node::kLeaf;
+}
+
+int32 BoostedTreesEnsembleResource::feature_id(const int32 tree_id,
+ const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().feature_id();
+}
+
+int32 BoostedTreesEnsembleResource::bucket_threshold(
+ const int32 tree_id, const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().threshold();
+}
+
+int32 BoostedTreesEnsembleResource::left_id(const int32 tree_id,
+ const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().left_id();
+}
+
+int32 BoostedTreesEnsembleResource::right_id(const int32 tree_id,
+ const int32 node_id) const {
+ const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
+ DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
+ return node.bucketized_split().right_id();
+}
+
+std::vector<float> BoostedTreesEnsembleResource::GetTreeWeights() const {
+ return {tree_ensemble_->tree_weights().begin(),
+ tree_ensemble_->tree_weights().end()};
+}
+
+float BoostedTreesEnsembleResource::GetTreeWeight(const int32 tree_id) const {
+ return tree_ensemble_->tree_weights(tree_id);
+}
+
+float BoostedTreesEnsembleResource::IsTreeFinalized(const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).is_finalized();
+}
+
+float BoostedTreesEnsembleResource::IsTreePostPruned(
+ const int32 tree_id) const {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->tree_metadata(tree_id).post_pruned_nodes_meta_size() >
+ 0;
+}
+
+void BoostedTreesEnsembleResource::SetIsFinalized(const int32 tree_id,
+ const bool is_finalized) {
+ DCHECK_LT(tree_id, tree_ensemble_->trees_size());
+ return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
+ is_finalized);
+}
+
+// Sets the weight of i'th tree.
+void BoostedTreesEnsembleResource::SetTreeWeight(const int32 tree_id,
+ const float weight) {
+ DCHECK_GE(tree_id, 0);
+ DCHECK_LT(tree_id, num_trees());
+ tree_ensemble_->set_tree_weights(tree_id, weight);
+}
+
void BoostedTreesEnsembleResource::UpdateGrowingMetadata() const {
tree_ensemble_->mutable_growing_metadata()->set_num_layers_attempted(
tree_ensemble_->growing_metadata().num_layers_attempted() + 1);
diff --git a/tensorflow/core/kernels/boosted_trees/resources.h b/tensorflow/core/kernels/boosted_trees/resources.h
index 561ca3a18a..df78d3f275 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.h
+++ b/tensorflow/core/kernels/boosted_trees/resources.h
@@ -17,12 +17,16 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_RESOURCES_H_
#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/kernels/boosted_trees/boosted_trees.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
+// Forward declaration for proto class TreeEnsemble
+namespace boosted_trees {
+class TreeEnsemble;
+} // namespace boosted_trees
+
// A StampedResource is a resource that has a stamp token associated with it.
// Before reading from or applying updates to the resource, the stamp should
// be checked to verify that the update is not stale.
@@ -42,31 +46,15 @@ class StampedResource : public ResourceBase {
// Keep a tree ensemble in memory for efficient evaluation and mutation.
class BoostedTreesEnsembleResource : public StampedResource {
public:
- // Constructor.
- BoostedTreesEnsembleResource()
- : tree_ensemble_(
- protobuf::Arena::CreateMessage<boosted_trees::TreeEnsemble>(
- &arena_)) {}
-
- string DebugString() override {
- return strings::StrCat("TreeEnsemble[size=", tree_ensemble_->trees_size(),
- "]");
- }
-
- bool InitFromSerialized(const string& serialized, const int64 stamp_token) {
- CHECK_EQ(stamp(), -1) << "Must Reset before Init.";
- if (ParseProtoUnlimited(tree_ensemble_, serialized)) {
- set_stamp(stamp_token);
- return true;
- }
- return false;
- }
-
- string SerializeAsString() const {
- return tree_ensemble_->SerializeAsString();
- }
-
- int32 num_trees() const { return tree_ensemble_->trees_size(); }
+ BoostedTreesEnsembleResource();
+
+ string DebugString() override;
+
+ bool InitFromSerialized(const string& serialized, const int64 stamp_token);
+
+ string SerializeAsString() const;
+
+ int32 num_trees() const;
// Find the next node to which the example (specified by index_in_batch)
// traverses down from the current node indicated by tree_id and node_id.
@@ -82,73 +70,31 @@ class BoostedTreesEnsembleResource : public StampedResource {
float node_value(const int32 tree_id, const int32 node_id) const;
- int32 GetNumLayersGrown(const int32 tree_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->tree_metadata(tree_id).num_layers_grown();
- }
+ int32 GetNumLayersGrown(const int32 tree_id) const;
- void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- tree_ensemble_->mutable_tree_metadata(tree_id)->set_num_layers_grown(
- new_num_layers);
- }
+ void SetNumLayersGrown(const int32 tree_id, int32 new_num_layers) const;
void UpdateLastLayerNodesRange(const int32 node_range_start,
- int32 node_range_end) const {
- tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_start(
- node_range_start);
- tree_ensemble_->mutable_growing_metadata()->set_last_layer_node_end(
- node_range_end);
- }
+ int32 node_range_end) const;
void GetLastLayerNodesRange(int32* node_range_start,
- int32* node_range_end) const {
- *node_range_start =
- tree_ensemble_->growing_metadata().last_layer_node_start();
- *node_range_end = tree_ensemble_->growing_metadata().last_layer_node_end();
- }
+ int32* node_range_end) const;
- int64 GetNumNodes(const int32 tree_id) {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->trees(tree_id).nodes_size();
- }
+ int64 GetNumNodes(const int32 tree_id);
void UpdateGrowingMetadata() const;
- int32 GetNumLayersAttempted() {
- return tree_ensemble_->growing_metadata().num_layers_attempted();
- }
-
- bool is_leaf(const int32 tree_id, const int32 node_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
- const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- return node.node_case() == boosted_trees::Node::kLeaf;
- }
-
- int32 feature_id(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().feature_id();
- }
-
- int32 bucket_threshold(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().threshold();
- }
-
- int32 left_id(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().left_id();
- }
-
- int32 right_id(const int32 tree_id, const int32 node_id) const {
- const auto node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- return node.bucketized_split().right_id();
- }
+ int32 GetNumLayersAttempted();
+
+ bool is_leaf(const int32 tree_id, const int32 node_id) const;
+
+ int32 feature_id(const int32 tree_id, const int32 node_id) const;
+
+ int32 bucket_threshold(const int32 tree_id, const int32 node_id) const;
+
+ int32 left_id(const int32 tree_id, const int32 node_id) const;
+
+ int32 right_id(const int32 tree_id, const int32 node_id) const;
// Add a tree to the ensemble and returns a new tree_id.
int32 AddNewTree(const float weight);
@@ -163,38 +109,18 @@ class BoostedTreesEnsembleResource : public StampedResource {
// Retrieves tree weights and returns as a vector.
// It involves a copy, so should be called only sparingly (like once per
// iteration, not per example).
- std::vector<float> GetTreeWeights() const {
- return {tree_ensemble_->tree_weights().begin(),
- tree_ensemble_->tree_weights().end()};
- }
-
- float GetTreeWeight(const int32 tree_id) const {
- return tree_ensemble_->tree_weights(tree_id);
- }
-
- float IsTreeFinalized(const int32 tree_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->tree_metadata(tree_id).is_finalized();
- }
-
- float IsTreePostPruned(const int32 tree_id) const {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->tree_metadata(tree_id)
- .post_pruned_nodes_meta_size() > 0;
- }
-
- void SetIsFinalized(const int32 tree_id, const bool is_finalized) {
- DCHECK_LT(tree_id, tree_ensemble_->trees_size());
- return tree_ensemble_->mutable_tree_metadata(tree_id)->set_is_finalized(
- is_finalized);
- }
+ std::vector<float> GetTreeWeights() const;
+
+ float GetTreeWeight(const int32 tree_id) const;
+
+ float IsTreeFinalized(const int32 tree_id) const;
+
+ float IsTreePostPruned(const int32 tree_id) const;
+
+ void SetIsFinalized(const int32 tree_id, const bool is_finalized);
// Sets the weight of i'th tree.
- void SetTreeWeight(const int32 tree_id, const float weight) {
- DCHECK_GE(tree_id, 0);
- DCHECK_LT(tree_id, num_trees());
- tree_ensemble_->set_tree_weights(tree_id, weight);
- }
+ void SetTreeWeight(const int32 tree_id, const float weight);
// Resets the resource and frees the protos in arena.
// Caller needs to hold the mutex lock while calling this.