diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-22 09:48:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-22 11:05:29 -0700 |
commit | 2bb0a625cd684f587daadae0252a78da3f14f4f9 (patch) | |
tree | fd8247fcf63535542224d369f38b062ad432073a /tensorflow/contrib/boosted_trees/proto | |
parent | 039c74293abfc8e9263ca80ab4facdde658dfb03 (diff) |
Migrate utils and protos to contrib/boosted_trees.
Change: 150897748
Diffstat (limited to 'tensorflow/contrib/boosted_trees/proto')
-rw-r--r-- | tensorflow/contrib/boosted_trees/proto/BUILD | 32 | ||||
-rw-r--r-- | tensorflow/contrib/boosted_trees/proto/learner.proto | 136 | ||||
-rw-r--r-- | tensorflow/contrib/boosted_trees/proto/tree_config.proto | 109 |
3 files changed, 277 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/proto/BUILD b/tensorflow/contrib/boosted_trees/proto/BUILD new file mode 100644 index 0000000000..3b6b0339d2 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/proto/BUILD @@ -0,0 +1,32 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "learner_proto", + srcs = [ + "learner.proto", + ], + cc_api_version = 2, + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "tree_config_proto", + srcs = ["tree_config.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto new file mode 100644 index 0000000000..06ee223467 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -0,0 +1,136 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.boosted_trees.learner; + +// Tree regularization config. +message TreeRegularizationConfig { + // Classic L1/L2. + float l1 = 1; + float l2 = 2; + + // Tree complexity penalizes overall model complexity effectively + // limiting how deep the tree can grow in regions with small gain. + float tree_complexity = 3; +} + +// Tree constraints config. +message TreeConstraintsConfig { + // Maximum depth of the trees. + uint32 max_tree_depth = 1; + + // Min hessian weight per node. + float min_node_weight = 2; +} + +// LearningRateConfig describes all supported learning rate tuners. +message LearningRateConfig { + oneof tuner { + LearningRateFixedConfig fixed = 1; + LearningRateDropoutDrivenConfig dropout = 2; + LearningRateLineSearchConfig line_search = 3; + } +} + +// Config for a fixed learning rate. +message LearningRateFixedConfig { + float learning_rate = 1; +} + +// Config for a tuned learning rate. +message LearningRateLineSearchConfig { + // Max learning rate. Must be strictly positive. + float max_learning_rate = 1; + + // Number of learning rate values to consider between [0, max_learning_rate). + int32 num_steps = 2; +} + +// When we have a sequence of trees 1, 2, 3 ... n, these essentially represent +// weights updates in functional space, and thus we can use averaging of weight +// updates to achieve better performance. For example, we can say that our final +// ensemble will be an average of ensembles of tree 1, and ensemble of tree 1 +// and tree 2 etc .. ensemble of all trees. +// Note that this averaging will apply ONLY DURING PREDICTION. The training +// stays the same. +message AveragingConfig { + oneof config { + float average_last_n_trees = 1; + // Between 0 and 1. If set to 1.0, we are averaging ensembles of tree 1, + // ensemble of tree 1 and tree 2, etc ensemble of all trees. If set to 0.5, + // last half of the trees are averaged etc. + float average_last_percent_trees = 2; + } +} + +message LearningRateDropoutDrivenConfig { + // Probability of dropping each tree in an existing so far ensemble. + float dropout_probability = 1; + + // When trees are built after dropout happen, they don't "advance" to the + // optimal solution, they just rearrange the path. However you can still + // choose to skip dropout periodically, to allow a new tree that "advances" + // to be added. + // For example, if running for 200 steps with probability of dropout 1/100, + // you would expect the dropout to start happening for sure for all iterations + // after 100. However you can add probability_of_skipping_dropout of 0.1, this + // way iterations 100-200 will include approx 90 iterations of dropout and 10 + // iterations of normal steps.Set it to 0 if you want just keep building + // the refinement trees after dropout kicks in. + float probability_of_skipping_dropout = 2; + + // Between 0 and 1. + float learning_rate = 3; +} + +message LearnerConfig { + enum PruningMode { + PRE_PRUNE = 0; + POST_PRUNE = 1; + } + + enum GrowingMode { + WHOLE_TREE = 0; + // Layer by layer is only supported by the batch learner. + LAYER_BY_LAYER = 1; + } + + enum MultiClassStrategy { + TREE_PER_CLASS = 0; + FULL_HESSIAN = 1; + DIAGONAL_HESSIAN = 2; + } + + // Number of classes. + uint32 num_classes = 1; + + // Fraction of features to consider in each tree sampled randomly + // from all available features. + oneof feature_fraction { + float feature_fraction_per_tree = 2; + float feature_fraction_per_level = 3; + }; + + // Regularization. + TreeRegularizationConfig regularization = 4; + + // Constraints. + TreeConstraintsConfig constraints = 5; + + // Pruning. + PruningMode pruning_mode = 8; + + // Growing Mode. + GrowingMode growing_mode = 9; + + // Learning rate. + LearningRateConfig learning_rate_tuner = 6; + + // Multi-class strategy. + MultiClassStrategy multi_class_strategy = 10; + + // If you want to average the ensembles (for regularization), provide the + // config below. + AveragingConfig averaging_config = 11; +} diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto new file mode 100644 index 0000000000..3daa613b5d --- /dev/null +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -0,0 +1,109 @@ +syntax = "proto3"; +option cc_enable_arenas = true; + +package tensorflow.boosted_trees.trees; + +// TreeNode describes a node in a tree. +message TreeNode { + oneof node { + Leaf leaf = 1; + DenseFloatBinarySplit dense_float_binary_split = 2; + SparseFloatBinarySplitDefaultLeft sparse_float_binary_split_default_left = + 3; + SparseFloatBinarySplitDefaultRight sparse_float_binary_split_default_right = + 4; + CategoricalIdBinarySplit categorical_id_binary_split = 5; + } + TreeNodeMetadata node_metadata = 777; +} + +// TreeNodeMetadata encodes metadata associated with each node in a tree. +message TreeNodeMetadata { + // The gain associated with this node. + float gain = 1; + + // The original leaf node before this node was split. + Leaf original_leaf = 2; +} + +// Leaves can either hold dense or sparse information. +message Leaf { + oneof leaf { + // See learning/decision_trees/proto/generic_tree_model.proto?l=133 + // for a description of how vector and sparse_vector might be used. + Vector vector = 1; + SparseVector sparse_vector = 2; + } +} + +message Vector { + repeated float value = 1; +} + +message SparseVector { + repeated int32 index = 1; + repeated float value = 2; +} + +// Split rule for dense float features. +message DenseFloatBinarySplit { + // Float feature column and split threshold describing + // the rule feature <= threshold. + int32 feature_column = 1; + float threshold = 2; + + // Node children indexing into a contiguous + // vector of nodes starting from the root. + int32 left_id = 3; + int32 right_id = 4; +} + +// Split rule for sparse float features defaulting left for missing features. +message SparseFloatBinarySplitDefaultLeft { + DenseFloatBinarySplit split = 1; +} + +// Split rule for sparse float features defaulting right for missing features. +message SparseFloatBinarySplitDefaultRight { + DenseFloatBinarySplit split = 1; +} + +// Split rule for categorical features with a single feature Id. +message CategoricalIdBinarySplit { + // Categorical feature column and Id describing + // the rule feature == Id. + int32 feature_column = 1; + int64 feature_id = 2; + + // Node children indexing into a contiguous + // vector of nodes starting from the root. + int32 left_id = 3; + int32 right_id = 4; +} + +// DecisionTreeConfig describes a list of connected nodes. +// Node 0 must be the root and can carry any payload including a leaf +// in the case of representing the bias. +// Note that each node id is implicitly its index in the list of nodes. +message DecisionTreeConfig { + repeated TreeNode nodes = 1; +} + +message DecisionTreeMetadata { + // How many times tree weight was updated (due to reweighting of the final + // ensemble, dropout, shrinkage etc). + int32 num_tree_weight_updates = 1; + + // Number of layers grown for this tree. + int32 num_layers_grown = 2; + + // Whether the tree is finalized in that no more layers can be grown. + bool is_finalized = 3; +} + +// DecisionTreeEnsembleConfig describes an ensemble of decision trees. +message DecisionTreeEnsembleConfig { + repeated DecisionTreeConfig trees = 1; + repeated float tree_weights = 2; + repeated DecisionTreeMetadata tree_metadata = 3; +} |