aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/proto/learner.proto
blob: d84ba7438e7f03685d5bafca52ff8283f0fce898 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
syntax = "proto3";

option cc_enable_arenas = true;

package tensorflow.boosted_trees.learner;

// Tree regularization config.
message TreeRegularizationConfig {
  // Classic L1/L2.
  float l1 = 1;
  float l2 = 2;

  // Tree complexity penalizes overall model complexity effectively
  // limiting how deep the tree can grow in regions with small gain.
  float tree_complexity = 3;
}

// Tree constraints config.
message TreeConstraintsConfig {
  // Maximum depth of the trees. The default value is 6 if not specified.
  uint32 max_tree_depth = 1;

  // Min hessian weight per node.
  float min_node_weight = 2;

  // Maximum number of unique features used in the tree. Zero means there is no
  // limit.
  int64 max_number_of_unique_feature_columns = 3;
}

// LearningRateConfig describes all supported learning rate tuners.
message LearningRateConfig {
  oneof tuner {
    LearningRateFixedConfig fixed = 1;
    LearningRateDropoutDrivenConfig dropout = 2;
    LearningRateLineSearchConfig line_search = 3;
  }
}

// Config for a fixed learning rate.
message LearningRateFixedConfig {
  float learning_rate = 1;
}

// Config for a tuned learning rate.
message LearningRateLineSearchConfig {
  // Max learning rate. Must be strictly positive.
  float max_learning_rate = 1;

  // Number of learning rate values to consider between [0, max_learning_rate).
  int32 num_steps = 2;
}

// When we have a sequence of trees 1, 2, 3 ... n, these essentially represent
// weights updates in functional space, and thus we can use averaging of weight
// updates to achieve better performance. For example, we can say that our final
// ensemble will be an average of ensembles of tree 1, and ensemble of tree 1
// and tree 2 etc .. ensemble of all trees.
// Note that this averaging will apply ONLY DURING PREDICTION. The training
// stays the same.
message AveragingConfig {
  oneof config {
    float average_last_n_trees = 1;
    // Between 0 and 1. If set to 1.0, we are averaging ensembles of tree 1,
    // ensemble of tree 1 and tree 2, etc ensemble of all trees. If set to 0.5,
    // last half of the trees are averaged etc.
    float average_last_percent_trees = 2;
  }
}

message LearningRateDropoutDrivenConfig {
  // Probability of dropping each tree in an existing so far ensemble.
  float dropout_probability = 1;

  // When trees are built after dropout happen, they don't "advance" to the
  // optimal solution, they just rearrange the path. However you can still
  // choose to skip dropout periodically, to allow a new tree that "advances"
  // to be added.
  // For example, if running for 200 steps with probability of dropout 1/100,
  // you would expect the dropout to start happening for sure for all iterations
  // after 100. However you can add probability_of_skipping_dropout of 0.1, this
  // way iterations 100-200 will include approx 90 iterations of dropout and 10
  // iterations of normal steps.Set it to 0 if you want just keep building
  // the refinement trees after dropout kicks in.
  float probability_of_skipping_dropout = 2;

  // Between 0 and 1.
  float learning_rate = 3;
}

message LearnerConfig {
  enum PruningMode {
    PRUNING_MODE_UNSPECIFIED = 0;
    PRE_PRUNE = 1;
    POST_PRUNE = 2;
  }

  enum GrowingMode {
    GROWING_MODE_UNSPECIFIED = 0;
    WHOLE_TREE = 1;
    LAYER_BY_LAYER = 2;
  }

  enum MultiClassStrategy {
    MULTI_CLASS_STRATEGY_UNSPECIFIED = 0;
    TREE_PER_CLASS = 1;
    FULL_HESSIAN = 2;
    DIAGONAL_HESSIAN = 3;
  }

  // Number of classes.
  uint32 num_classes = 1;

  // Fraction of features to consider in each tree sampled randomly
  // from all available features.
  oneof feature_fraction {
    float feature_fraction_per_tree = 2;
    float feature_fraction_per_level = 3;
  };

  // Regularization.
  TreeRegularizationConfig regularization = 4;

  // Constraints.
  TreeConstraintsConfig constraints = 5;

  // Pruning. POST_PRUNE is the default pruning mode.
  PruningMode pruning_mode = 8;

  // Growing Mode. LAYER_BY_LAYER is the default growing mode.
  GrowingMode growing_mode = 9;

  // Learning rate. By default we use fixed learning rate of 0.1.
  LearningRateConfig learning_rate_tuner = 6;

  // Multi-class strategy. By default we use TREE_PER_CLASS for binary
  // classification and linear regression. For other cases, we use
  // DIAGONAL_HESSIAN as the default.
  MultiClassStrategy multi_class_strategy = 10;

  // If you want to average the ensembles (for regularization), provide the
  // config below.
  AveragingConfig averaging_config = 11;
}