aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/proto
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-22 09:48:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-22 11:05:29 -0700
commit2bb0a625cd684f587daadae0252a78da3f14f4f9 (patch)
treefd8247fcf63535542224d369f38b062ad432073a /tensorflow/contrib/boosted_trees/proto
parent039c74293abfc8e9263ca80ab4facdde658dfb03 (diff)
Migrate utils and protos to contrib/boosted_trees.
Change: 150897748
Diffstat (limited to 'tensorflow/contrib/boosted_trees/proto')
-rw-r--r--tensorflow/contrib/boosted_trees/proto/BUILD32
-rw-r--r--tensorflow/contrib/boosted_trees/proto/learner.proto136
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto109
3 files changed, 277 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD
new file mode 100644
index 0000000000..3b6b0339d2
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/proto/BUILD
@@ -0,0 +1,32 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_proto_library(
+ name = "learner_proto",
+ srcs = [
+ "learner.proto",
+ ],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
+
+tf_proto_library(
+ name = "tree_config_proto",
+ srcs = ["tree_config.proto"],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto
new file mode 100644
index 0000000000..06ee223467
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/proto/learner.proto
@@ -0,0 +1,136 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+
+package tensorflow.boosted_trees.learner;
+
+// Tree regularization config.
+message TreeRegularizationConfig {
+ // Classic L1/L2.
+ float l1 = 1;
+ float l2 = 2;
+
+ // Tree complexity penalizes overall model complexity effectively
+ // limiting how deep the tree can grow in regions with small gain.
+ float tree_complexity = 3;
+}
+
+// Tree constraints config.
+message TreeConstraintsConfig {
+ // Maximum depth of the trees.
+ uint32 max_tree_depth = 1;
+
+ // Min hessian weight per node.
+ float min_node_weight = 2;
+}
+
+// LearningRateConfig describes all supported learning rate tuners.
+message LearningRateConfig {
+ oneof tuner {
+ LearningRateFixedConfig fixed = 1;
+ LearningRateDropoutDrivenConfig dropout = 2;
+ LearningRateLineSearchConfig line_search = 3;
+ }
+}
+
+// Config for a fixed learning rate.
+message LearningRateFixedConfig {
+ float learning_rate = 1;
+}
+
+// Config for a tuned learning rate.
+message LearningRateLineSearchConfig {
+ // Max learning rate. Must be strictly positive.
+ float max_learning_rate = 1;
+
+ // Number of learning rate values to consider between [0, max_learning_rate).
+ int32 num_steps = 2;
+}
+
+// When we have a sequence of trees 1, 2, 3 ... n, these essentially represent
+// weights updates in functional space, and thus we can use averaging of weight
+// updates to achieve better performance. For example, we can say that our final
+// ensemble will be an average of ensembles of tree 1, and ensemble of tree 1
+// and tree 2 etc .. ensemble of all trees.
+// Note that this averaging will apply ONLY DURING PREDICTION. The training
+// stays the same.
+message AveragingConfig {
+ oneof config {
+ float average_last_n_trees = 1;
+ // Between 0 and 1. If set to 1.0, we are averaging ensembles of tree 1,
+ // ensemble of tree 1 and tree 2, etc ensemble of all trees. If set to 0.5,
+ // last half of the trees are averaged etc.
+ float average_last_percent_trees = 2;
+ }
+}
+
+message LearningRateDropoutDrivenConfig {
+ // Probability of dropping each tree in an existing so far ensemble.
+ float dropout_probability = 1;
+
+ // When trees are built after dropout happen, they don't "advance" to the
+ // optimal solution, they just rearrange the path. However you can still
+ // choose to skip dropout periodically, to allow a new tree that "advances"
+ // to be added.
+ // For example, if running for 200 steps with probability of dropout 1/100,
+ // you would expect the dropout to start happening for sure for all iterations
+ // after 100. However you can add probability_of_skipping_dropout of 0.1, this
+ // way iterations 100-200 will include approx 90 iterations of dropout and 10
+ // iterations of normal steps.Set it to 0 if you want just keep building
+ // the refinement trees after dropout kicks in.
+ float probability_of_skipping_dropout = 2;
+
+ // Between 0 and 1.
+ float learning_rate = 3;
+}
+
+message LearnerConfig {
+ enum PruningMode {
+ PRE_PRUNE = 0;
+ POST_PRUNE = 1;
+ }
+
+ enum GrowingMode {
+ WHOLE_TREE = 0;
+ // Layer by layer is only supported by the batch learner.
+ LAYER_BY_LAYER = 1;
+ }
+
+ enum MultiClassStrategy {
+ TREE_PER_CLASS = 0;
+ FULL_HESSIAN = 1;
+ DIAGONAL_HESSIAN = 2;
+ }
+
+ // Number of classes.
+ uint32 num_classes = 1;
+
+ // Fraction of features to consider in each tree sampled randomly
+ // from all available features.
+ oneof feature_fraction {
+ float feature_fraction_per_tree = 2;
+ float feature_fraction_per_level = 3;
+ };
+
+ // Regularization.
+ TreeRegularizationConfig regularization = 4;
+
+ // Constraints.
+ TreeConstraintsConfig constraints = 5;
+
+ // Pruning.
+ PruningMode pruning_mode = 8;
+
+ // Growing Mode.
+ GrowingMode growing_mode = 9;
+
+ // Learning rate.
+ LearningRateConfig learning_rate_tuner = 6;
+
+ // Multi-class strategy.
+ MultiClassStrategy multi_class_strategy = 10;
+
+ // If you want to average the ensembles (for regularization), provide the
+ // config below.
+ AveragingConfig averaging_config = 11;
+}
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
new file mode 100644
index 0000000000..3daa613b5d
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -0,0 +1,109 @@
+syntax = "proto3";
+option cc_enable_arenas = true;
+
+package tensorflow.boosted_trees.trees;
+
+// TreeNode describes a node in a tree.
+message TreeNode {
+ oneof node {
+ Leaf leaf = 1;
+ DenseFloatBinarySplit dense_float_binary_split = 2;
+ SparseFloatBinarySplitDefaultLeft sparse_float_binary_split_default_left =
+ 3;
+ SparseFloatBinarySplitDefaultRight sparse_float_binary_split_default_right =
+ 4;
+ CategoricalIdBinarySplit categorical_id_binary_split = 5;
+ }
+ TreeNodeMetadata node_metadata = 777;
+}
+
+// TreeNodeMetadata encodes metadata associated with each node in a tree.
+message TreeNodeMetadata {
+ // The gain associated with this node.
+ float gain = 1;
+
+ // The original leaf node before this node was split.
+ Leaf original_leaf = 2;
+}
+
+// Leaves can either hold dense or sparse information.
+message Leaf {
+ oneof leaf {
+ // See learning/decision_trees/proto/generic_tree_model.proto?l=133
+ // for a description of how vector and sparse_vector might be used.
+ Vector vector = 1;
+ SparseVector sparse_vector = 2;
+ }
+}
+
+message Vector {
+ repeated float value = 1;
+}
+
+message SparseVector {
+ repeated int32 index = 1;
+ repeated float value = 2;
+}
+
+// Split rule for dense float features.
+message DenseFloatBinarySplit {
+ // Float feature column and split threshold describing
+ // the rule feature <= threshold.
+ int32 feature_column = 1;
+ float threshold = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
+// Split rule for sparse float features defaulting left for missing features.
+message SparseFloatBinarySplitDefaultLeft {
+ DenseFloatBinarySplit split = 1;
+}
+
+// Split rule for sparse float features defaulting right for missing features.
+message SparseFloatBinarySplitDefaultRight {
+ DenseFloatBinarySplit split = 1;
+}
+
+// Split rule for categorical features with a single feature Id.
+message CategoricalIdBinarySplit {
+ // Categorical feature column and Id describing
+ // the rule feature == Id.
+ int32 feature_column = 1;
+ int64 feature_id = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
+// DecisionTreeConfig describes a list of connected nodes.
+// Node 0 must be the root and can carry any payload including a leaf
+// in the case of representing the bias.
+// Note that each node id is implicitly its index in the list of nodes.
+message DecisionTreeConfig {
+ repeated TreeNode nodes = 1;
+}
+
+message DecisionTreeMetadata {
+ // How many times tree weight was updated (due to reweighting of the final
+ // ensemble, dropout, shrinkage etc).
+ int32 num_tree_weight_updates = 1;
+
+ // Number of layers grown for this tree.
+ int32 num_layers_grown = 2;
+
+ // Whether the tree is finalized in that no more layers can be grown.
+ bool is_finalized = 3;
+}
+
+// DecisionTreeEnsembleConfig describes an ensemble of decision trees.
+message DecisionTreeEnsembleConfig {
+ repeated DecisionTreeConfig trees = 1;
+ repeated float tree_weights = 2;
+ repeated DecisionTreeMetadata tree_metadata = 3;
+}