diff options
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r-- | tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index e08b230f46..19e053fcb6 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -380,6 +380,8 @@ class GradientBoostedDecisionTreeModel(object): self._learner_config = learner_config self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() + self._max_tree_depth = variables.Variable( + initial_value=self._learner_config.constraints.max_tree_depth) self._attempted_trees = variables.Variable( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, @@ -1051,7 +1053,8 @@ class GradientBoostedDecisionTreeModel(object): splits=split_info_list, learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_not_ready_fn(): # Don't grow the ensemble, just update the stamp. @@ -1065,7 +1068,8 @@ class GradientBoostedDecisionTreeModel(object): splits=[], learner_config=self._learner_config_serialized, dropout_seed=dropout_seed, - center_bias=self._center_bias) + center_bias=self._center_bias, + max_tree_depth=self._max_tree_depth) def _grow_ensemble_fn(): # Conditionally grow an ensemble depending on whether the splits @@ -1105,6 +1109,9 @@ class GradientBoostedDecisionTreeModel(object): def get_number_of_trees_tensor(self): return self._finalized_trees, self._attempted_trees + def get_max_tree_depth(self): + return self._max_tree_depth + def train(self, loss, predictions_dict, labels): """Updates the accumalator stats and grows the ensemble. |