diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/core/ops/tree_utils.h')
-rw-r--r-- | tensorflow/contrib/tensor_forest/core/ops/tree_utils.h | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h index 19b02e379e..067f0768d3 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h @@ -19,6 +19,7 @@ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -26,6 +27,7 @@ namespace tensorflow { namespace tensorforest { +// TODO(gilberth): Put these in protos so they can be shared by C++ and python. // Indexes in the tree representation's 2nd dimension for children and features. const int32 CHILDREN_INDEX = 0; const int32 FEATURE_INDEX = 1; @@ -34,6 +36,14 @@ const int32 FEATURE_INDEX = 1; const int32 LEAF_NODE = -1; const int32 FREE_NODE = -2; +// Used to indicate column types, e.g. categorical vs. float +enum DataColumnTypes { + kDataFloat = 0, + kDataCategorical = 1 +}; + + + // Calculates the sum of a tensor. template<typename T> T Sum(Tensor counts) { @@ -80,6 +90,20 @@ int32 BestFeatureRegression(const Tensor& total_sums, const Tensor& split_sums, const Tensor& split_squares, int32 accumulator); +// Returns true if the best split's variance is sufficiently smaller than +// that of the next best split. +bool BestSplitDominatesRegression( + const Tensor& total_sums, const Tensor& total_squares, + const Tensor& split_sums, const Tensor& split_squares, + int32 accumulator); + +// Returns true if the best split's Gini impurity is sufficiently smaller than +// that of the next best split. +bool BestSplitDominatesClassification( + const Tensor& total_counts, + const Tensor& split_counts, int32 accumulator, + float dominate_fraction); + // Initializes everything in the given tensor to the given value. template <typename T> void Initialize(Tensor counts, T val = 0) { @@ -90,7 +114,74 @@ void Initialize(Tensor counts, T val = 0) { // Returns true if the point falls to the right (i.e., the selected feature // of the input point is greater than the bias threshold), and false if it // falls to the left. -bool DecideNode(const Tensor& point, int32 feature, float bias); +// Even though our input data is forced into float Tensors, it could have +// originally been something else (e.g. categorical string data) which +// we treat differently. +bool DecideNode(const Tensor& point, int32 feature, float bias, + DataColumnTypes type = kDataFloat); + +// Returns input_data(i, feature) > bias. +template <typename T> +bool DecideDenseNode(const T& input_data, + int32 i, int32 feature, float bias, + DataColumnTypes type = kDataFloat) { + CHECK_LT(i, input_data.dimensions()[0]); + CHECK_LT(feature, input_data.dimensions()[1]); + return Decide(input_data(i, feature), bias, type); +} + +// If T is a sparse float matrix represented by sparse_input_indices and +// sparse_input_values, FindSparseValue returns T(i,j), or 0.0 if (i,j) +// isn't present in sparse_input_indices. sparse_input_indices is assumed +// to be sorted. +template <typename T1, typename T2> +float FindSparseValue( + const T1& sparse_input_indices, + const T2& sparse_input_values, + int32 i, int32 j) { + int32 low = 0; + int32 high = sparse_input_values.dimension(0); + while (low < high) { + int32 mid = (low + high) / 2; + int64 midi = internal::SubtleMustCopy(sparse_input_indices(mid, 0)); + int64 midj = internal::SubtleMustCopy(sparse_input_indices(mid, 1)); + if (midi == i) { + if (midj == j) { + return sparse_input_values(mid); + } + if (midj < j) { + low = mid + 1; + } else { + high = mid; + } + continue; + } + if (midi < i) { + low = mid + 1; + } else { + high = mid; + } + } + return 0.0; +} + +// Returns t(i, feature) > bias, where t is the sparse tensor represented by +// sparse_input_indices and sparse_input_values. +template <typename T1, typename T2> +bool DecideSparseNode( + const T1& sparse_input_indices, + const T2& sparse_input_values, + int32 i, int32 feature, float bias, + DataColumnTypes type = kDataFloat) { + return Decide( + FindSparseValue(sparse_input_indices, sparse_input_values, i, feature), + bias, type); +} + +// Returns left/right decision between the input value and the threshold bias. +// For floating point types, the decision is value > bias, but for +// categorical data, it is value != bias. +bool Decide(float value, float bias, DataColumnTypes type = kDataFloat); // Returns true if all the splits are initialized. Since they get initialized // in order, we can simply infer this from the last split. |