aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/decision_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-13 06:30:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-13 06:34:55 -0700
commitfb489b6d1117ad3d6dcea888d696cbbf8ca5307f (patch)
treedf5744b1e9b58a0615cdd26925f4e33362ab1e62 /tensorflow/contrib/decision_trees
parent30bea6a1eb7cfc68fa926a96a48f22d8fabb350f (diff)
Move generic tree representation proto to tensorflow/contrib.
PiperOrigin-RevId: 158838465
Diffstat (limited to 'tensorflow/contrib/decision_trees')
-rw-r--r--tensorflow/contrib/decision_trees/BUILD19
-rw-r--r--tensorflow/contrib/decision_trees/proto/BUILD26
-rw-r--r--tensorflow/contrib/decision_trees/proto/generic_tree_model.proto183
-rw-r--r--tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto18
-rw-r--r--tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig14
5 files changed, 260 insertions, 0 deletions
diff --git a/tensorflow/contrib/decision_trees/BUILD b/tensorflow/contrib/decision_trees/BUILD
new file mode 100644
index 0000000000..4045b92f10
--- /dev/null
+++ b/tensorflow/contrib/decision_trees/BUILD
@@ -0,0 +1,19 @@
+# Files common to decision-tree algorithms.
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
new file mode 100644
index 0000000000..86174c5865
--- /dev/null
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -0,0 +1,26 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files([
+ "LICENSE",
+ "generic_tree_model_proto.swig",
+])
+
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+tf_proto_library(
+ name = "generic_tree_model",
+ srcs = ["generic_tree_model.proto"],
+ cc_api_version = 2,
+ go_api_version = 2,
+ java_api_version = 2,
+)
+
+tf_proto_library(
+ name = "generic_tree_model_extensions",
+ srcs = ["generic_tree_model_extensions.proto"],
+ cc_api_version = 2,
+ go_api_version = 2,
+ protodeps = [":generic_tree_model"],
+)
diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto b/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto
new file mode 100644
index 0000000000..dd80b37f52
--- /dev/null
+++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto
@@ -0,0 +1,183 @@
+// Generic representation of tree-based models.
+
+// This proto establishes a shared standard: "fully compatible" projects should
+// provide support for all reasonable models expressed through it. Therefore,
+// it should be kept as simple as possible, and should never contain
+// project-specific design choices.
+
+// Status: work in progress. This proto can change anytime without notice.
+
+syntax = "proto3";
+option cc_enable_arenas = true;
+
+package tensorflow.decision_trees;
+
+import "google/protobuf/any.proto";
+import "google/protobuf/wrappers.proto";
+
+// A generic handle for any type of model.
+message Model {
+ oneof model {
+ DecisionTree decision_tree = 1;
+ Ensemble ensemble = 2;
+ google.protobuf.Any custom_model = 3;
+ }
+ repeated google.protobuf.Any additional_data = 4;
+}
+
+message ModelAndFeatures {
+ message Feature {
+ // TODO(jonasz): Remove this field, as it's confusing. Ctx: cr/153569450.
+ FeatureId feature_id = 1 [deprecated = true];
+ repeated google.protobuf.Any additional_data = 2;
+ };
+ // Given a FeatureId feature_id, the feature's description is in
+ // features[feature_id.id.value].
+ map<string, Feature> features = 1;
+ Model model = 2;
+ repeated google.protobuf.Any additional_data = 3;
+}
+
+// An ordered sequence of models. This message can be used to express bagged or
+// boosted models, as well as custom ensembles.
+message Ensemble {
+ message Member {
+ Model submodel = 1;
+ google.protobuf.Int32Value submodel_id = 2;
+ repeated google.protobuf.Any additional_data = 3;
+ }
+ repeated Member members = 100; // A higher id for more readable printing.
+
+ // The presence of a certain combination_technique indicates how to combine
+ // the outputs of member models in order to compute the ensemble's output.
+ oneof combination_technique {
+ Summation summation_combination_technique = 1;
+ Averaging averaging_combination_technique = 2;
+ google.protobuf.Any custom_combination_technique = 3;
+ }
+ repeated google.protobuf.Any additional_data = 4;
+}
+
+// When present, the Ensemble's output is the sum of member models' outputs.
+message Summation {
+ repeated google.protobuf.Any additional_data = 1;
+};
+
+
+// When present, the Ensemble's output is the average of member models' outputs.
+message Averaging {
+ repeated google.protobuf.Any additional_data = 1;
+};
+
+
+message DecisionTree {
+ repeated TreeNode nodes = 1;
+ repeated google.protobuf.Any additional_data = 2;
+};
+
+
+message TreeNode {
+ // Following fields are provided for convenience and better readability.
+ // Filling them in is not required.
+ google.protobuf.Int32Value node_id = 1;
+ google.protobuf.Int32Value depth = 2;
+ google.protobuf.Int32Value subtree_size = 3;
+
+ oneof node_type {
+ BinaryNode binary_node = 4;
+ Leaf leaf = 5;
+ google.protobuf.Any custom_node_type = 6;
+ }
+
+ repeated google.protobuf.Any additional_data = 7;
+}
+
+
+message BinaryNode {
+ google.protobuf.Int32Value left_child_id = 1;
+ google.protobuf.Int32Value right_child_id = 2;
+ enum Direction {
+ LEFT = 0;
+ RIGHT = 1;
+ }
+ // When left_child_test is undefined for a particular datapoint (e.g. because
+ // it's not defined when feature value is missing), the datapoint should go
+ // in this direction.
+ Direction default_direction = 3;
+ // When a datapoint satisfies the test, it should be propagated to the left
+ // child.
+ oneof left_child_test {
+ InequalityTest inequality_left_child_test = 4;
+ google.protobuf.Any custom_left_child_test = 5;
+ }
+};
+
+// A SparseVector represents a vector in which only certain select elements
+// are non-zero. Maps labels to values (e.g. class id to probability or count).
+message SparseVector {
+ map<int64, Value> sparse_value = 1;
+}
+
+message Vector {
+ repeated Value value = 1;
+}
+
+message Leaf {
+ oneof leaf {
+ // The interpretation of the values held in the leaves of a decision tree
+ // is application specific, but some common cases are:
+ // 1) len(vector) = 1, and the floating point value[0] holds the class 0
+ // probability in a two class classification problem.
+ // 2) len(vector) = 1, and the integer value[0] holds the class prediction.
+ // 3) The floating point value[i] holds the class i probability prediction.
+ // 4) The floating point value[i] holds the i-th component of the
+ // vector prediction in a regression problem.
+ // 5) sparse_vector holds the sparse class predictions for a classification
+ // problem with a large number of classes.
+ Vector vector = 1;
+ SparseVector sparse_vector = 2;
+ }
+ // For non-standard handling of leaves.
+ repeated google.protobuf.Any additional_data = 3;
+};
+
+
+message FeatureId {
+ google.protobuf.StringValue id = 1;
+ repeated google.protobuf.Any additional_data = 2;
+};
+
+message ObliqueFeatures {
+ // total value is sum(features[i] * weights[i]).
+ repeated FeatureId features = 1;
+ repeated float weights = 2;
+}
+
+
+message InequalityTest {
+ // When the feature is missing, the test's outcome is undefined.
+ oneof FeatureSum {
+ FeatureId feature_id = 1;
+ ObliqueFeatures oblique = 4;
+ }
+ enum Type {
+ LESS_OR_EQUAL = 0;
+ LESS_THAN = 1;
+ GREATER_OR_EQUAL = 2;
+ GREATER_THAN = 3;
+ };
+ Type type = 2;
+ Value threshold = 3;
+};
+
+
+// Represents a single value of any type, e.g. 5 or "abc".
+message Value {
+ oneof value {
+ float float_value = 1;
+ double double_value = 2;
+ int32 int32_value = 3;
+ int64 int64_value = 4;
+ google.protobuf.Any custom_value = 5;
+ }
+};
diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto b/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto
new file mode 100644
index 0000000000..4c0cceaddc
--- /dev/null
+++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto
@@ -0,0 +1,18 @@
+// Messages in this file are not part of the basic standard established by
+// generic_tree_model.proto (see the toplevel comment in that file).
+
+syntax = "proto3";
+
+package tensorflow.decision_trees;
+
+import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto";
+
+// Used in generic_tree_model.BinaryNode.left_child_test.
+// Tests whether the feature's value belongs to the specified list,
+// (or does not belong if inverse=True).
+message MatchingValuesTest {
+ // When the feature is missing, the test's outcome is undefined.
+ FeatureId feature_id = 1;
+ repeated Value value = 2;
+ bool inverse = 3;
+}
diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig
new file mode 100644
index 0000000000..d3d201afd5
--- /dev/null
+++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig
@@ -0,0 +1,14 @@
+////////// SWIG INCLUDE //////////
+
+%include "net/proto/swig/protofunc.swig"
+
+#ifndef MUST_USE_RESULT
+#error Use this file only as a %include or %import after google.swig.
+#endif
+
+%{
+#include "third_party/tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
+%}
+
+PROTO_INPUT(tensorflow::decision_trees::DecisionTree, decision_tree);
+PROTO_IN_OUT(tensorflow::decision_trees::DecisionTree, decision_tree);