aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index e08b230f46..19e053fcb6 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -380,6 +380,8 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._max_tree_depth = variables.Variable(
+ initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
@@ -1051,7 +1053,8 @@ class GradientBoostedDecisionTreeModel(object):
splits=split_info_list,
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
- center_bias=self._center_bias)
+ center_bias=self._center_bias,
+ max_tree_depth=self._max_tree_depth)
def _grow_ensemble_not_ready_fn():
# Don't grow the ensemble, just update the stamp.
@@ -1065,7 +1068,8 @@ class GradientBoostedDecisionTreeModel(object):
splits=[],
learner_config=self._learner_config_serialized,
dropout_seed=dropout_seed,
- center_bias=self._center_bias)
+ center_bias=self._center_bias,
+ max_tree_depth=self._max_tree_depth)
def _grow_ensemble_fn():
# Conditionally grow an ensemble depending on whether the splits
@@ -1105,6 +1109,9 @@ class GradientBoostedDecisionTreeModel(object):
def get_number_of_trees_tensor(self):
return self._finalized_trees, self._attempted_trees
+ def get_max_tree_depth(self):
+ return self._max_tree_depth
+
def train(self, loss, predictions_dict, labels):
"""Updates the accumalator stats and grows the ensemble.