diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-16 09:07:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 09:12:14 -0700 |
commit | 0594892f3544ebb5d01d66fc09793e267f6e3e89 (patch) | |
tree | 93eac5b3cbd6dc82145740839fa4eb7011950137 /tensorflow/contrib/tensor_forest | |
parent | 2ff193eac54f03af1dc2490f691595e00f3c96b2 (diff) |
Support GREATER_OR_EQUAL and GREATER_THAN splits in tensor_forest evaluation
PiperOrigin-RevId: 208992685
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
3 files changed, 48 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc index 6cb2c881e2..7716536ba4 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc @@ -54,17 +54,24 @@ InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator( CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_)) << "Invalid feature ID: [" << test.feature_id().id().value() << "]"; threshold_ = test.threshold().float_value(); - include_equals_ = - test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL; + _test_type = test.type(); } int32 InequalityDecisionNodeEvaluator::Decide( const std::unique_ptr<TensorDataSet>& dataset, int example) const { const float val = dataset->GetExampleValue(example, feature_num_); - if (val < threshold_ || (include_equals_ && val == threshold_)) { - return left_child_id_; - } else { - return right_child_id_; + switch (_test_type) { + case decision_trees::InequalityTest::LESS_OR_EQUAL: + return val <= threshold_ ? left_child_id_ : right_child_id_; + case decision_trees::InequalityTest::LESS_THAN: + return val < threshold_ ? left_child_id_ : right_child_id_; + case decision_trees::InequalityTest::GREATER_OR_EQUAL: + return val >= threshold_ ? left_child_id_ : right_child_id_; + case decision_trees::InequalityTest::GREATER_THAN: + return val > threshold_ ? left_child_id_ : right_child_id_; + default: + LOG(ERROR) << "Unknown split test type: " << _test_type; + return -1; } } diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h index 3db351c328..6497787f84 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h @@ -55,9 +55,7 @@ class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { protected: int32 feature_num_; float threshold_; - - // If decision is '<=' as opposed to '<'. - bool include_equals_; + ::tensorflow::decision_trees::InequalityTest_Type _test_type; }; // Evaluator for splits with multiple weighted features. diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc index af5cf72a3c..3db1335563 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc @@ -60,6 +60,40 @@ TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) { ASSERT_EQ(eval->Decide(dataset, 4), 1); } +TEST(InequalityDecisionNodeEvaluatorTest, TestGreaterOrEqual) { + InequalityTest test; + test.mutable_feature_id()->mutable_id()->set_value("0"); + test.mutable_threshold()->set_float_value(3.0); + test.set_type(InequalityTest::GREATER_OR_EQUAL); + std::unique_ptr<InequalityDecisionNodeEvaluator> eval( + new InequalityDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1)); + + ASSERT_EQ(eval->Decide(dataset, 2), 1); + ASSERT_EQ(eval->Decide(dataset, 3), 0); + ASSERT_EQ(eval->Decide(dataset, 4), 0); +} + +TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyGreater) { + InequalityTest test; + test.mutable_feature_id()->mutable_id()->set_value("0"); + test.mutable_threshold()->set_float_value(3.0); + test.set_type(InequalityTest::GREATER_THAN); + std::unique_ptr<InequalityDecisionNodeEvaluator> eval( + new InequalityDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1)); + + ASSERT_EQ(eval->Decide(dataset, 2), 1); + ASSERT_EQ(eval->Decide(dataset, 3), 1); + ASSERT_EQ(eval->Decide(dataset, 4), 0); +} + TEST(MatchingDecisionNodeEvaluatorTest, Basic) { MatchingValuesTest test; test.mutable_feature_id()->mutable_id()->set_value("0"); |