aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-21 11:18:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-21 11:21:46 -0700
commit054b88233bf6d6bc5b953fca50dbb01d108b2d18 (patch)
tree954f314c6507bda5ee323d32acc81900e89571ce /tensorflow/contrib/tensor_forest
parent2679dcfbaa491c764caa9e2d69b071dbc1b7978b (diff)
Add fixed space sparse class stats handling.
PiperOrigin-RevId: 169570470
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc312
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h133
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc72
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc4
-rw-r--r--tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto6
5 files changed, 439 insertions, 88 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
index 63bfc1aef1..3ce630e3a9 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
@@ -159,7 +159,7 @@ void ClassificationStats::AdditionalInitializationExample(
}
bool ClassificationStats::IsFinished() const {
- bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1;
+ bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
return basic || finish_early_;
}
@@ -193,8 +193,11 @@ void ClassificationStats::AddExample(
left_gini_->update(i, left_count(i, int_label), weight);
}
ClassificationAddLeftExample(i, int_label, weight);
- } else if (right_gini_ != nullptr) {
- right_gini_->update(i, right_count(i, int_label), weight);
+ } else {
+ if (right_gini_ != nullptr) {
+ right_gini_->update(i, right_count(i, int_label), weight);
+ }
+ ClassificationAddRightExample(i, int_label, weight);
}
}
@@ -374,6 +377,41 @@ void ClassificationStats::CheckFinishEarlyBootstrap() {
finish_early_ = worst_g1 < best_g2;
}
+bool ClassificationStats::BestSplit(SplitCandidate* best) const {
+ float min_score = FLT_MAX;
+ int best_index = -1;
+ float best_left_sum, best_right_sum;
+
+ // Calculate sums.
+ for (int i = 0; i < num_splits(); ++i) {
+ float left_sum, right_sum;
+ const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
+ // Find the lowest gini.
+ if (left_sum > 0 && right_sum > 0 &&
+ split_score < min_score) { // useless check
+ min_score = split_score;
+ best_index = i;
+ best_left_sum = left_sum;
+ best_right_sum = right_sum;
+ }
+ }
+
+ // This could happen if all the splits are useless.
+ if (best_index < 0) {
+ return false;
+ }
+
+ // Fill in stats to be used for leaf model.
+ *best->mutable_split() = splits_[best_index];
+ auto* left = best->mutable_left_stats();
+ left->set_weight_sum(best_left_sum);
+ auto* right = best->mutable_right_stats();
+ right->set_weight_sum(best_right_sum);
+ InitLeafClassStats(best_index, left, right);
+
+ return true;
+}
+
// ------------------------ Dense Classification --------------------------- //
void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
Initialize();
@@ -449,52 +487,20 @@ float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
return left_score + right_score;
}
-bool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
- float min_score = FLT_MAX;
- int best_index = -1;
- float best_left_sum, best_right_sum;
-
- // Calculate sums.
- for (int i = 0; i < num_splits(); ++i) {
- float left_sum, right_sum;
- const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
- // Find the lowest gini.
- if (left_sum > 0 && right_sum > 0 &&
- split_score < min_score) { // useless check
- min_score = split_score;
- best_index = i;
- best_left_sum = left_sum;
- best_right_sum = right_sum;
- }
- }
-
- // This could happen if all the splits are useless.
- if (best_index < 0) {
- return false;
- }
-
- // Fill in stats to be used for leaf model.
- *best->mutable_split() = splits_[best_index];
- // Left
- auto* left = best->mutable_left_stats();
- auto* left_class_stats = left->mutable_classification();
- left->set_weight_sum(best_left_sum);
+void DenseClassificationGrowStats::InitLeafClassStats(
+ int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
+ auto* left_class_stats = left_stats->mutable_classification();
auto* left_counts = left_class_stats->mutable_dense_counts();
for (int i = 0; i < params_.num_outputs(); ++i) {
- left_counts->add_value()->set_float_value(
- left_count(best_index, i));
+ left_counts->add_value()->set_float_value(left_count(best_split_index, i));
}
- // Right
- auto* right = best->mutable_right_stats();
- auto* right_class_stats = right->mutable_classification();
- right->set_weight_sum(best_right_sum);
+ auto* right_class_stats = right_stats->mutable_classification();
auto* right_counts = right_class_stats->mutable_dense_counts();
for (int i = 0; i < params_.num_outputs(); ++i) {
- right_counts->add_value()->set_float_value(
- total_counts_[i] - left_count(best_index, i));
+ right_counts->add_value()->set_float_value(total_counts_[i] -
+ left_count(best_split_index, i));
}
- return true;
}
// ------------------------ Sparse Classification --------------------------- //
@@ -584,49 +590,18 @@ float SparseClassificationGrowStats::GiniScore(
return left_score + right_score;
}
-bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
- float min_score = FLT_MAX;
- int best_index = -1;
- float best_left_sum = -1;
- float best_right_sum = -1;
-
- // Find the lowest gini.
- for (int i = 0; i < num_splits(); ++i) {
- float left_sum, right_sum;
- const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
- if (left_sum > 0 && right_sum > 0 &&
- split_score < min_score) { // useless check
- min_score = split_score;
- best_index = i;
- best_left_sum = left_sum;
- best_right_sum = right_sum;
- }
- }
-
- // This could happen if all the splits are useless.
- if (best_index < 0) {
- return false;
- }
-
- // Fill in stats to be used for leaf model.
- *best->mutable_split() = splits_[best_index];
- // Left
- auto* left = best->mutable_left_stats();
- auto* left_class_stats = left->mutable_classification();
- left->set_weight_sum(best_left_sum);
+void SparseClassificationGrowStats::InitLeafClassStats(
+ int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
+ auto* left_class_stats = left_stats->mutable_classification();
auto* left_counts =
left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
-
- // Right
- auto* right = best->mutable_right_stats();
- auto* right_class_stats = right->mutable_classification();
- right->set_weight_sum(best_right_sum);
+ auto* right_class_stats = right_stats->mutable_classification();
auto* right_counts =
right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
for (const auto& entry : total_counts_) {
- auto it = left_counts_[best_index].find(entry.first);
- if (it == left_counts_[best_index].end()) {
+ auto it = left_counts_[best_split_index].find(entry.first);
+ if (it == left_counts_[best_split_index].end()) {
(*right_counts)[entry.first].set_float_value(entry.second);
} else {
const float left = it->second;
@@ -637,7 +612,184 @@ bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
}
}
}
- return true;
+}
+
+// -------------------- FixedSizeClassStats --------------------------------- //
+
+// FixedSizeClassStats implements the "SpaceSaving" algorithm by
+// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example
+// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
+
+int argmin(const std::unordered_map<int, float>& m) {
+ int c = -1;
+ float f = FLT_MAX;
+ for (const auto it : m) {
+ if (it.second < f) {
+ f = it.second;
+ c = it.first;
+ }
+ }
+ return c;
+}
+
+void FixedSizeClassStats::accumulate(int c, float w) {
+ auto it = class_weights_.find(c);
+ if (it != class_weights_.end()) {
+ it->second += w;
+ if (c == smallest_weight_class_) {
+ smallest_weight_class_ = argmin(class_weights_);
+ }
+ return;
+ }
+
+ if (class_weights_.size() < n_) {
+ class_weights_.insert(it, std::pair<int, float>(c, w));
+ if (class_weights_.size() == n_) {
+ // Can't assume last added has the smallest weight, because the
+ // w's might be all different.
+ smallest_weight_class_ = argmin(class_weights_);
+ }
+ return;
+ }
+
+ // This is the slightly unintuitive heart of the SpaceSaving algorithm:
+ // if the map is full and we see a new class, we find the entry with the
+ // smallest weight and "take it over": we add our weight to its weight,
+ // and assign it all to the new seen class.
+ it = class_weights_.find(smallest_weight_class_);
+ float new_weight = it->second + w;
+ class_weights_.erase(it);
+ class_weights_[c] = new_weight;
+ smallest_weight_class_ = argmin(class_weights_);
+}
+
+float FixedSizeClassStats::get_weight(int c) const {
+ // Every entry in class_weights_ might be overstated by as much as the
+ // smallest_weight. We therefore assume that each has been overstated
+ // by smallest_weight / 2.0, and we re-distribute that mass over all
+ // num_classes_ classes.
+ float smallest_weight = 0.0;
+ auto it = class_weights_.find(smallest_weight_class_);
+ if (it != class_weights_.end()) {
+ smallest_weight = it->second;
+ }
+ float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
+ it = class_weights_.find(c);
+ if (it != class_weights_.end()) {
+ w += it->second - smallest_weight / 2.0;
+ }
+ return w;
+}
+
+void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
+ *sum = 0.0;
+ *square = 0.0;
+
+ float smallest_weight = 0.0;
+ auto it = class_weights_.find(smallest_weight_class_);
+ if (it != class_weights_.end()) {
+ smallest_weight = it->second;
+ }
+
+ float w;
+ for (const auto it : class_weights_) {
+ *sum += it.second;
+ w = get_weight(it.first);
+ *square += w * w;
+ }
+
+ w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
+ *square += (num_classes_ - n_) * w * w;
+}
+
+void FixedSizeClassStats::ExtractFromProto(
+ const decision_trees::SparseVector& sparse_vector) {
+ for (const auto& it : sparse_vector.sparse_value()) {
+ class_weights_[it.first] = it.second.float_value();
+ }
+ if (class_weights_.size() == n_) {
+ smallest_weight_class_ = argmin(class_weights_);
+ }
+}
+
+void FixedSizeClassStats::PackToProto(
+ decision_trees::SparseVector* sparse_vector) const {
+ for (const auto it : class_weights_) {
+ (*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
+ it.second);
+ }
+}
+
+// --------------------- FixedSizeSparseClassificationGrowStats ------------- //
+
+void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
+ const FertileSlot& slot) {
+ Initialize();
+ if (!slot.has_post_init_leaf_stats()) {
+ return;
+ }
+ weight_sum_ = slot.post_init_leaf_stats().weight_sum();
+
+ // Candidate counts and splits.
+ int split_num = 0;
+ left_counts_.clear();
+ right_counts_.clear();
+ for (const auto& cand : slot.candidates()) {
+ AddSplit(cand.split(), nullptr, nullptr, -1);
+ const auto& left_stats = cand.left_stats().classification().sparse_counts();
+ left_counts_.emplace_back(params_.num_classes_to_track(),
+ params_.num_outputs());
+ left_counts_[split_num].ExtractFromProto(left_stats);
+ const auto& right_stats =
+ cand.right_stats().classification().sparse_counts();
+ right_counts_.emplace_back(params_.num_classes_to_track(),
+ params_.num_outputs());
+ right_counts_[split_num].ExtractFromProto(right_stats);
+ ++split_num;
+ }
+}
+
+void FixedSizeSparseClassificationGrowStats::PackToProto(
+ FertileSlot* slot) const {
+ auto* slot_stats = slot->mutable_post_init_leaf_stats();
+ slot_stats->set_weight_sum(weight_sum_);
+
+ for (int split_num = 0; split_num < num_splits(); ++split_num) {
+ auto* cand = slot->add_candidates();
+ *cand->mutable_split() = splits_[split_num];
+ auto* left_stats = cand->mutable_left_stats()
+ ->mutable_classification()
+ ->mutable_sparse_counts();
+ left_counts_[split_num].PackToProto(left_stats);
+ auto* right_stats = cand->mutable_right_stats()
+ ->mutable_classification()
+ ->mutable_sparse_counts();
+ right_counts_[split_num].PackToProto(right_stats);
+ }
+}
+
+float FixedSizeSparseClassificationGrowStats::GiniScore(
+ int split, float* left_sum, float* right_sum) const {
+ float left_square, right_square;
+ left_counts_[split].set_sum_and_square(left_sum, &left_square);
+ right_counts_[split].set_sum_and_square(right_sum, &right_square);
+ const int32 num_classes = params_.num_outputs();
+ const float left_score =
+ WeightedSmoothedGini(*left_sum, left_square, num_classes);
+ const float right_score =
+ WeightedSmoothedGini(*right_sum, right_square, num_classes);
+ return left_score + right_score;
+}
+
+void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
+ int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
+ auto* left_class_stats = left_stats->mutable_classification();
+ auto* left_counts = left_class_stats->mutable_sparse_counts();
+ left_counts_[best_split_index].PackToProto(left_counts);
+
+ auto* right_class_stats = right_stats->mutable_classification();
+ auto* right_counts = right_class_stats->mutable_sparse_counts();
+ right_counts_[best_split_index].PackToProto(right_counts);
}
// --------------------- Least Squares Regression --------------------------- //
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
index ba73d1d246..3e41ab50b9 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h
@@ -189,15 +189,29 @@ class ClassificationStats : public GrowStats {
half_initialized_splits_.empty());
}
+ bool BestSplit(SplitCandidate* best) const override;
+ // When best_split_index has been chosen as the best split,
+ // InitLeafClassStats is used to initialize the LeafStat's of the two
+ // new leaves.
+ virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
+ LeafStat* right_stats) const = 0;
+
protected:
virtual float GiniScore(int split, float* left_sum,
float* right_sum) const = 0;
- virtual int num_outputs_seen() const = 0;
+
+ // is_pure should return true if at most one class label has been seen
+ // at the node, and false if two or more have been seen.
+ virtual bool is_pure() const = 0;
virtual float left_count(int split, int class_num) const = 0;
virtual float right_count(int split, int class_num) const = 0;
virtual void ClassificationAddLeftExample(
int split, int64 int_label, float weight) = 0;
+ virtual void ClassificationAddRightExample(int split, int64 int_label,
+ float weight) {
+ // Does nothing by default, but sub-classes can override.
+ }
virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0;
virtual void ClassificationAddSplitStats() = 0;
@@ -301,7 +315,8 @@ class DenseClassificationGrowStats : public ClassificationStats {
void ExtractFromProto(const FertileSlot& slot) override;
void PackToProto(FertileSlot* slot) const override;
- bool BestSplit(SplitCandidate* best) const override;
+ void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
+ LeafStat* right_stats) const;
protected:
void ClassificationAddSplitStats() override {
@@ -317,9 +332,7 @@ class DenseClassificationGrowStats : public ClassificationStats {
num_outputs_seen_ = 0;
}
- int num_outputs_seen() const override {
- return num_outputs_seen_;
- }
+ bool is_pure() const override { return num_outputs_seen_ <= 1; }
void ClassificationAddLeftExample(int split, int64 int_label,
float weight) override {
@@ -369,7 +382,8 @@ class SparseClassificationGrowStats : public ClassificationStats {
void ExtractFromProto(const FertileSlot& slot) override;
void PackToProto(FertileSlot* slot) const override;
- bool BestSplit(SplitCandidate* best) const override;
+ void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
+ LeafStat* right_stats) const;
protected:
void ClassificationAddSplitStats() override {
@@ -384,7 +398,7 @@ class SparseClassificationGrowStats : public ClassificationStats {
left_counts_.clear();
}
- int num_outputs_seen() const override { return total_counts_.size(); }
+ bool is_pure() const override { return total_counts_.size() <= 1; }
void ClassificationAddLeftExample(int split, int64 int_label,
float weight) override {
@@ -412,6 +426,111 @@ class SparseClassificationGrowStats : public ClassificationStats {
std::vector<std::unordered_map<int, float>> left_counts_;
};
+// Accumulates weights for the most popular classes while only using a
+// fixed amount of space.
+class FixedSizeClassStats {
+ public:
+ // n specifies how many classes are tracked.
+ FixedSizeClassStats(int n, int num_classes)
+ : n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {}
+
+ // Add weight w to the class c.
+ void accumulate(int c, float w);
+
+ // Return the approximate accumulated weight for class c. If c isn't one
+ // of the n-most popular classes, this can be 0 even if c has accumulated
+ // some weight.
+ float get_weight(int c) const;
+
+ // Put the sum of all weights seen into *sum, and
+ // \sum_c get_weight(c)^2
+ // into *square. *sum will be exact, but *square will be approximate.
+ void set_sum_and_square(float* sum, float* square) const;
+
+ void ExtractFromProto(const decision_trees::SparseVector& sparse_vector);
+ void PackToProto(decision_trees::SparseVector* sparse_vector) const;
+
+ private:
+ // For our typical use cases, n_ is between 10 and 100, so there's no
+ // need to track the smallest weight with a min_heap or the like.
+ int n_;
+ int num_classes_;
+
+ // This tracks the class of the smallest weight, but isn't set until
+ // class_weights_.size() == n_.
+ int smallest_weight_class_;
+
+ std::unordered_map<int, float> class_weights_;
+};
+
+// Tracks classification stats sparsely in a fixed amount of space.
+class FixedSizeSparseClassificationGrowStats : public ClassificationStats {
+ public:
+ FixedSizeSparseClassificationGrowStats(const TensorForestParams& params,
+ int32 depth)
+ : ClassificationStats(params, depth) {}
+
+ void Initialize() override { Clear(); }
+
+ void ExtractFromProto(const FertileSlot& slot) override;
+ void PackToProto(FertileSlot* slot) const override;
+
+ void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
+ LeafStat* right_stats) const;
+
+ protected:
+ void ClassificationAddSplitStats() override {
+ FixedSizeClassStats stats(params_.num_classes_to_track(),
+ params_.num_outputs());
+ left_counts_.resize(num_splits(), stats);
+ right_counts_.resize(num_splits(), stats);
+ }
+ void ClassificationRemoveSplitStats(int split_num) override {
+ left_counts_.erase(left_counts_.begin() + split_num,
+ left_counts_.begin() + (split_num + 1));
+ right_counts_.erase(right_counts_.begin() + split_num,
+ right_counts_.begin() + (split_num + 1));
+ }
+ void ClearInternal() override {
+ left_counts_.clear();
+ right_counts_.clear();
+ }
+
+ bool is_pure() const override { return first_two_classes_seen_.size() <= 1; }
+
+ void ClassificationAddLeftExample(int split, int64 int_label,
+ float weight) override {
+ left_counts_[split].accumulate(int_label, weight);
+ }
+ void ClassificationAddRightExample(int split, int64 int_label,
+ float weight) override {
+ right_counts_[split].accumulate(int_label, weight);
+ }
+ void ClassificationAddTotalExample(int64 int_label, float weight) override {
+ if (is_pure()) {
+ first_two_classes_seen_.insert(int_label);
+ }
+ }
+
+ float GiniScore(int split, float* left_sum, float* right_sum) const override;
+
+ float left_count(int split, int class_num) const override {
+ return left_counts_[split].get_weight(class_num);
+ }
+
+ float right_count(int split, int class_num) const override {
+ return right_counts_[split].get_weight(class_num);
+ }
+
+ private:
+ std::vector<FixedSizeClassStats> left_counts_;
+ std::vector<FixedSizeClassStats> right_counts_;
+
+ // We keep track of the first two class labels seen, so we can tell if
+ // the node is pure (= all of one class) or not.
+ std::set<int> first_two_classes_seen_;
+};
+
// Tracks regression stats using least-squares minimization.
class LeastSquaresRegressionGrowStats : public GrowStats {
public:
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc
index fa959e8373..ceb58d2ead 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc
@@ -29,6 +29,8 @@ using tensorflow::tensorforest::TestableInputTarget;
using tensorflow::tensorforest::FertileSlot;
using tensorflow::tensorforest::DenseClassificationGrowStats;
using tensorflow::tensorforest::SparseClassificationGrowStats;
+using tensorflow::tensorforest::FixedSizeClassStats;
+using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats;
using tensorflow::tensorforest::LeastSquaresRegressionGrowStats;
using tensorflow::tensorforest::TensorForestParams;
using tensorflow::tensorforest::SPLIT_FINISH_BASIC;
@@ -327,7 +329,6 @@ TEST(GrowStatsLeastSquaresRegressionTest, Basic) {
ASSERT_EQ(serialized_again, serialized);
}
-
TEST(GrowStatsSparseClassificationTest, Basic) {
TensorForestParams params;
params.set_num_outputs(2);
@@ -360,5 +361,74 @@ TEST(GrowStatsSparseClassificationTest, Basic) {
ASSERT_EQ(serialized_again, serialized);
}
+TEST(FixedSizeClassStats, Exact) {
+ FixedSizeClassStats stats(10, 100);
+
+ stats.accumulate(1, 1.0);
+ stats.accumulate(2, 2.0);
+ stats.accumulate(3, 3.0);
+
+ EXPECT_EQ(stats.get_weight(1), 1.0);
+ EXPECT_EQ(stats.get_weight(2), 2.0);
+ EXPECT_EQ(stats.get_weight(3), 3.0);
+
+ float sum;
+ float square;
+ stats.set_sum_and_square(&sum, &square);
+
+ EXPECT_EQ(sum, 6.0);
+ EXPECT_EQ(square, 14.0);
+}
+
+TEST(FixedSizeClassStats, Approximate) {
+ FixedSizeClassStats stats(5, 10);
+
+ for (int i = 1; i <= 10; i++) {
+ stats.accumulate(i, i * 1.0);
+ }
+
+ // We should be off by no more than *half* of the least weight
+ // in the class_weights_, which is 7.
+ float tolerance = 3.5;
+ for (int i = 1; i <= 10; i++) {
+ float diff = stats.get_weight(i) - i * 1.0;
+ EXPECT_LE(diff, tolerance);
+ EXPECT_GE(diff, -tolerance);
+ }
+}
+
+TEST(GrowStatsFixedSizeSparseClassificationTest, Basic) {
+ TensorForestParams params;
+ params.set_num_outputs(2);
+ params.set_num_classes_to_track(5);
+ params.mutable_split_after_samples()->set_constant_value(2);
+ params.mutable_num_splits_to_consider()->set_constant_value(2);
+ std::unique_ptr<FixedSizeSparseClassificationGrowStats> stat(
+ new FixedSizeSparseClassificationGrowStats(params, 1));
+ stat->Initialize();
+
+ std::vector<float> labels = {100, 1000, 1};
+ std::vector<float> weights = {2.3, 20.3, 1.1};
+ std::unique_ptr<TestableInputTarget> target(
+ new TestableInputTarget(labels, weights, 1));
+ std::vector<int> branches = {1, 0, 1, 1, 0, 0};
+
+ RunBatch(stat.get(), target.get());
+ CHECK(stat->IsFinished());
+
+ FertileSlot slot;
+ stat->PackToProto(&slot);
+
+ string serialized = slot.DebugString();
+
+ std::unique_ptr<FixedSizeSparseClassificationGrowStats> new_stat(
+ new FixedSizeSparseClassificationGrowStats(params, 1));
+ new_stat->ExtractFromProto(slot);
+ FertileSlot second_one;
+ new_stat->PackToProto(&second_one);
+ string serialized_again = second_one.DebugString();
+ ASSERT_EQ(serialized_again, serialized);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
index e5d1beae7f..cdb1d80a4b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
@@ -55,6 +55,10 @@ std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats(
return std::unique_ptr<GrowStats>(new LeastSquaresRegressionGrowStats(
params_, depth));
+ case STATS_FIXED_SIZE_SPARSE_GINI:
+ return std::unique_ptr<GrowStats>(
+ new FixedSizeSparseClassificationGrowStats(params_, depth));
+
default:
LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type();
return nullptr;
diff --git a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto
index 29d115ab69..4545a8a675 100644
--- a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto
+++ b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto
@@ -20,7 +20,9 @@ enum StatsModelType {
STATS_DENSE_GINI = 0;
STATS_SPARSE_GINI = 1;
STATS_LEAST_SQUARES_REGRESSION = 2;
+ // STATS_SPARSE_THEN_DENSE_GINI is deprecated and no longer supported.
STATS_SPARSE_THEN_DENSE_GINI = 3;
+ STATS_FIXED_SIZE_SPARSE_GINI = 4;
}
// Allows selection of operations on the collection of split candidates.
@@ -145,4 +147,8 @@ message TensorForestParams {
// --------- Parameters for experimental features ---------------------- //
string graph_dir = 16;
int32 num_select_features = 17;
+
+ // When using a FixedSizeSparseClassificationGrowStats, keep track of
+ // this many classes.
+ int32 num_classes_to_track = 24;
}