diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc')
-rw-r--r-- | tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc | 33 |
1 files changed, 15 insertions, 18 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc index 49e425642d..d43c068e46 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc @@ -17,6 +17,8 @@ namespace tensorflow { namespace tensorforest { +using decision_trees::Leaf; + std::unique_ptr<LeafModelOperator> LeafModelOperatorFactory::CreateLeafModelOperator( const TensorForestParams& params) { @@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue( } void DenseClassificationLeafModelOperator::UpdateModel( - LeafStat* leaf, const InputTarget* target, - int example) const { + Leaf* leaf, const InputTarget* target, int example) const { const int32 int_label = target->GetTargetAsClassIndex(example, 0); QCHECK_LT(int_label, params_.num_outputs()) << "Got label greater than indicated number of classes. Is " "params.num_classes set correctly?"; QCHECK_GE(int_label, 0); - auto* val = leaf->mutable_classification()->mutable_dense_counts() - ->mutable_value(int_label); + auto* val = leaf->mutable_vector()->mutable_value(int_label); + float weight = target->GetTargetWeight(example); val->set_float_value(val->float_value() + weight); - leaf->set_weight_sum(leaf->weight_sum() + weight); } -void DenseClassificationLeafModelOperator::InitModel( - LeafStat* leaf) const { +void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const { for (int i = 0; i < params_.num_outputs(); ++i) { - leaf->mutable_classification()->mutable_dense_counts()->add_value(); + leaf->mutable_vector()->add_value(); } } @@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue( } void SparseClassificationLeafModelOperator::UpdateModel( - LeafStat* leaf, const InputTarget* target, - int example) const { + Leaf* leaf, const InputTarget* target, int example) const { const int32 int_label = target->GetTargetAsClassIndex(example, 0); QCHECK_LT(int_label, params_.num_outputs()) << "Got label greater than indicated number of classes. Is " "params.num_classes set correctly?"; QCHECK_GE(int_label, 0); const float weight = target->GetTargetWeight(example); - leaf->set_weight_sum(leaf->weight_sum() + weight); - auto value_map = leaf->mutable_classification()->mutable_sparse_counts() - ->mutable_sparse_value(); + + auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value(); auto it = value_map->find(int_label); if (it == value_map->end()) { (*value_map)[int_label].set_float_value(weight); @@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue( } void SparseOrDenseClassificationLeafModelOperator::UpdateModel( - LeafStat* leaf, const InputTarget* target, int example) const { - if (leaf->classification().has_dense_counts()) { + Leaf* leaf, const InputTarget* target, int example) const { + if (leaf->has_vector()) { return dense_->UpdateModel(leaf, target, example); } else { return sparse_->UpdateModel(leaf, target, example); @@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue( return leaf.vector().value(o).float_value(); } -void RegressionLeafModelOperator::InitModel( - LeafStat* leaf) const { +void RegressionLeafModelOperator::InitModel(Leaf* leaf) const { for (int i = 0; i < params_.num_outputs(); ++i) { - leaf->mutable_regression()->mutable_mean_output()->add_value(); + leaf->mutable_vector()->add_value(); } } void RegressionLeafModelOperator::ExportModel( const LeafStat& stat, decision_trees::Leaf* leaf) const { + leaf->clear_vector(); for (int i = 0; i < params_.num_outputs(); ++i) { const float new_val = stat.regression().mean_output().value(i).float_value() / |