aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc')
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc33
1 files changed, 15 insertions, 18 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
index 49e425642d..d43c068e46 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
@@ -17,6 +17,8 @@
namespace tensorflow {
namespace tensorforest {
+using decision_trees::Leaf;
+
std::unique_ptr<LeafModelOperator>
LeafModelOperatorFactory::CreateLeafModelOperator(
const TensorForestParams& params) {
@@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue(
}
void DenseClassificationLeafModelOperator::UpdateModel(
- LeafStat* leaf, const InputTarget* target,
- int example) const {
+ Leaf* leaf, const InputTarget* target, int example) const {
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
QCHECK_LT(int_label, params_.num_outputs())
<< "Got label greater than indicated number of classes. Is "
"params.num_classes set correctly?";
QCHECK_GE(int_label, 0);
- auto* val = leaf->mutable_classification()->mutable_dense_counts()
- ->mutable_value(int_label);
+ auto* val = leaf->mutable_vector()->mutable_value(int_label);
+
float weight = target->GetTargetWeight(example);
val->set_float_value(val->float_value() + weight);
- leaf->set_weight_sum(leaf->weight_sum() + weight);
}
-void DenseClassificationLeafModelOperator::InitModel(
- LeafStat* leaf) const {
+void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
for (int i = 0; i < params_.num_outputs(); ++i) {
- leaf->mutable_classification()->mutable_dense_counts()->add_value();
+ leaf->mutable_vector()->add_value();
}
}
@@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue(
}
void SparseClassificationLeafModelOperator::UpdateModel(
- LeafStat* leaf, const InputTarget* target,
- int example) const {
+ Leaf* leaf, const InputTarget* target, int example) const {
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
QCHECK_LT(int_label, params_.num_outputs())
<< "Got label greater than indicated number of classes. Is "
"params.num_classes set correctly?";
QCHECK_GE(int_label, 0);
const float weight = target->GetTargetWeight(example);
- leaf->set_weight_sum(leaf->weight_sum() + weight);
- auto value_map = leaf->mutable_classification()->mutable_sparse_counts()
- ->mutable_sparse_value();
+
+ auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
auto it = value_map->find(int_label);
if (it == value_map->end()) {
(*value_map)[int_label].set_float_value(weight);
@@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
}
void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
- LeafStat* leaf, const InputTarget* target, int example) const {
- if (leaf->classification().has_dense_counts()) {
+ Leaf* leaf, const InputTarget* target, int example) const {
+ if (leaf->has_vector()) {
return dense_->UpdateModel(leaf, target, example);
} else {
return sparse_->UpdateModel(leaf, target, example);
@@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue(
return leaf.vector().value(o).float_value();
}
-void RegressionLeafModelOperator::InitModel(
- LeafStat* leaf) const {
+void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
for (int i = 0; i < params_.num_outputs(); ++i) {
- leaf->mutable_regression()->mutable_mean_output()->add_value();
+ leaf->mutable_vector()->add_value();
}
}
void RegressionLeafModelOperator::ExportModel(
const LeafStat& stat, decision_trees::Leaf* leaf) const {
+ leaf->clear_vector();
for (int i = 0; i < params_.num_outputs(); ++i) {
const float new_val =
stat.regression().mean_output().value(i).float_value() /