aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py')
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 5cd37ec67e..2589504762 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -59,7 +59,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -132,7 +133,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -171,7 +173,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
class_id=-1,
feature_column_group_id=0,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
# .assertEmpty doesn't exist on ubuntu-contrib
self.assertEqual(0, len(partitions))