aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py')
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index ba5ef700c5..d0d1249bd6 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -51,6 +51,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import device_setter
+
# Key names for prediction dict.
ENSEMBLE_STAMP = "ensemble_stamp"
PREDICTIONS = "predictions"
@@ -898,7 +899,7 @@ class GradientBoostedDecisionTreeModel(object):
reset_ops = []
for handler in handlers:
- reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0))
+ reset_ops.append(handler.reset(stamp_token, next_stamp_token))
if self._center_bias:
reset_ops.append(
bias_stats_accumulator.flush(stamp_token, next_stamp_token))