diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-13 06:30:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-13 06:34:55 -0700 |
commit | fb489b6d1117ad3d6dcea888d696cbbf8ca5307f (patch) | |
tree | df5744b1e9b58a0615cdd26925f4e33362ab1e62 /tensorflow/contrib/decision_trees | |
parent | 30bea6a1eb7cfc68fa926a96a48f22d8fabb350f (diff) |
Move generic tree representation proto to tensorflow/contrib.
PiperOrigin-RevId: 158838465
Diffstat (limited to 'tensorflow/contrib/decision_trees')
5 files changed, 260 insertions, 0 deletions
diff --git a/tensorflow/contrib/decision_trees/BUILD b/tensorflow/contrib/decision_trees/BUILD new file mode 100644 index 0000000000..4045b92f10 --- /dev/null +++ b/tensorflow/contrib/decision_trees/BUILD @@ -0,0 +1,19 @@ +# Files common to decision-tree algorithms. +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD new file mode 100644 index 0000000000..86174c5865 --- /dev/null +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -0,0 +1,26 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files([ + "LICENSE", + "generic_tree_model_proto.swig", +]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +tf_proto_library( + name = "generic_tree_model", + srcs = ["generic_tree_model.proto"], + cc_api_version = 2, + go_api_version = 2, + java_api_version = 2, +) + +tf_proto_library( + name = "generic_tree_model_extensions", + srcs = ["generic_tree_model_extensions.proto"], + cc_api_version = 2, + go_api_version = 2, + protodeps = [":generic_tree_model"], +) diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto b/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto new file mode 100644 index 0000000000..dd80b37f52 --- /dev/null +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model.proto @@ -0,0 +1,183 @@ +// Generic representation of tree-based models. + +// This proto establishes a shared standard: "fully compatible" projects should +// provide support for all reasonable models expressed through it. Therefore, +// it should be kept as simple as possible, and should never contain +// project-specific design choices. + +// Status: work in progress. This proto can change anytime without notice. + +syntax = "proto3"; +option cc_enable_arenas = true; + +package tensorflow.decision_trees; + +import "google/protobuf/any.proto"; +import "google/protobuf/wrappers.proto"; + +// A generic handle for any type of model. +message Model { + oneof model { + DecisionTree decision_tree = 1; + Ensemble ensemble = 2; + google.protobuf.Any custom_model = 3; + } + repeated google.protobuf.Any additional_data = 4; +} + +message ModelAndFeatures { + message Feature { + // TODO(jonasz): Remove this field, as it's confusing. Ctx: cr/153569450. + FeatureId feature_id = 1 [deprecated = true]; + repeated google.protobuf.Any additional_data = 2; + }; + // Given a FeatureId feature_id, the feature's description is in + // features[feature_id.id.value]. + map<string, Feature> features = 1; + Model model = 2; + repeated google.protobuf.Any additional_data = 3; +} + +// An ordered sequence of models. This message can be used to express bagged or +// boosted models, as well as custom ensembles. +message Ensemble { + message Member { + Model submodel = 1; + google.protobuf.Int32Value submodel_id = 2; + repeated google.protobuf.Any additional_data = 3; + } + repeated Member members = 100; // A higher id for more readable printing. + + // The presence of a certain combination_technique indicates how to combine + // the outputs of member models in order to compute the ensemble's output. + oneof combination_technique { + Summation summation_combination_technique = 1; + Averaging averaging_combination_technique = 2; + google.protobuf.Any custom_combination_technique = 3; + } + repeated google.protobuf.Any additional_data = 4; +} + +// When present, the Ensemble's output is the sum of member models' outputs. +message Summation { + repeated google.protobuf.Any additional_data = 1; +}; + + +// When present, the Ensemble's output is the average of member models' outputs. +message Averaging { + repeated google.protobuf.Any additional_data = 1; +}; + + +message DecisionTree { + repeated TreeNode nodes = 1; + repeated google.protobuf.Any additional_data = 2; +}; + + +message TreeNode { + // Following fields are provided for convenience and better readability. + // Filling them in is not required. + google.protobuf.Int32Value node_id = 1; + google.protobuf.Int32Value depth = 2; + google.protobuf.Int32Value subtree_size = 3; + + oneof node_type { + BinaryNode binary_node = 4; + Leaf leaf = 5; + google.protobuf.Any custom_node_type = 6; + } + + repeated google.protobuf.Any additional_data = 7; +} + + +message BinaryNode { + google.protobuf.Int32Value left_child_id = 1; + google.protobuf.Int32Value right_child_id = 2; + enum Direction { + LEFT = 0; + RIGHT = 1; + } + // When left_child_test is undefined for a particular datapoint (e.g. because + // it's not defined when feature value is missing), the datapoint should go + // in this direction. + Direction default_direction = 3; + // When a datapoint satisfies the test, it should be propagated to the left + // child. + oneof left_child_test { + InequalityTest inequality_left_child_test = 4; + google.protobuf.Any custom_left_child_test = 5; + } +}; + +// A SparseVector represents a vector in which only certain select elements +// are non-zero. Maps labels to values (e.g. class id to probability or count). +message SparseVector { + map<int64, Value> sparse_value = 1; +} + +message Vector { + repeated Value value = 1; +} + +message Leaf { + oneof leaf { + // The interpretation of the values held in the leaves of a decision tree + // is application specific, but some common cases are: + // 1) len(vector) = 1, and the floating point value[0] holds the class 0 + // probability in a two class classification problem. + // 2) len(vector) = 1, and the integer value[0] holds the class prediction. + // 3) The floating point value[i] holds the class i probability prediction. + // 4) The floating point value[i] holds the i-th component of the + // vector prediction in a regression problem. + // 5) sparse_vector holds the sparse class predictions for a classification + // problem with a large number of classes. + Vector vector = 1; + SparseVector sparse_vector = 2; + } + // For non-standard handling of leaves. + repeated google.protobuf.Any additional_data = 3; +}; + + +message FeatureId { + google.protobuf.StringValue id = 1; + repeated google.protobuf.Any additional_data = 2; +}; + +message ObliqueFeatures { + // total value is sum(features[i] * weights[i]). + repeated FeatureId features = 1; + repeated float weights = 2; +} + + +message InequalityTest { + // When the feature is missing, the test's outcome is undefined. + oneof FeatureSum { + FeatureId feature_id = 1; + ObliqueFeatures oblique = 4; + } + enum Type { + LESS_OR_EQUAL = 0; + LESS_THAN = 1; + GREATER_OR_EQUAL = 2; + GREATER_THAN = 3; + }; + Type type = 2; + Value threshold = 3; +}; + + +// Represents a single value of any type, e.g. 5 or "abc". +message Value { + oneof value { + float float_value = 1; + double double_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + google.protobuf.Any custom_value = 5; + } +}; diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto b/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto new file mode 100644 index 0000000000..4c0cceaddc --- /dev/null +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.proto @@ -0,0 +1,18 @@ +// Messages in this file are not part of the basic standard established by +// generic_tree_model.proto (see the toplevel comment in that file). + +syntax = "proto3"; + +package tensorflow.decision_trees; + +import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto"; + +// Used in generic_tree_model.BinaryNode.left_child_test. +// Tests whether the feature's value belongs to the specified list, +// (or does not belong if inverse=True). +message MatchingValuesTest { + // When the feature is missing, the test's outcome is undefined. + FeatureId feature_id = 1; + repeated Value value = 2; + bool inverse = 3; +} diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig new file mode 100644 index 0000000000..d3d201afd5 --- /dev/null +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig @@ -0,0 +1,14 @@ +////////// SWIG INCLUDE ////////// + +%include "net/proto/swig/protofunc.swig" + +#ifndef MUST_USE_RESULT +#error Use this file only as a %include or %import after google.swig. +#endif + +%{ +#include "third_party/tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +%} + +PROTO_INPUT(tensorflow::decision_trees::DecisionTree, decision_tree); +PROTO_IN_OUT(tensorflow::decision_trees::DecisionTree, decision_tree); |