aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto
blob: dd80b37f52e76421a16550fb2258c059c140215d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Generic representation of tree-based models.

// This proto establishes a shared standard: "fully compatible" projects should
// provide support for all reasonable models expressed through it. Therefore,
// it should be kept as simple as possible, and should never contain
// project-specific design choices.

// Status: work in progress. This proto can change anytime without notice.

syntax = "proto3";
option cc_enable_arenas = true;

package tensorflow.decision_trees;

import "google/protobuf/any.proto";
import "google/protobuf/wrappers.proto";

// A generic handle for any type of model.
message Model {
  oneof model {
    DecisionTree decision_tree = 1;
    Ensemble ensemble = 2;
    google.protobuf.Any custom_model = 3;
  }
  repeated google.protobuf.Any additional_data = 4;
}

message ModelAndFeatures {
  message Feature {
    // TODO(jonasz): Remove this field, as it's confusing. Ctx: cr/153569450.
    FeatureId feature_id = 1 [deprecated = true];
    repeated google.protobuf.Any additional_data = 2;
  };
  // Given a FeatureId feature_id, the feature's description is in
  // features[feature_id.id.value].
  map<string, Feature> features = 1;
  Model model = 2;
  repeated google.protobuf.Any additional_data = 3;
}

// An ordered sequence of models. This message can be used to express bagged or
// boosted models, as well as custom ensembles.
message Ensemble {
  message Member {
    Model submodel = 1;
    google.protobuf.Int32Value submodel_id = 2;
    repeated google.protobuf.Any additional_data = 3;
  }
  repeated Member members = 100; // A higher id for more readable printing.

  // The presence of a certain combination_technique indicates how to combine
  // the outputs of member models in order to compute the ensemble's output.
  oneof combination_technique {
    Summation summation_combination_technique = 1;
    Averaging averaging_combination_technique = 2;
    google.protobuf.Any custom_combination_technique = 3;
  }
  repeated google.protobuf.Any additional_data = 4;
}

// When present, the Ensemble's output is the sum of member models' outputs.
message Summation {
  repeated google.protobuf.Any additional_data = 1;
};


// When present, the Ensemble's output is the average of member models' outputs.
message Averaging {
  repeated google.protobuf.Any additional_data = 1;
};


message DecisionTree {
  repeated TreeNode nodes = 1;
  repeated google.protobuf.Any additional_data = 2;
};


message TreeNode {
  // Following fields are provided for convenience and better readability.
  // Filling them in is not required.
  google.protobuf.Int32Value node_id = 1;
  google.protobuf.Int32Value depth = 2;
  google.protobuf.Int32Value subtree_size = 3;

  oneof node_type {
    BinaryNode binary_node = 4;
    Leaf leaf = 5;
    google.protobuf.Any custom_node_type = 6;
  }

  repeated google.protobuf.Any additional_data = 7;
}


message BinaryNode {
  google.protobuf.Int32Value left_child_id = 1;
  google.protobuf.Int32Value right_child_id = 2;
  enum Direction {
    LEFT = 0;
    RIGHT = 1;
  }
  // When left_child_test is undefined for a particular datapoint (e.g. because
  // it's not defined when feature value is missing), the datapoint should go
  // in this direction.
  Direction default_direction = 3;
  // When a datapoint satisfies the test, it should be propagated to the left
  // child.
  oneof left_child_test {
    InequalityTest inequality_left_child_test = 4;
    google.protobuf.Any custom_left_child_test = 5;
  }
};

// A SparseVector represents a vector in which only certain select elements
// are non-zero.  Maps labels to values (e.g. class id to probability or count).
message SparseVector {
  map<int64, Value> sparse_value = 1;
}

message Vector {
  repeated Value value = 1;
}

message Leaf {
  oneof leaf {
    // The interpretation of the values held in the leaves of a decision tree
    // is application specific, but some common cases are:
    // 1) len(vector) = 1, and the floating point value[0] holds the class 0
    //    probability in a two class classification problem.
    // 2) len(vector) = 1, and the integer value[0] holds the class prediction.
    // 3) The floating point value[i] holds the class i probability prediction.
    // 4) The floating point value[i] holds the i-th component of the
    //    vector prediction in a regression problem.
    // 5) sparse_vector holds the sparse class predictions for a classification
    //    problem with a large number of classes.
    Vector vector = 1;
    SparseVector sparse_vector = 2;
  }
  // For non-standard handling of leaves.
  repeated google.protobuf.Any additional_data = 3;
};


message FeatureId {
  google.protobuf.StringValue id = 1;
  repeated google.protobuf.Any additional_data = 2;
};

message ObliqueFeatures {
  // total value is sum(features[i] * weights[i]).
  repeated FeatureId features = 1;
  repeated float weights = 2;
}


message InequalityTest {
  // When the feature is missing, the test's outcome is undefined.
  oneof FeatureSum {
    FeatureId feature_id = 1;
    ObliqueFeatures oblique = 4;
  }
  enum Type {
    LESS_OR_EQUAL = 0;
    LESS_THAN = 1;
    GREATER_OR_EQUAL = 2;
    GREATER_THAN = 3;
  };
  Type type = 2;
  Value threshold = 3;
};


// Represents a single value of any type, e.g. 5 or "abc".
message Value {
  oneof value {
    float float_value = 1;
    double double_value = 2;
    int32 int32_value = 3;
    int64 int64_value = 4;
    google.protobuf.Any custom_value = 5;
  }
};