diff options
author | 2017-09-21 11:18:05 -0700 | |
---|---|---|
committer | 2017-09-21 11:21:46 -0700 | |
commit | 054b88233bf6d6bc5b953fca50dbb01d108b2d18 (patch) | |
tree | 954f314c6507bda5ee323d32acc81900e89571ce /tensorflow/contrib/tensor_forest | |
parent | 2679dcfbaa491c764caa9e2d69b071dbc1b7978b (diff) |
Add fixed space sparse class stats handling.
PiperOrigin-RevId: 169570470
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
5 files changed, 439 insertions, 88 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc index 63bfc1aef1..3ce630e3a9 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -159,7 +159,7 @@ void ClassificationStats::AdditionalInitializationExample( } bool ClassificationStats::IsFinished() const { - bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1; + bool basic = (weight_sum_ >= split_after_samples_) && !is_pure(); return basic || finish_early_; } @@ -193,8 +193,11 @@ void ClassificationStats::AddExample( left_gini_->update(i, left_count(i, int_label), weight); } ClassificationAddLeftExample(i, int_label, weight); - } else if (right_gini_ != nullptr) { - right_gini_->update(i, right_count(i, int_label), weight); + } else { + if (right_gini_ != nullptr) { + right_gini_->update(i, right_count(i, int_label), weight); + } + ClassificationAddRightExample(i, int_label, weight); } } @@ -374,6 +377,41 @@ void ClassificationStats::CheckFinishEarlyBootstrap() { finish_early_ = worst_g1 < best_g2; } +bool ClassificationStats::BestSplit(SplitCandidate* best) const { + float min_score = FLT_MAX; + int best_index = -1; + float best_left_sum, best_right_sum; + + // Calculate sums. + for (int i = 0; i < num_splits(); ++i) { + float left_sum, right_sum; + const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); + // Find the lowest gini. + if (left_sum > 0 && right_sum > 0 && + split_score < min_score) { // useless check + min_score = split_score; + best_index = i; + best_left_sum = left_sum; + best_right_sum = right_sum; + } + } + + // This could happen if all the splits are useless. + if (best_index < 0) { + return false; + } + + // Fill in stats to be used for leaf model. + *best->mutable_split() = splits_[best_index]; + auto* left = best->mutable_left_stats(); + left->set_weight_sum(best_left_sum); + auto* right = best->mutable_right_stats(); + right->set_weight_sum(best_right_sum); + InitLeafClassStats(best_index, left, right); + + return true; +} + // ------------------------ Dense Classification --------------------------- // void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { Initialize(); @@ -449,52 +487,20 @@ float DenseClassificationGrowStats::GiniScore(int split, float* left_sum, return left_score + right_score; } -bool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const { - float min_score = FLT_MAX; - int best_index = -1; - float best_left_sum, best_right_sum; - - // Calculate sums. - for (int i = 0; i < num_splits(); ++i) { - float left_sum, right_sum; - const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); - // Find the lowest gini. - if (left_sum > 0 && right_sum > 0 && - split_score < min_score) { // useless check - min_score = split_score; - best_index = i; - best_left_sum = left_sum; - best_right_sum = right_sum; - } - } - - // This could happen if all the splits are useless. - if (best_index < 0) { - return false; - } - - // Fill in stats to be used for leaf model. - *best->mutable_split() = splits_[best_index]; - // Left - auto* left = best->mutable_left_stats(); - auto* left_class_stats = left->mutable_classification(); - left->set_weight_sum(best_left_sum); +void DenseClassificationGrowStats::InitLeafClassStats( + int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { + auto* left_class_stats = left_stats->mutable_classification(); auto* left_counts = left_class_stats->mutable_dense_counts(); for (int i = 0; i < params_.num_outputs(); ++i) { - left_counts->add_value()->set_float_value( - left_count(best_index, i)); + left_counts->add_value()->set_float_value(left_count(best_split_index, i)); } - // Right - auto* right = best->mutable_right_stats(); - auto* right_class_stats = right->mutable_classification(); - right->set_weight_sum(best_right_sum); + auto* right_class_stats = right_stats->mutable_classification(); auto* right_counts = right_class_stats->mutable_dense_counts(); for (int i = 0; i < params_.num_outputs(); ++i) { - right_counts->add_value()->set_float_value( - total_counts_[i] - left_count(best_index, i)); + right_counts->add_value()->set_float_value(total_counts_[i] - + left_count(best_split_index, i)); } - return true; } // ------------------------ Sparse Classification --------------------------- // @@ -584,49 +590,18 @@ float SparseClassificationGrowStats::GiniScore( return left_score + right_score; } -bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const { - float min_score = FLT_MAX; - int best_index = -1; - float best_left_sum = -1; - float best_right_sum = -1; - - // Find the lowest gini. - for (int i = 0; i < num_splits(); ++i) { - float left_sum, right_sum; - const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); - if (left_sum > 0 && right_sum > 0 && - split_score < min_score) { // useless check - min_score = split_score; - best_index = i; - best_left_sum = left_sum; - best_right_sum = right_sum; - } - } - - // This could happen if all the splits are useless. - if (best_index < 0) { - return false; - } - - // Fill in stats to be used for leaf model. - *best->mutable_split() = splits_[best_index]; - // Left - auto* left = best->mutable_left_stats(); - auto* left_class_stats = left->mutable_classification(); - left->set_weight_sum(best_left_sum); +void SparseClassificationGrowStats::InitLeafClassStats( + int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { + auto* left_class_stats = left_stats->mutable_classification(); auto* left_counts = left_class_stats->mutable_sparse_counts()->mutable_sparse_value(); - - // Right - auto* right = best->mutable_right_stats(); - auto* right_class_stats = right->mutable_classification(); - right->set_weight_sum(best_right_sum); + auto* right_class_stats = right_stats->mutable_classification(); auto* right_counts = right_class_stats->mutable_sparse_counts()->mutable_sparse_value(); for (const auto& entry : total_counts_) { - auto it = left_counts_[best_index].find(entry.first); - if (it == left_counts_[best_index].end()) { + auto it = left_counts_[best_split_index].find(entry.first); + if (it == left_counts_[best_split_index].end()) { (*right_counts)[entry.first].set_float_value(entry.second); } else { const float left = it->second; @@ -637,7 +612,184 @@ bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const { } } } - return true; +} + +// -------------------- FixedSizeClassStats --------------------------------- // + +// FixedSizeClassStats implements the "SpaceSaving" algorithm by +// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example +// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf + +int argmin(const std::unordered_map<int, float>& m) { + int c = -1; + float f = FLT_MAX; + for (const auto it : m) { + if (it.second < f) { + f = it.second; + c = it.first; + } + } + return c; +} + +void FixedSizeClassStats::accumulate(int c, float w) { + auto it = class_weights_.find(c); + if (it != class_weights_.end()) { + it->second += w; + if (c == smallest_weight_class_) { + smallest_weight_class_ = argmin(class_weights_); + } + return; + } + + if (class_weights_.size() < n_) { + class_weights_.insert(it, std::pair<int, float>(c, w)); + if (class_weights_.size() == n_) { + // Can't assume last added has the smallest weight, because the + // w's might be all different. + smallest_weight_class_ = argmin(class_weights_); + } + return; + } + + // This is the slightly unintuitive heart of the SpaceSaving algorithm: + // if the map is full and we see a new class, we find the entry with the + // smallest weight and "take it over": we add our weight to its weight, + // and assign it all to the new seen class. + it = class_weights_.find(smallest_weight_class_); + float new_weight = it->second + w; + class_weights_.erase(it); + class_weights_[c] = new_weight; + smallest_weight_class_ = argmin(class_weights_); +} + +float FixedSizeClassStats::get_weight(int c) const { + // Every entry in class_weights_ might be overstated by as much as the + // smallest_weight. We therefore assume that each has been overstated + // by smallest_weight / 2.0, and we re-distribute that mass over all + // num_classes_ classes. + float smallest_weight = 0.0; + auto it = class_weights_.find(smallest_weight_class_); + if (it != class_weights_.end()) { + smallest_weight = it->second; + } + float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_); + it = class_weights_.find(c); + if (it != class_weights_.end()) { + w += it->second - smallest_weight / 2.0; + } + return w; +} + +void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const { + *sum = 0.0; + *square = 0.0; + + float smallest_weight = 0.0; + auto it = class_weights_.find(smallest_weight_class_); + if (it != class_weights_.end()) { + smallest_weight = it->second; + } + + float w; + for (const auto it : class_weights_) { + *sum += it.second; + w = get_weight(it.first); + *square += w * w; + } + + w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_); + *square += (num_classes_ - n_) * w * w; +} + +void FixedSizeClassStats::ExtractFromProto( + const decision_trees::SparseVector& sparse_vector) { + for (const auto& it : sparse_vector.sparse_value()) { + class_weights_[it.first] = it.second.float_value(); + } + if (class_weights_.size() == n_) { + smallest_weight_class_ = argmin(class_weights_); + } +} + +void FixedSizeClassStats::PackToProto( + decision_trees::SparseVector* sparse_vector) const { + for (const auto it : class_weights_) { + (*sparse_vector->mutable_sparse_value())[it.first].set_float_value( + it.second); + } +} + +// --------------------- FixedSizeSparseClassificationGrowStats ------------- // + +void FixedSizeSparseClassificationGrowStats::ExtractFromProto( + const FertileSlot& slot) { + Initialize(); + if (!slot.has_post_init_leaf_stats()) { + return; + } + weight_sum_ = slot.post_init_leaf_stats().weight_sum(); + + // Candidate counts and splits. + int split_num = 0; + left_counts_.clear(); + right_counts_.clear(); + for (const auto& cand : slot.candidates()) { + AddSplit(cand.split(), nullptr, nullptr, -1); + const auto& left_stats = cand.left_stats().classification().sparse_counts(); + left_counts_.emplace_back(params_.num_classes_to_track(), + params_.num_outputs()); + left_counts_[split_num].ExtractFromProto(left_stats); + const auto& right_stats = + cand.right_stats().classification().sparse_counts(); + right_counts_.emplace_back(params_.num_classes_to_track(), + params_.num_outputs()); + right_counts_[split_num].ExtractFromProto(right_stats); + ++split_num; + } +} + +void FixedSizeSparseClassificationGrowStats::PackToProto( + FertileSlot* slot) const { + auto* slot_stats = slot->mutable_post_init_leaf_stats(); + slot_stats->set_weight_sum(weight_sum_); + + for (int split_num = 0; split_num < num_splits(); ++split_num) { + auto* cand = slot->add_candidates(); + *cand->mutable_split() = splits_[split_num]; + auto* left_stats = cand->mutable_left_stats() + ->mutable_classification() + ->mutable_sparse_counts(); + left_counts_[split_num].PackToProto(left_stats); + auto* right_stats = cand->mutable_right_stats() + ->mutable_classification() + ->mutable_sparse_counts(); + right_counts_[split_num].PackToProto(right_stats); + } +} + +float FixedSizeSparseClassificationGrowStats::GiniScore( + int split, float* left_sum, float* right_sum) const { + float left_square, right_square; + left_counts_[split].set_sum_and_square(left_sum, &left_square); + right_counts_[split].set_sum_and_square(right_sum, &right_square); + const int32 num_classes = params_.num_outputs(); + const float left_score = + WeightedSmoothedGini(*left_sum, left_square, num_classes); + const float right_score = + WeightedSmoothedGini(*right_sum, right_square, num_classes); + return left_score + right_score; +} + +void FixedSizeSparseClassificationGrowStats::InitLeafClassStats( + int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { + auto* left_class_stats = left_stats->mutable_classification(); + auto* left_counts = left_class_stats->mutable_sparse_counts(); + left_counts_[best_split_index].PackToProto(left_counts); + + auto* right_class_stats = right_stats->mutable_classification(); + auto* right_counts = right_class_stats->mutable_sparse_counts(); + right_counts_[best_split_index].PackToProto(right_counts); } // --------------------- Least Squares Regression --------------------------- // diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index ba73d1d246..3e41ab50b9 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -189,15 +189,29 @@ class ClassificationStats : public GrowStats { half_initialized_splits_.empty()); } + bool BestSplit(SplitCandidate* best) const override; + // When best_split_index has been chosen as the best split, + // InitLeafClassStats is used to initialize the LeafStat's of the two + // new leaves. + virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats, + LeafStat* right_stats) const = 0; + protected: virtual float GiniScore(int split, float* left_sum, float* right_sum) const = 0; - virtual int num_outputs_seen() const = 0; + + // is_pure should return true if at most one class label has been seen + // at the node, and false if two or more have been seen. + virtual bool is_pure() const = 0; virtual float left_count(int split, int class_num) const = 0; virtual float right_count(int split, int class_num) const = 0; virtual void ClassificationAddLeftExample( int split, int64 int_label, float weight) = 0; + virtual void ClassificationAddRightExample(int split, int64 int_label, + float weight) { + // Does nothing by default, but sub-classes can override. + } virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0; virtual void ClassificationAddSplitStats() = 0; @@ -301,7 +315,8 @@ class DenseClassificationGrowStats : public ClassificationStats { void ExtractFromProto(const FertileSlot& slot) override; void PackToProto(FertileSlot* slot) const override; - bool BestSplit(SplitCandidate* best) const override; + void InitLeafClassStats(int best_split_index, LeafStat* left_stats, + LeafStat* right_stats) const; protected: void ClassificationAddSplitStats() override { @@ -317,9 +332,7 @@ class DenseClassificationGrowStats : public ClassificationStats { num_outputs_seen_ = 0; } - int num_outputs_seen() const override { - return num_outputs_seen_; - } + bool is_pure() const override { return num_outputs_seen_ <= 1; } void ClassificationAddLeftExample(int split, int64 int_label, float weight) override { @@ -369,7 +382,8 @@ class SparseClassificationGrowStats : public ClassificationStats { void ExtractFromProto(const FertileSlot& slot) override; void PackToProto(FertileSlot* slot) const override; - bool BestSplit(SplitCandidate* best) const override; + void InitLeafClassStats(int best_split_index, LeafStat* left_stats, + LeafStat* right_stats) const; protected: void ClassificationAddSplitStats() override { @@ -384,7 +398,7 @@ class SparseClassificationGrowStats : public ClassificationStats { left_counts_.clear(); } - int num_outputs_seen() const override { return total_counts_.size(); } + bool is_pure() const override { return total_counts_.size() <= 1; } void ClassificationAddLeftExample(int split, int64 int_label, float weight) override { @@ -412,6 +426,111 @@ class SparseClassificationGrowStats : public ClassificationStats { std::vector<std::unordered_map<int, float>> left_counts_; }; +// Accumulates weights for the most popular classes while only using a +// fixed amount of space. +class FixedSizeClassStats { + public: + // n specifies how many classes are tracked. + FixedSizeClassStats(int n, int num_classes) + : n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {} + + // Add weight w to the class c. + void accumulate(int c, float w); + + // Return the approximate accumulated weight for class c. If c isn't one + // of the n-most popular classes, this can be 0 even if c has accumulated + // some weight. + float get_weight(int c) const; + + // Put the sum of all weights seen into *sum, and + // \sum_c get_weight(c)^2 + // into *square. *sum will be exact, but *square will be approximate. + void set_sum_and_square(float* sum, float* square) const; + + void ExtractFromProto(const decision_trees::SparseVector& sparse_vector); + void PackToProto(decision_trees::SparseVector* sparse_vector) const; + + private: + // For our typical use cases, n_ is between 10 and 100, so there's no + // need to track the smallest weight with a min_heap or the like. + int n_; + int num_classes_; + + // This tracks the class of the smallest weight, but isn't set until + // class_weights_.size() == n_. + int smallest_weight_class_; + + std::unordered_map<int, float> class_weights_; +}; + +// Tracks classification stats sparsely in a fixed amount of space. +class FixedSizeSparseClassificationGrowStats : public ClassificationStats { + public: + FixedSizeSparseClassificationGrowStats(const TensorForestParams& params, + int32 depth) + : ClassificationStats(params, depth) {} + + void Initialize() override { Clear(); } + + void ExtractFromProto(const FertileSlot& slot) override; + void PackToProto(FertileSlot* slot) const override; + + void InitLeafClassStats(int best_split_index, LeafStat* left_stats, + LeafStat* right_stats) const; + + protected: + void ClassificationAddSplitStats() override { + FixedSizeClassStats stats(params_.num_classes_to_track(), + params_.num_outputs()); + left_counts_.resize(num_splits(), stats); + right_counts_.resize(num_splits(), stats); + } + void ClassificationRemoveSplitStats(int split_num) override { + left_counts_.erase(left_counts_.begin() + split_num, + left_counts_.begin() + (split_num + 1)); + right_counts_.erase(right_counts_.begin() + split_num, + right_counts_.begin() + (split_num + 1)); + } + void ClearInternal() override { + left_counts_.clear(); + right_counts_.clear(); + } + + bool is_pure() const override { return first_two_classes_seen_.size() <= 1; } + + void ClassificationAddLeftExample(int split, int64 int_label, + float weight) override { + left_counts_[split].accumulate(int_label, weight); + } + void ClassificationAddRightExample(int split, int64 int_label, + float weight) override { + right_counts_[split].accumulate(int_label, weight); + } + void ClassificationAddTotalExample(int64 int_label, float weight) override { + if (is_pure()) { + first_two_classes_seen_.insert(int_label); + } + } + + float GiniScore(int split, float* left_sum, float* right_sum) const override; + + float left_count(int split, int class_num) const override { + return left_counts_[split].get_weight(class_num); + } + + float right_count(int split, int class_num) const override { + return right_counts_[split].get_weight(class_num); + } + + private: + std::vector<FixedSizeClassStats> left_counts_; + std::vector<FixedSizeClassStats> right_counts_; + + // We keep track of the first two class labels seen, so we can tell if + // the node is pure (= all of one class) or not. + std::set<int> first_two_classes_seen_; +}; + // Tracks regression stats using least-squares minimization. class LeastSquaresRegressionGrowStats : public GrowStats { public: diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc index fa959e8373..ceb58d2ead 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc @@ -29,6 +29,8 @@ using tensorflow::tensorforest::TestableInputTarget; using tensorflow::tensorforest::FertileSlot; using tensorflow::tensorforest::DenseClassificationGrowStats; using tensorflow::tensorforest::SparseClassificationGrowStats; +using tensorflow::tensorforest::FixedSizeClassStats; +using tensorflow::tensorforest::FixedSizeSparseClassificationGrowStats; using tensorflow::tensorforest::LeastSquaresRegressionGrowStats; using tensorflow::tensorforest::TensorForestParams; using tensorflow::tensorforest::SPLIT_FINISH_BASIC; @@ -327,7 +329,6 @@ TEST(GrowStatsLeastSquaresRegressionTest, Basic) { ASSERT_EQ(serialized_again, serialized); } - TEST(GrowStatsSparseClassificationTest, Basic) { TensorForestParams params; params.set_num_outputs(2); @@ -360,5 +361,74 @@ TEST(GrowStatsSparseClassificationTest, Basic) { ASSERT_EQ(serialized_again, serialized); } +TEST(FixedSizeClassStats, Exact) { + FixedSizeClassStats stats(10, 100); + + stats.accumulate(1, 1.0); + stats.accumulate(2, 2.0); + stats.accumulate(3, 3.0); + + EXPECT_EQ(stats.get_weight(1), 1.0); + EXPECT_EQ(stats.get_weight(2), 2.0); + EXPECT_EQ(stats.get_weight(3), 3.0); + + float sum; + float square; + stats.set_sum_and_square(&sum, &square); + + EXPECT_EQ(sum, 6.0); + EXPECT_EQ(square, 14.0); +} + +TEST(FixedSizeClassStats, Approximate) { + FixedSizeClassStats stats(5, 10); + + for (int i = 1; i <= 10; i++) { + stats.accumulate(i, i * 1.0); + } + + // We should be off by no more than *half* of the least weight + // in the class_weights_, which is 7. + float tolerance = 3.5; + for (int i = 1; i <= 10; i++) { + float diff = stats.get_weight(i) - i * 1.0; + EXPECT_LE(diff, tolerance); + EXPECT_GE(diff, -tolerance); + } +} + +TEST(GrowStatsFixedSizeSparseClassificationTest, Basic) { + TensorForestParams params; + params.set_num_outputs(2); + params.set_num_classes_to_track(5); + params.mutable_split_after_samples()->set_constant_value(2); + params.mutable_num_splits_to_consider()->set_constant_value(2); + std::unique_ptr<FixedSizeSparseClassificationGrowStats> stat( + new FixedSizeSparseClassificationGrowStats(params, 1)); + stat->Initialize(); + + std::vector<float> labels = {100, 1000, 1}; + std::vector<float> weights = {2.3, 20.3, 1.1}; + std::unique_ptr<TestableInputTarget> target( + new TestableInputTarget(labels, weights, 1)); + std::vector<int> branches = {1, 0, 1, 1, 0, 0}; + + RunBatch(stat.get(), target.get()); + CHECK(stat->IsFinished()); + + FertileSlot slot; + stat->PackToProto(&slot); + + string serialized = slot.DebugString(); + + std::unique_ptr<FixedSizeSparseClassificationGrowStats> new_stat( + new FixedSizeSparseClassificationGrowStats(params, 1)); + new_stat->ExtractFromProto(slot); + FertileSlot second_one; + new_stat->PackToProto(&second_one); + string serialized_again = second_one.DebugString(); + ASSERT_EQ(serialized_again, serialized); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index e5d1beae7f..cdb1d80a4b 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -55,6 +55,10 @@ std::unique_ptr<GrowStats> SplitCollectionOperator::CreateGrowStats( return std::unique_ptr<GrowStats>(new LeastSquaresRegressionGrowStats( params_, depth)); + case STATS_FIXED_SIZE_SPARSE_GINI: + return std::unique_ptr<GrowStats>( + new FixedSizeSparseClassificationGrowStats(params_, depth)); + default: LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type(); return nullptr; diff --git a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto index 29d115ab69..4545a8a675 100644 --- a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto +++ b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto @@ -20,7 +20,9 @@ enum StatsModelType { STATS_DENSE_GINI = 0; STATS_SPARSE_GINI = 1; STATS_LEAST_SQUARES_REGRESSION = 2; + // STATS_SPARSE_THEN_DENSE_GINI is deprecated and no longer supported. STATS_SPARSE_THEN_DENSE_GINI = 3; + STATS_FIXED_SIZE_SPARSE_GINI = 4; } // Allows selection of operations on the collection of split candidates. @@ -145,4 +147,8 @@ message TensorForestParams { // --------- Parameters for experimental features ---------------------- // string graph_dir = 16; int32 num_select_features = 17; + + // When using a FixedSizeSparseClassificationGrowStats, keep track of + // this many classes. + int32 num_classes_to_track = 24; } |