aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-16 09:07:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 09:12:14 -0700
commit0594892f3544ebb5d01d66fc09793e267f6e3e89 (patch)
tree93eac5b3cbd6dc82145740839fa4eb7011950137 /tensorflow/contrib/tensor_forest
parent2ff193eac54f03af1dc2490f691595e00f3c96b2 (diff)
Support GREATER_OR_EQUAL and GREATER_THAN splits in tensor_forest evaluation
PiperOrigin-RevId: 208992685
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc19
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h4
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc34
3 files changed, 48 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
index 6cb2c881e2..7716536ba4 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc
@@ -54,17 +54,24 @@ InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator(
CHECK(safe_strto32(test.feature_id().id().value(), &feature_num_))
<< "Invalid feature ID: [" << test.feature_id().id().value() << "]";
threshold_ = test.threshold().float_value();
- include_equals_ =
- test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL;
+ _test_type = test.type();
}
int32 InequalityDecisionNodeEvaluator::Decide(
const std::unique_ptr<TensorDataSet>& dataset, int example) const {
const float val = dataset->GetExampleValue(example, feature_num_);
- if (val < threshold_ || (include_equals_ && val == threshold_)) {
- return left_child_id_;
- } else {
- return right_child_id_;
+ switch (_test_type) {
+ case decision_trees::InequalityTest::LESS_OR_EQUAL:
+ return val <= threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::LESS_THAN:
+ return val < threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::GREATER_OR_EQUAL:
+ return val >= threshold_ ? left_child_id_ : right_child_id_;
+ case decision_trees::InequalityTest::GREATER_THAN:
+ return val > threshold_ ? left_child_id_ : right_child_id_;
+ default:
+ LOG(ERROR) << "Unknown split test type: " << _test_type;
+ return -1;
}
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
index 3db351c328..6497787f84 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h
@@ -55,9 +55,7 @@ class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator {
protected:
int32 feature_num_;
float threshold_;
-
- // If decision is '<=' as opposed to '<'.
- bool include_equals_;
+ ::tensorflow::decision_trees::InequalityTest_Type _test_type;
};
// Evaluator for splits with multiple weighted features.
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
index af5cf72a3c..3db1335563 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc
@@ -60,6 +60,40 @@ TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) {
ASSERT_EQ(eval->Decide(dataset, 4), 1);
}
+TEST(InequalityDecisionNodeEvaluatorTest, TestGreaterOrEqual) {
+ InequalityTest test;
+ test.mutable_feature_id()->mutable_id()->set_value("0");
+ test.mutable_threshold()->set_float_value(3.0);
+ test.set_type(InequalityTest::GREATER_OR_EQUAL);
+ std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
+ new InequalityDecisionNodeEvaluator(test, 0, 1));
+
+ std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
+ new tensorflow::tensorforest::TestableDataSet(
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
+
+ ASSERT_EQ(eval->Decide(dataset, 2), 1);
+ ASSERT_EQ(eval->Decide(dataset, 3), 0);
+ ASSERT_EQ(eval->Decide(dataset, 4), 0);
+}
+
+TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyGreater) {
+ InequalityTest test;
+ test.mutable_feature_id()->mutable_id()->set_value("0");
+ test.mutable_threshold()->set_float_value(3.0);
+ test.set_type(InequalityTest::GREATER_THAN);
+ std::unique_ptr<InequalityDecisionNodeEvaluator> eval(
+ new InequalityDecisionNodeEvaluator(test, 0, 1));
+
+ std::unique_ptr<tensorflow::tensorforest::TensorDataSet> dataset(
+ new tensorflow::tensorforest::TestableDataSet(
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1));
+
+ ASSERT_EQ(eval->Decide(dataset, 2), 1);
+ ASSERT_EQ(eval->Decide(dataset, 3), 1);
+ ASSERT_EQ(eval->Decide(dataset, 4), 0);
+}
+
TEST(MatchingDecisionNodeEvaluatorTest, Basic) {
MatchingValuesTest test;
test.mutable_feature_id()->mutable_id()->set_value("0");