aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/core/ops/tree_utils.h')
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/tree_utils.h93
1 files changed, 92 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
index 19b02e379e..067f0768d3 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
+++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -26,6 +27,7 @@
namespace tensorflow {
namespace tensorforest {
+// TODO(gilberth): Put these in protos so they can be shared by C++ and python.
// Indexes in the tree representation's 2nd dimension for children and features.
const int32 CHILDREN_INDEX = 0;
const int32 FEATURE_INDEX = 1;
@@ -34,6 +36,14 @@ const int32 FEATURE_INDEX = 1;
const int32 LEAF_NODE = -1;
const int32 FREE_NODE = -2;
+// Used to indicate column types, e.g. categorical vs. float
+enum DataColumnTypes {
+ kDataFloat = 0,
+ kDataCategorical = 1
+};
+
+
+
// Calculates the sum of a tensor.
template<typename T>
T Sum(Tensor counts) {
@@ -80,6 +90,20 @@ int32 BestFeatureRegression(const Tensor& total_sums,
const Tensor& split_sums,
const Tensor& split_squares, int32 accumulator);
+// Returns true if the best split's variance is sufficiently smaller than
+// that of the next best split.
+bool BestSplitDominatesRegression(
+ const Tensor& total_sums, const Tensor& total_squares,
+ const Tensor& split_sums, const Tensor& split_squares,
+ int32 accumulator);
+
+// Returns true if the best split's Gini impurity is sufficiently smaller than
+// that of the next best split.
+bool BestSplitDominatesClassification(
+ const Tensor& total_counts,
+ const Tensor& split_counts, int32 accumulator,
+ float dominate_fraction);
+
// Initializes everything in the given tensor to the given value.
template <typename T>
void Initialize(Tensor counts, T val = 0) {
@@ -90,7 +114,74 @@ void Initialize(Tensor counts, T val = 0) {
// Returns true if the point falls to the right (i.e., the selected feature
// of the input point is greater than the bias threshold), and false if it
// falls to the left.
-bool DecideNode(const Tensor& point, int32 feature, float bias);
+// Even though our input data is forced into float Tensors, it could have
+// originally been something else (e.g. categorical string data) which
+// we treat differently.
+bool DecideNode(const Tensor& point, int32 feature, float bias,
+ DataColumnTypes type = kDataFloat);
+
+// Returns input_data(i, feature) > bias.
+template <typename T>
+bool DecideDenseNode(const T& input_data,
+ int32 i, int32 feature, float bias,
+ DataColumnTypes type = kDataFloat) {
+ CHECK_LT(i, input_data.dimensions()[0]);
+ CHECK_LT(feature, input_data.dimensions()[1]);
+ return Decide(input_data(i, feature), bias, type);
+}
+
+// If T is a sparse float matrix represented by sparse_input_indices and
+// sparse_input_values, FindSparseValue returns T(i,j), or 0.0 if (i,j)
+// isn't present in sparse_input_indices. sparse_input_indices is assumed
+// to be sorted.
+template <typename T1, typename T2>
+float FindSparseValue(
+ const T1& sparse_input_indices,
+ const T2& sparse_input_values,
+ int32 i, int32 j) {
+ int32 low = 0;
+ int32 high = sparse_input_values.dimension(0);
+ while (low < high) {
+ int32 mid = (low + high) / 2;
+ int64 midi = internal::SubtleMustCopy(sparse_input_indices(mid, 0));
+ int64 midj = internal::SubtleMustCopy(sparse_input_indices(mid, 1));
+ if (midi == i) {
+ if (midj == j) {
+ return sparse_input_values(mid);
+ }
+ if (midj < j) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ continue;
+ }
+ if (midi < i) {
+ low = mid + 1;
+ } else {
+ high = mid;
+ }
+ }
+ return 0.0;
+}
+
+// Returns t(i, feature) > bias, where t is the sparse tensor represented by
+// sparse_input_indices and sparse_input_values.
+template <typename T1, typename T2>
+bool DecideSparseNode(
+ const T1& sparse_input_indices,
+ const T2& sparse_input_values,
+ int32 i, int32 feature, float bias,
+ DataColumnTypes type = kDataFloat) {
+ return Decide(
+ FindSparseValue(sparse_input_indices, sparse_input_values, i, feature),
+ bias, type);
+}
+
+// Returns left/right decision between the input value and the threshold bias.
+// For floating point types, the decision is value > bias, but for
+// categorical data, it is value != bias.
+bool Decide(float value, float bias, DataColumnTypes type = kDataFloat);
// Returns true if all the splits are initialized. Since they get initialized
// in order, we can simply infer this from the last split.