aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc7
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc60
-rw-r--r--tensorflow/contrib/boosted_trees/ops/prediction_ops.cc7
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py50
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py8
7 files changed, 107 insertions, 46 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
index 2e99f42ff0..daca049548 100644
--- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc
@@ -88,10 +88,9 @@ class GradientTreesPredictionOp : public OpKernel {
context, ParseProtoUnlimited(&learner_config, learner_config_str),
errors::InvalidArgument("Unable to parse learner config config."));
- prediction_vector_size_ =
- learner_config.multi_class_strategy() == LearnerConfig::TREE_PER_CLASS
- ? num_classes_ - 1
- : num_classes_;
+ bool reduce_dim;
+ OP_REQUIRES_OK(context, context->GetAttr("reduce_dim", &reduce_dim));
+ prediction_vector_size_ = reduce_dim ? num_classes_ - 1 : num_classes_;
only_finalized_trees_ =
learner_config.growing_mode() == learner_config.WHOLE_TREE;
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index a2bf374cf5..29635bb3c4 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -70,6 +70,32 @@ class BaseBuildSplitOp : public OpKernel {
multiclass_strategy_, grad_stats);
}
+ void ReadClassId(OpKernelContext* const context, int32* class_id) {
+ const Tensor* class_id_t;
+ OP_REQUIRES_OK(context, context->input("class_id", &class_id_t));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(class_id_t->shape()),
+ errors::InvalidArgument("class_id must be a scalar."));
+ *class_id = class_id_t->scalar<int32>()();
+ }
+
+ void FillLeaf(const int class_id, const NodeStats& best_node_stats,
+ boosted_trees::trees::Leaf* leaf) const {
+ if (class_id == -1) {
+ // This would be the case either for TREE_PER_CLASS with only 2 classes,
+ // or for other multiclass strategies.
+ for (float f : best_node_stats.weight_contribution) {
+ leaf->mutable_vector()->add_value(f);
+ }
+ } else {
+ CHECK(best_node_stats.weight_contribution.size() == 1)
+ << "Weight contribution size = "
+ << best_node_stats.weight_contribution.size();
+ leaf->mutable_sparse_vector()->add_index(class_id);
+ leaf->mutable_sparse_vector()->add_value(
+ best_node_stats.weight_contribution[0]);
+ }
+ }
+
protected:
LearnerConfig_MultiClassStrategy multiclass_strategy_;
int32 feature_column_group_id_;
@@ -110,6 +136,9 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ int class_id;
+ ReadClassId(context, &class_id);
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
partition_boundaries.push_back(0);
@@ -194,12 +223,9 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp {
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- for (float f : best_left_node_stats.weight_contribution) {
- left_child->mutable_vector()->add_value(f);
- }
- for (float f : best_right_node_stats.weight_contribution) {
- right_child->mutable_vector()->add_value(f);
- }
+
+ FillLeaf(class_id, best_left_node_stats, left_child);
+ FillLeaf(class_id, best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
best_gain - root_stats.gain - tree_complexity_regularization_;
@@ -244,6 +270,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ int class_id;
+ ReadClassId(context, &class_id);
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
std::vector<int32> non_empty_partitions;
@@ -369,12 +398,8 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp {
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- for (float f : best_left_node_stats.weight_contribution) {
- left_child->mutable_vector()->add_value(f);
- }
- for (float f : best_right_node_stats.weight_contribution) {
- right_child->mutable_vector()->add_value(f);
- }
+ FillLeaf(class_id, best_left_node_stats, left_child);
+ FillLeaf(class_id, best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
best_gain - root_stats.gain - tree_complexity_regularization_;
@@ -417,6 +442,9 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
const Tensor* hessians_t;
OP_REQUIRES_OK(context, context->input("hessians", &hessians_t));
+ int class_id;
+ ReadClassId(context, &class_id);
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
std::vector<int32> non_empty_partitions;
@@ -494,12 +522,8 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp {
equality_split->set_feature_id(feature_ids(best_feature_idx));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- for (float f : best_left_node_stats.weight_contribution) {
- left_child->mutable_vector()->add_value(f);
- }
- for (float f : best_right_node_stats.weight_contribution) {
- right_child->mutable_vector()->add_value(f);
- }
+ FillLeaf(class_id, best_left_node_stats, left_child);
+ FillLeaf(class_id, best_right_node_stats, right_child);
split_info.SerializeToString(&output_splits(root_idx));
gains(root_idx) =
best_gain - root_stats.gain - tree_complexity_regularization_;
diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
index 8effb6f98f..3163590624 100644
--- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
@@ -29,9 +29,10 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
c->GetAttr("learner_config", &learner_config_str).IgnoreError();
LearnerConfig learner_config;
ParseProtoUnlimited(&learner_config, learner_config_str);
+
+ bool reduce_dim;
+ c->GetAttr("reduce_dim", &reduce_dim).IgnoreError();
// Sets the shape of the output as a matrix.
- const bool reduce_dim =
- learner_config.multi_class_strategy() == LearnerConfig::TREE_PER_CLASS;
c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
reduce_dim ? learner_config.num_classes() - 1
: learner_config.num_classes())});
@@ -51,6 +52,7 @@ REGISTER_OP("GradientTreesPrediction")
.Attr("apply_dropout: bool")
.Attr("apply_averaging: bool")
.Attr("center_bias: bool")
+ .Attr("reduce_dim: bool")
.Input("tree_ensemble_handle: resource")
.Input("seed: int64")
.Input("dense_float_features: num_dense_float_features * float")
@@ -75,6 +77,7 @@ num_sparse_float_features: Number of sparse float features.
num_sparse_int_features: Number of sparse int features.
use_locking: Whether to use locking.
seed: random seed to be used for dropout.
+reduce_dim: whether to reduce the dimension (legacy impl) or not.
apply_dropout: whether to apply dropout during prediction.
apply_averaging: whether averaging of tree ensembles should take place. If set
to true, will be based on AveragingConfig from learner_config.
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index 95ec738796..07cfd413bb 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -35,6 +35,7 @@ REGISTER_OP("BuildDenseInequalitySplits")
.Input("gradients: float32")
.Input("hessians: float32")
.Input("bucket_boundaries: float32")
+ .Input("class_id: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -95,6 +96,7 @@ REGISTER_OP("BuildSparseInequalitySplits")
.Input("gradients: float32")
.Input("hessians: float32")
.Input("bucket_boundaries: float32")
+ .Input("class_id: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -154,6 +156,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits")
.Input("feature_ids: int64")
.Input("gradients: float32")
.Input("hessians: float32")
+ .Input("class_id: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py
index 90cc07f661..8e62856854 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py
@@ -128,7 +128,8 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False)
+ center_bias=False,
+ reduce_dim=True)
self.assertAllClose(result.eval(), [[-0.4], [-0.4]])
stamp_token = model_ops.tree_ensemble_stamp_token(tree_ensemble_handle)
self.assertEqual(stamp_token.eval(), 3)
@@ -154,6 +155,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
_append_to_leaf(tree2.nodes.add().leaf, 0, 0.5)
_append_to_leaf(tree2.nodes.add().leaf, 1, 1.2)
_append_to_leaf(tree2.nodes.add().leaf, 0, -0.9)
+
tree_ensemble_handle = model_ops.tree_ensemble_variable(
stamp_token=7,
tree_ensemble_config=tree_ensemble_config.SerializeToString(),
@@ -187,7 +189,8 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False)
+ center_bias=False,
+ reduce_dim=True)
# Re-serialize tree.
stamp_token2, serialized_config2 = model_ops.tree_ensemble_serialize(
@@ -198,6 +201,8 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
# the second example will get the same bias class 1 -0.2 and leaf 3
# payload of class 1 1.2 hence [0.0, 1.0].
self.assertEqual(stamp_token2.eval(), 9)
+
+ # Class 2 does have scores in the leaf => it gets score 0.
self.assertEqual(serialized_config2.eval(), serialized_config)
self.assertAllClose(result.eval(), [[0.5, -0.2], [0, 1.0]])
@@ -267,7 +272,8 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False)
+ center_bias=False,
+ reduce_dim=True)
self.assertAllClose([[-1.1], [-1.1]], result.eval())
# Save before adding other trees.
val = my_saver.save(sess, save_path)
@@ -293,7 +299,8 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False)
+ center_bias=False,
+ reduce_dim=True)
self.assertAllClose(result.eval(), [[-11.1], [-11.1]])
# Start a second session. In that session the parameter nodes
@@ -315,7 +322,8 @@ class ModelOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False)
+ center_bias=False,
+ reduce_dim=True)
# Make sure we only have the first and second tree.
# The third tree was added after the save.
self.assertAllClose(result.eval(), [[-1.1], [-1.1]])
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index e5624da965..2b64235bb2 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -163,7 +163,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
self.assertAllEqual([[0], [0]], result.eval())
self.assertAllEqual(result_no_dropout.eval(), result.eval())
# Empty dropout.
@@ -200,7 +201,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
self.assertAllClose([[-0.4], [-0.4]], result.eval())
self.assertAllEqual(result_no_dropout.eval(), result.eval())
@@ -240,7 +242,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
self.assertAllClose([[-0.4, 0.9], [-0.4, 0.9]], result.eval())
self.assertAllEqual(result_no_dropout.eval(), result.eval())
@@ -294,7 +297,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# The first example will get bias -0.4 from first tree and
# leaf 4 payload of -0.9 hence -1.3, the second example will
@@ -353,7 +357,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# All the examples should get only the bias since the second tree is
# non-finalized
@@ -411,7 +416,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# The first example will get bias -0.4 from first tree and
# leaf 4 payload of -0.9 hence -1.3, the second example will
@@ -471,7 +477,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# The first example will get bias -0.4 from first tree and
# leaf 4 payload of -0.9 hence -1.3, the second example will
@@ -530,7 +537,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# The first example will get bias class 1 -0.2 from first tree and
# leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2],
# the second example will get the same bias class 1 -0.2 and leaf 3
@@ -541,7 +549,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
# Empty dropout.
self.assertAllEqual([[], []], dropout_info.eval())
- # For all non-tree-per class multiclass handling strategies, predictions vec
+ # For tree-per-class multiclass handling strategies, predictions vec
# will have the size of the number of classes.
# This test is when leafs have SPARSE weights stored (class id and
# contribution).
@@ -591,7 +599,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=False))
# The first example will get bias class 1 -0.2 from first tree and
# leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2],
# the second example will get the same bias class 1 -0.2 and leaf 3
@@ -651,7 +660,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=False))
# The first example will get bias class 1 -0.2 and -2 for class 2 from
# first tree and leaf 2 payload (sparse feature missing) of 0.5 hence
# 0.5, -0.2], the second example will get the same bias and leaf 3 payload
@@ -679,7 +689,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=apply_dropout,
apply_averaging=apply_averaging,
- center_bias=center_bias)
+ center_bias=center_bias,
+ reduce_dim=True)
def testDropout(self):
with self.test_session():
@@ -918,7 +929,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=True,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
_, result_no_dropout_2, dropout_info_2 = (
prediction_ops.gradient_trees_prediction(
@@ -932,7 +944,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=True,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# Different seed.
_, result_no_dropout_3, dropout_info_3 = (
@@ -947,7 +960,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=True,
apply_averaging=False,
- center_bias=False))
+ center_bias=False,
+ reduce_dim=True))
# First seed with centering bias.
_, result_no_dropout_4, dropout_info_4 = (
@@ -962,7 +976,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=True,
apply_averaging=False,
- center_bias=True))
+ center_bias=True,
+ reduce_dim=True))
# The same seed returns the same results.
self.assertAllEqual(dropout_info_1.eval(), dropout_info_2.eval())
@@ -987,7 +1002,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
learner_config=learner_config.SerializeToString(),
apply_dropout=False,
apply_averaging=False,
- center_bias=False)
+ center_bias=False,
+ reduce_dim=True)
self.assertAllCloseAccordingToType(result.eval(),
result_no_dropout.eval())
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index aa504d1742..edf088b5fa 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -54,6 +54,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
l2_regularization=1,
tree_complexity_regularization=0,
min_node_weight=0,
+ class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
partitions, gains, splits = sess.run([partitions, gains, splits])
@@ -125,6 +126,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
l2_regularization=1,
tree_complexity_regularization=0,
min_node_weight=0,
+ class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
partitions, gains, splits = sess.run([partitions, gains, splits])
@@ -163,6 +165,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
l2_regularization=1,
tree_complexity_regularization=0,
min_node_weight=0,
+ class_id=-1,
feature_column_group_id=0,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
partitions, gains, splits = sess.run([partitions, gains, splits])
@@ -200,6 +203,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
feature_column_group_id=0,
bias_feature_id=-1,
+ class_id=-1,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
self.assertAllEqual([0, 1], partitions)
@@ -283,6 +287,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
feature_column_group_id=0,
bias_feature_id=-1,
+ class_id=-1,
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
@@ -326,6 +331,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
feature_column_group_id=0,
bias_feature_id=-1,
+ class_id=-1,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -420,6 +426,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
feature_column_group_id=0,
bias_feature_id=-1,
+ class_id=-1,
multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -456,6 +463,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
min_node_weight=0,
feature_column_group_id=0,
bias_feature_id=-1,
+ class_id=-1,
multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
self.assertEqual(0, len(partitions))