diff options
author | 2017-11-02 16:42:27 -0700 | |
---|---|---|
committer | 2017-11-02 16:46:12 -0700 | |
commit | 9db84049fe1d9c5c7c93d87b53528b8e8255afd9 (patch) | |
tree | a505fcec74e522f26acc5adb546c8be95b53c6ce | |
parent | 6ace5e0494d8142dc67ca0714893afc716125917 (diff) |
boosted_trees: some cleanups.
- removed learner_config.num_classes in training_ops which is unnecessary.
- replaced num_classes in gbdt_batch.py with logits_dimension where possible.
- simplified prediction_ops_test.
PiperOrigin-RevId: 174399706
3 files changed, 110 insertions, 213 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 4c56718f1b..2a5c7949f2 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -208,27 +208,19 @@ class CenterTreeEnsembleBiasOp : public OpKernel { int64 next_stamp_token = next_stamp_token_t->scalar<int64>()(); CHECK(stamp_token != next_stamp_token); + // Update the ensemble stamp. + ensemble_resource->set_stamp(next_stamp_token); + // Get the delta updates. const Tensor* delta_updates_t; OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t)); - OP_REQUIRES( - context, - delta_updates_t->dim_size(0) + 1 == learner_config_.num_classes(), - errors::InvalidArgument( - "Delta updates size must be consistent with label dimensions.")); auto delta_updates = delta_updates_t->vec<float>(); - - // Update the ensemble stamp. - ensemble_resource->set_stamp(next_stamp_token); + const int64 logits_dimension = delta_updates_t->dim_size(0); // Get the bias. - boosted_trees::trees::Leaf* const bias = RetrieveBias(ensemble_resource); + boosted_trees::trees::Leaf* const bias = + RetrieveBias(ensemble_resource, logits_dimension); CHECK(bias->has_vector()); - OP_REQUIRES( - context, - bias->vector().value_size() + 1 == learner_config_.num_classes(), - errors::InvalidArgument( - "Bias vector size must be consistent with label dimensions.")); // Update the bias. float total_delta = 0; @@ -256,7 +248,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel { private: // Helper method to retrieve the bias from the tree ensemble. boosted_trees::trees::Leaf* RetrieveBias( - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource, + int64 logits_dimension) { const int32 num_trees = ensemble_resource->num_trees(); if (num_trees <= 0) { // Add a new bias leaf. @@ -264,7 +257,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { boosted_trees::trees::DecisionTreeConfig* const tree_config = ensemble_resource->AddNewTree(1.0); auto* const leaf = tree_config->add_nodes()->mutable_leaf(); - for (size_t idx = 0; idx + 1 < learner_config_.num_classes(); ++idx) { + for (size_t idx = 0; idx < logits_dimension; ++idx) { leaf->mutable_vector()->add_value(0.0); } ensemble_resource->LastTreeMetadata()->set_is_finalized(true); diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index cf09585113..79802922ca 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -136,6 +136,27 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self._sparse_int_shape1 = np.array([2, 2]) self._seed = 123 + def _get_predictions(self, + tree_ensemble_handle, + learner_config, + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=False): + return prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], + [self._sparse_float_indices1, self._sparse_float_indices2], + [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, self._sparse_float_shape2], + [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], + learner_config=learner_config, + apply_dropout=apply_dropout, + apply_averaging=apply_averaging, + center_bias=center_bias, + reduce_dim=reduce_dim) + def testEmptyEnsemble(self): with self.test_session(): # Empty tree ensenble. @@ -151,18 +172,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) self.assertAllEqual([[0], [0]], result.eval()) # Empty dropout. @@ -187,18 +199,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) self.assertAllClose([[-0.4], [-0.4]], result.eval()) @@ -226,18 +229,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) self.assertAllClose([[-0.4, 0.9], [-0.4, 0.9]], result.eval()) @@ -279,18 +273,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias -0.4 from first tree and @@ -338,18 +323,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # All the examples should get only the bias since the second tree is @@ -395,18 +371,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias -0.4 from first tree and @@ -454,18 +421,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias -0.4 from first tree and @@ -512,18 +470,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.TREE_PER_CLASS) - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], @@ -572,18 +521,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=False) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], @@ -631,18 +571,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=False) # The first example will get bias class 1 -0.2 and -2 for class 2 from # first tree and leaf 2 payload (sparse feature missing) of 0.5 hence @@ -653,26 +584,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) - def _get_predictions(self, - tree_ensemble_handle, - learner_config, - apply_dropout=False, - apply_averaging=False, - center_bias=False): - return prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - center_bias=center_bias, - reduce_dim=True) - def testDropout(self): with self.test_session(): # Empty tree ensenble. @@ -699,10 +610,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) # We expect approx 500 trees were dropped. dropout_info = dropout_info.eval() @@ -719,10 +631,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Don't apply dropout. result_no_dropout, no_dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) self.assertEqual(result.eval().size, result_no_dropout.eval().size) for i in range(result.eval().size): @@ -760,17 +673,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) result_center, dropout_info_center = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=True) + center_bias=True, + reduce_dim=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -830,17 +745,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) result_center, dropout_info_center = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=True) + center_bias=True, + reduce_dim=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -888,28 +805,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="empty") resources.initialize_resources(resources.shared_resources()).run() - _, dropout_info_1 = prediction_ops.gradient_trees_prediction( + _, dropout_info_1 = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, center_bias=False, reduce_dim=True) - _, dropout_info_2 = prediction_ops.gradient_trees_prediction( + _, dropout_info_2 = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, @@ -919,12 +824,12 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Different seed. _, dropout_info_3 = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, - 112314, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], + 112314, [self._dense_float_tensor], + [self._sparse_float_indices1, self._sparse_float_indices2], + [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, self._sparse_float_shape2], + [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, @@ -932,14 +837,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): reduce_dim=True) # First seed with centering bias. - _, dropout_info_4 = prediction_ops.gradient_trees_prediction( + _, dropout_info_4 = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, @@ -983,17 +882,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) result_no_dropout, _ = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) self.assertAllEqual([[], []], dropout_info.eval()) self.assertAllClose(result.eval(), result_no_dropout.eval()) @@ -1048,12 +949,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Do averaging. result, dropout_info = self._get_predictions( - tree_ensemble_handle, learner_config, apply_averaging=True) + tree_ensemble_handle, + learner_config.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_dropout_info = (self._get_predictions( + pattern_result, pattern_dropout_info = self._get_predictions( adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1116,15 +1021,22 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() result_1, dropout_info_1 = self._get_predictions( - tree_ensemble_handle, learner_config_1, apply_averaging=True) + tree_ensemble_handle, + learner_config_1.SerializeToString(), + apply_averaging=True, + reduce_dim=True) result_2, dropout_info_2 = self._get_predictions( - tree_ensemble_handle, learner_config_2, apply_averaging=True) + tree_ensemble_handle, + learner_config_2.SerializeToString(), + apply_averaging=True, + reduce_dim=True) pattern_result, pattern_dropout_info = self._get_predictions( adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False) + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result_1.eval(), pattern_result.eval()) self.assertAllEqual(result_2.eval(), pattern_result.eval()) @@ -1179,12 +1091,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() result, dropout_info = self._get_predictions( - tree_ensemble_handle, learner_config, apply_averaging=True) + tree_ensemble_handle, + learner_config.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_dropout_info = (self._get_predictions( + pattern_result, pattern_dropout_info = self._get_predictions( adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1224,10 +1140,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -1263,10 +1175,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -1302,10 +1210,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 5a917ca428..4d9fd75323 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -494,7 +494,6 @@ class GradientBoostedDecisionTreeModel(object): gate_gradients=0, aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - num_classes = self._learner_config.num_classes class_id = -1 # Handle different multiclass strategies. @@ -503,7 +502,7 @@ class GradientBoostedDecisionTreeModel(object): gradient_shape = tensor_shape.scalar() hessian_shape = tensor_shape.scalar() - if num_classes == 2: + if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( gradients, @@ -522,7 +521,7 @@ class GradientBoostedDecisionTreeModel(object): # Choose the class for which the tree is built (one vs rest). class_id = math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % num_classes) + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) # Use class id tensor to get the column with that index from gradients # and hessians. @@ -532,14 +531,15 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([num_classes]) + gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape(([num_classes, num_classes])) + hessian_shape = tensor_shape.TensorShape( + ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([num_classes])) + hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -804,10 +804,10 @@ class GradientBoostedDecisionTreeModel(object): # compute the full hessian with a single call to gradients, but instead # must compute it row-by-row. gradients_list = array_ops.unstack( - grads, num=self._learner_config.num_classes, axis=1) + grads, num=self._logits_dimension, axis=1) hessian_rows = [] - for row in range(self._learner_config.num_classes): + for row in range(self._logits_dimension): # If current row is i, K is number of classes,each row returns a tensor of # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2 # etc dx_i dx_K @@ -830,7 +830,7 @@ class GradientBoostedDecisionTreeModel(object): diag_hessian_list = [] gradients_list = array_ops.unstack( - grads, num=self._learner_config.num_classes, axis=1) + grads, num=self._logits_dimension, axis=1) for row, row_grads in enumerate(gradients_list): # If current row is i, K is number of classes,each row returns a tensor of @@ -891,7 +891,7 @@ class GradientBoostedDecisionTreeModel(object): hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. - partition_ids = math_ops.range(predictions.get_shape()[1]) + partition_ids = math_ops.range(self._logits_dimension) feature_ids = array_ops.zeros_like(partition_ids, dtype=dtypes.int64) add_stats_op = bias_stats_accumulator.add( ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) |