aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 11:47:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 11:55:18 -0700
commite28f9da84b51acdbf3234688daa4c55647041219 (patch)
treeae2b76de366292854c8f07a74f6cfa8897ebd9e9 /tensorflow/contrib/boosted_trees
parente787c15ae8e96170135b39338db222e58ee754b4 (diff)
1) Update the proto files for oblivious trees.
2) Grow a new layer of an oblivious tree. PiperOrigin-RevId: 209633300
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc13
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/training_ops.cc197
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py12
-rw-r--r--tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc29
-rw-r--r--tensorflow/contrib/boosted_trees/ops/training_ops.cc2
-rw-r--r--tensorflow/contrib/boosted_trees/proto/split_info.proto4
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto15
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py353
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py6
10 files changed, 583 insertions, 49 deletions
diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD
index 8eac1243ef..f03eab510c 100644
--- a/tensorflow/contrib/boosted_trees/BUILD
+++ b/tensorflow/contrib/boosted_trees/BUILD
@@ -445,6 +445,7 @@ tf_kernel_library(
"//tensorflow/contrib/boosted_trees/proto:learner_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc",
"//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc",
+ "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc",
"//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource",
"//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource",
"//tensorflow/core:framework_headers_lib",
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index d9e7a0f466..64349cfca3 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -383,19 +383,20 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
best_gain -= num_elements * state->tree_complexity_regularization();
ObliviousSplitInfo oblivious_split_info;
- auto* oblivious_dense_split = oblivious_split_info.mutable_split_node()
- ->mutable_dense_float_binary_split();
+ auto* oblivious_dense_split =
+ oblivious_split_info.mutable_split_node()
+ ->mutable_oblivious_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
oblivious_dense_split->set_threshold(
bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
(*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
- auto* left_children = oblivious_split_info.add_children_leaves();
- auto* right_children = oblivious_split_info.add_children_leaves();
+ auto* left_child = oblivious_split_info.add_children();
+ auto* right_child = oblivious_split_info.add_children();
- state->FillLeaf(best_left_node_stats[root_idx], left_children);
- state->FillLeaf(best_right_node_stats[root_idx], right_children);
+ state->FillLeaf(best_left_node_stats[root_idx], left_child);
+ state->FillLeaf(best_right_node_stats[root_idx], right_child);
const int start_index = partition_boundaries[root_idx];
(*output_partition_ids)(root_idx) = partition_ids(start_index);
diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
index 6d9a6ee5a0..bb5ae78d9b 100644
--- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
@@ -15,6 +15,7 @@
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
+#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -26,6 +27,7 @@ namespace boosted_trees {
namespace {
+using boosted_trees::learner::LearnerConfig;
using boosted_trees::learner::LearningRateConfig;
using boosted_trees::trees::Leaf;
using boosted_trees::trees::TreeNode;
@@ -42,6 +44,9 @@ struct SplitCandidate {
// Split info.
learner::SplitInfo split_info;
+
+ // Oblivious split info.
+ learner::ObliviousSplitInfo oblivious_split_info;
};
// Checks that the leaf is not empty.
@@ -343,7 +348,12 @@ class GrowTreeEnsembleOp : public OpKernel {
OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
float learning_rate = learning_rate_t->scalar<float>()();
- // Read seed that was used for dropout.
+ // Read the weak learner type to use.
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
const Tensor* seed_t;
OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
// Cast seed to uint64.
@@ -363,9 +373,18 @@ class GrowTreeEnsembleOp : public OpKernel {
// Find best splits for each active partition.
std::map<int32, SplitCandidate> best_splits;
- FindBestSplitsPerPartition(context, partition_ids_list, gains_list,
- splits_list, &best_splits);
-
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ FindBestSplitsPerPartitionNormal(context, partition_ids_list,
+ gains_list, splits_list, &best_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ FindBestSplitsPerPartitionOblivious(context, gains_list, splits_list,
+ &best_splits);
+ break;
+ }
+ }
// No-op if no new splits can be considered.
if (best_splits.empty()) {
LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
@@ -377,25 +396,34 @@ class GrowTreeEnsembleOp : public OpKernel {
OP_REQUIRES_OK(context,
context->input("max_tree_depth", &max_tree_depth_t));
const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
-
// Update and retrieve the growable tree.
// If the tree is fully built and dropout was applied, it also adjusts the
// weights of dropped and the last tree.
boosted_trees::trees::DecisionTreeConfig* const tree_config =
UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
- dropout_seed, max_tree_depth);
-
+ dropout_seed, max_tree_depth,
+ weak_learner_type);
// Split tree nodes.
- for (auto& split_entry : best_splits) {
- SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
- ensemble_resource);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ for (auto& split_entry : best_splits) {
+ SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
+ ensemble_resource);
+ }
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
+ }
}
-
// Post-prune finalized tree if needed.
if (learner_config_.pruning_mode() ==
boosted_trees::learner::LearnerConfig::POST_PRUNE &&
ensemble_resource->LastTreeMetadata()->is_finalized()) {
VLOG(2) << "Post-pruning finalized tree.";
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
+ LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
+ }
PruneTree(tree_config);
// If after post-pruning the whole tree has no gain, remove the tree
@@ -409,10 +437,9 @@ class GrowTreeEnsembleOp : public OpKernel {
private:
// Helper method which effectively does a reduce over all split candidates
// and finds the best split for each partition.
- void FindBestSplitsPerPartition(
- OpKernelContext* const context,
- const OpInputList& partition_ids_list, const OpInputList& gains_list,
- const OpInputList& splits_list,
+ void FindBestSplitsPerPartitionNormal(
+ OpKernelContext* const context, const OpInputList& partition_ids_list,
+ const OpInputList& gains_list, const OpInputList& splits_list,
std::map<int32, SplitCandidate>* best_splits) {
// Find best split per partition going through every feature candidate.
// TODO(salehay): Is this worth parallelizing?
@@ -446,6 +473,90 @@ class GrowTreeEnsembleOp : public OpKernel {
}
}
+ void FindBestSplitsPerPartitionOblivious(
+ OpKernelContext* const context, const OpInputList& gains_list,
+ const OpInputList& splits_list,
+ std::map<int32, SplitCandidate>* best_splits) {
+ // Find best split per partition going through every feature candidate.
+ for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
+ const auto& gains = gains_list[handler_id].vec<float>();
+ const auto& splits = splits_list[handler_id].vec<string>();
+ OP_REQUIRES(context, gains.size() == 1,
+ errors::InvalidArgument(
+ "Gains size must be one for oblivious weak learner: ",
+ gains.size(), " != ", 1));
+ OP_REQUIRES(context, splits.size() == 1,
+ errors::InvalidArgument(
+ "Splits size must be one for oblivious weak learner: ",
+ splits.size(), " != ", 1));
+ // Get current split candidate.
+ const auto& gain = gains(0);
+ const auto& serialized_split = splits(0);
+ SplitCandidate split;
+ split.handler_id = handler_id;
+ split.gain = gain;
+ OP_REQUIRES(
+ context, split.oblivious_split_info.ParseFromString(serialized_split),
+ errors::InvalidArgument("Unable to parse oblivious split info."));
+
+ auto split_info = split.oblivious_split_info;
+ CHECK(split_info.children_size() % 2 == 0)
+ << "The oblivious split should generate an even number of children: "
+ << split_info.children_size();
+
+ // If every node is pure, then we shouldn't split.
+ bool only_pure_nodes = true;
+ for (int idx = 0; idx < split_info.children_size(); idx += 2) {
+ if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
+ IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
+ only_pure_nodes = false;
+ break;
+ }
+ }
+ if (only_pure_nodes) {
+ VLOG(1) << "The oblivious split does not actually split anything.";
+ continue;
+ }
+
+ // Don't consider negative splits if we're pre-pruning the tree.
+ if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
+ gain < 0) {
+ continue;
+ }
+
+ // Take the split if we don't have a candidate yet.
+ auto best_split_it = best_splits->find(0);
+ if (best_split_it == best_splits->end()) {
+ best_splits->insert(std::make_pair(0, std::move(split)));
+ continue;
+ }
+
+ // Determine if we should update best split.
+ SplitCandidate& best_split = best_split_it->second;
+ trees::TreeNode current_node = split_info.split_node();
+ trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
+ if (TF_PREDICT_FALSE(gain == best_split.gain)) {
+ // Tie break on node case preferring simpler tree node types.
+ VLOG(2) << "Attempting to tie break with smaller node case. "
+ << "(current split: " << current_node.node_case()
+ << ", best split: " << best_node.node_case() << ")";
+ if (current_node.node_case() < best_node.node_case()) {
+ best_split = std::move(split);
+ } else if (current_node.node_case() == best_node.node_case()) {
+ // Tie break on handler Id.
+ VLOG(2) << "Tie breaking with higher handler Id. "
+ << "(current split: " << handler_id
+ << ", best split: " << best_split.handler_id << ")";
+ if (handler_id > best_split.handler_id) {
+ best_split = std::move(split);
+ }
+ }
+ } else if (gain > best_split.gain) {
+ best_split = std::move(split);
+ }
+ }
+ }
+
void UpdateTreeWeightsIfDropout(
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
@@ -501,7 +612,7 @@ class GrowTreeEnsembleOp : public OpKernel {
boosted_trees::models::DecisionTreeEnsembleResource* const
ensemble_resource,
const float learning_rate, const uint64 dropout_seed,
- const int32 max_tree_depth) {
+ const int32 max_tree_depth, const int32 weak_learner_type) {
const auto num_trees = ensemble_resource->num_trees();
if (num_trees <= 0 ||
ensemble_resource->LastTreeMetadata()->is_finalized()) {
@@ -647,6 +758,60 @@ class GrowTreeEnsembleOp : public OpKernel {
}
}
+ void SplitTreeLayer(
+ SplitCandidate* split,
+ boosted_trees::trees::DecisionTreeConfig* tree_config,
+ boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
+ int depth = 0;
+ while (depth < tree_config->nodes_size() &&
+ tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
+ depth++;
+ }
+ CHECK(tree_config->nodes_size() > 0)
+ << "A tree must have at least one dummy leaf.";
+ // The number of new children.
+ int num_children = 1 << (depth + 1);
+ auto split_info = split->oblivious_split_info;
+ CHECK(num_children == split_info.children_size())
+ << "Wrong number of new children: " << num_children
+ << " != " << split_info.children_size();
+ for (int idx = 0; idx < num_children; idx += 2) {
+ // Old leaf is at position depth + idx / 2.
+ trees::Leaf old_leaf =
+ *tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf();
+ // Update left leaf.
+ *split_info.mutable_children(idx) =
+ *MergeLeafWeights(old_leaf, split_info.mutable_children(idx));
+ // Update right leaf.
+ *split_info.mutable_children(idx + 1) =
+ *MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1));
+ }
+ TreeNodeMetadata* split_metadata =
+ split_info.mutable_split_node()->mutable_node_metadata();
+ split_metadata->set_gain(split->gain);
+
+ TreeNode new_split = *split_info.mutable_split_node();
+ // Move old children to metadata.
+ for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
+ *new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
+ *tree_config->mutable_nodes(idx)->mutable_leaf();
+ }
+ // Add the new split to the tree_config in place before the children start.
+ *tree_config->mutable_nodes(depth) = new_split;
+ // Add the new children
+ int nodes_size = tree_config->nodes_size();
+ for (int idx = 0; idx < num_children; idx++) {
+ if (idx + depth + 1 < nodes_size) {
+ // Update leaves that were already there.
+ *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
+ *split_info.mutable_children(idx);
+ } else {
+ // Add new leaves.
+ *tree_config->add_nodes()->mutable_leaf() =
+ *split_info.mutable_children(idx);
+ }
+ }
+ }
void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
// No-op if tree is empty.
if (tree_config->nodes_size() <= 0) {
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 6572f2f414..d9caebb645 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -258,8 +258,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
oblivious_split_info.ParseFromString(splits[0])
- split_node = oblivious_split_info.split_node.dense_float_binary_split
-
+ split_node = oblivious_split_info.split_node
+ split_node = split_node.oblivious_dense_float_binary_split
self.assertAllClose(0.3, split_node.threshold, 0.00001)
self.assertEqual(0, split_node.feature_column)
@@ -279,8 +279,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
expected_bias_gain_0 = 0.46043165467625896
- left_child = oblivious_split_info.children_leaves[0].vector
- right_child = oblivious_split_info.children_leaves[1].vector
+ left_child = oblivious_split_info.children[0].vector
+ right_child = oblivious_split_info.children[1].vector
self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
@@ -296,8 +296,8 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
# (-4 + 0.1) ** 2 / (0.13 + 1)
expected_bias_gain_1 = 13.460176991150442
- left_child = oblivious_split_info.children_leaves[2].vector
- right_child = oblivious_split_info.children_leaves[3].vector
+ left_child = oblivious_split_info.children[2].vector
+ right_child = oblivious_split_info.children[3].vector
self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
index 0e5578693a..3ed6c5c04d 100644
--- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
+++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
@@ -12,11 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <algorithm>
+
#include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h"
#include "tensorflow/core/platform/macros.h"
-#include <algorithm>
-
namespace tensorflow {
namespace boosted_trees {
namespace trees {
@@ -28,14 +28,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
if (TF_PREDICT_FALSE(config.nodes_size() <= sub_root_id)) {
return kInvalidLeaf;
}
-
// Traverse tree starting at the provided sub-root.
int32 node_id = sub_root_id;
+ // The index of the leave that holds this example in the oblivious case.
+ int oblivious_leaf_idx = 0;
while (true) {
const auto& current_node = config.nodes(node_id);
switch (current_node.node_case()) {
case TreeNode::kLeaf: {
- return node_id;
+ return node_id + oblivious_leaf_idx;
}
case TreeNode::kDenseFloatBinarySplit: {
const auto& split = current_node.dense_float_binary_split();
@@ -100,6 +101,16 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
}
break;
}
+ case TreeNode::kObliviousDenseFloatBinarySplit: {
+ const auto& split = current_node.oblivious_dense_float_binary_split();
+ oblivious_leaf_idx <<= 1;
+ if (example.dense_float_features[split.feature_column()] >
+ split.threshold()) {
+ oblivious_leaf_idx++;
+ }
+ node_id++;
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
break;
@@ -165,6 +176,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
split->set_right_id(*++children_it);
break;
}
+ case TreeNode::kObliviousDenseFloatBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "A non-set node cannot have children.";
break;
@@ -199,6 +215,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
const auto& split = node.categorical_id_set_membership_binary_split();
return {split.left_id(), split.right_id()};
}
+ case TreeNode::kObliviousDenseFloatBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousDenseFloatBinarySplit case.";
+ return {};
+ }
case TreeNode::NODE_NOT_SET: {
return {};
}
diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
index 22ac9edb72..604ec8e0bf 100644
--- a/tensorflow/contrib/boosted_trees/ops/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
@@ -57,6 +57,7 @@ REGISTER_OP("GrowTreeEnsemble")
.Input("learning_rate: float")
.Input("dropout_seed: int64")
.Input("max_tree_depth: int32")
+ .Input("weak_learner_type: int32")
.Input("partition_ids: num_handlers * int32")
.Input("gains: num_handlers * float")
.Input("splits: num_handlers * string")
@@ -82,6 +83,7 @@ tree_ensemble_handle: Handle to the ensemble variable.
stamp_token: Stamp token for validating operation consistency.
next_stamp_token: Stamp token to be used for the next iteration.
learning_rate: Scalar learning rate.
+weak_learner_type: The type of weak learner to use.
partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
gains: List of Rank 1 Tensors containing gains per candidate.
splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.
diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto
index 850340f5c2..65448996bf 100644
--- a/tensorflow/contrib/boosted_trees/proto/split_info.proto
+++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto
@@ -19,8 +19,6 @@ message SplitInfo {
}
message ObliviousSplitInfo {
- // The split node with the feature_column and threshold defined.
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
- // The new leaves of the tree.
- repeated tensorflow.boosted_trees.trees.Leaf children_leaves = 2;
+ repeated tensorflow.boosted_trees.trees.Leaf children = 2;
}
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
index 81411aa84a..500909bf2a 100644
--- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -15,6 +15,7 @@ message TreeNode {
CategoricalIdBinarySplit categorical_id_binary_split = 5;
CategoricalIdSetMembershipBinarySplit
categorical_id_set_membership_binary_split = 6;
+ ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
}
TreeNodeMetadata node_metadata = 777;
}
@@ -26,6 +27,9 @@ message TreeNodeMetadata {
// The original leaf node before this node was split.
Leaf original_leaf = 2;
+
+ // The original layer of leaves before that layer was converted to a split.
+ repeated Leaf original_oblivious_leaves = 3;
}
// Leaves can either hold dense or sparse information.
@@ -101,6 +105,17 @@ message CategoricalIdSetMembershipBinarySplit {
int32 right_id = 4;
}
+// Split rule for dense float features in the oblivious case.
+message ObliviousDenseFloatBinarySplit {
+ // Float feature column and split threshold describing
+ // the rule feature <= threshold.
+ int32 feature_column = 1;
+ float threshold = 2;
+ // We don't store children ids, because either the next node represents the
+ // whole next layer of the tree or starting with the next node we only have
+ // leaves.
+}
+
// DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index e39e1de8d1..572717e216 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -91,6 +91,27 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
return split.SerializeToString()
+def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
+ split_str = """
+ split_node {
+ oblivious_dense_float_binary_split {
+ feature_column: %d
+ threshold: %f
+ }
+ }""" % (fc, threshold)
+ for weight in leave_weights:
+ split_str += """
+ children {
+ vector {
+ value: %f
+ }
+ }""" % (
+ weight)
+ split = split_info_pb2.ObliviousSplitInfo()
+ text_format.Merge(split_str, split)
+ return split.SerializeToString()
+
+
def _gen_categorical_split_info(fc, feat_id, left_weight, right_weight):
split_str = """
split_node {
@@ -324,7 +345,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen.
@@ -383,6 +405,115 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 1)
self.assertProtoEquals(expected_result, tree_ensemble_config)
+ def testGrowEmptyEnsembleObliviousCase(self):
+ """Test growing an empty ensemble in the oblivious case."""
+ with self.test_session() as session:
+ # Create empty ensemble.
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=1,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE)
+
+ # Prepare handler inputs.
+ # Note that handlers 1 & 3 have the same gain but different splits.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([7.62], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143])
+ ]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([0.63], dtype=np.float32)
+ handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([7.62], dtype=np.float32)
+ handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])]
+
+ # Grow tree ensemble.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ # Expect the split with bigger handler_id, i.e. handler 3 to be chosen.
+ # The grown tree should be finalized as max tree depth is 1.
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.143
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ is_finalized: true
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 1)
+ self.assertEqual(stats.num_layers, 1)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 1)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 1)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
def testGrowExistingEnsembleTreeNotFinalized(self):
"""Test growing an existing ensemble with the last tree not finalized."""
with self.test_session() as session:
@@ -476,7 +607,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
@@ -661,7 +793,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect a new tree to be added with the split from handler 1.
@@ -798,7 +931,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the ensemble to be empty.
@@ -869,7 +1003,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the simpler split from handler 1 to be chosen.
@@ -971,7 +1106,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain.
@@ -1053,7 +1189,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the ensemble to be empty as post-pruning will prune
@@ -1120,7 +1257,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split from handler 2 to be chosen despite the negative gain.
@@ -1200,7 +1338,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the negative gain split of partition 1 to be pruned and the
@@ -1371,7 +1510,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect the split for partition 1 to be chosen from handler 1 and
@@ -1470,6 +1610,193 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 2)
self.assertProtoEquals(expected_result, tree_ensemble_config)
+ def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
+ """Test growing an existing ensemble with the last tree not finalized."""
+ with self.test_session() as session:
+ # Create existing ensemble with one root split
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.143
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 1
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 1
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=3,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
+
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([1.4], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5])
+ ]
+ handler2_partitions = np.array([0], dtype=np.int32)
+ handler2_gains = np.array([2.7], dtype=np.float32)
+ handler2_split = [
+ _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]),
+ ]
+ handler3_partitions = np.array([0], dtype=np.int32)
+ handler3_gains = np.array([1.7], dtype=np.float32)
+ handler3_split = [
+ _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1])
+ ]
+
+ # Grow tree ensemble layer by layer.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[
+ handler1_partitions, handler2_partitions, handler3_partitions
+ ],
+ gains=[handler1_gains, handler2_gains, handler3_gains],
+ splits=[handler1_split, handler2_split, handler3_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ # Expect the split for partition 1 to be chosen from handler 1 and
+ # the split for partition 2 to be chosen from handler 2.
+ # The grown tree should not be finalized as max tree depth is 3 and
+ # it's only grown 2 layers.
+ # The partition 1 split weights get added to original leaf weight 7.143.
+ # The partition 2 split weights get added to original leaf weight -4.375.
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.383
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 0)
+ self.assertEqual(stats.num_layers, 2)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 2)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 2)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
"""Test growing an existing ensemble with the last tree finalized."""
with self.test_session() as session:
@@ -1575,7 +1902,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
# Expect a new tree to be added with the split from handler 1.
@@ -1700,7 +2028,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
dropout_seed=123,
center_bias=True,
- max_tree_depth=learner_config.constraints.max_tree_depth)
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE)
session.run(grow_op)
_, serialized = session.run(
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 2f75d8aa99..97743ba255 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -1076,7 +1076,8 @@ class GradientBoostedDecisionTreeModel(object):
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
center_bias=self._center_bias,
- max_tree_depth=self._max_tree_depth)
+ max_tree_depth=self._max_tree_depth,
+ weak_learner_type=self._learner_config.weak_learner_type)
def _grow_ensemble_not_ready_fn():
# Don't grow the ensemble, just update the stamp.
@@ -1091,7 +1092,8 @@ class GradientBoostedDecisionTreeModel(object):
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
center_bias=self._center_bias,
- max_tree_depth=self._max_tree_depth)
+ max_tree_depth=self._max_tree_depth,
+ weak_learner_type=self._learner_config.weak_learner_type)
def _grow_ensemble_fn():
# Conditionally grow an ensemble depending on whether the splits