aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 233e21f1cf..85b909e4f2 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -724,9 +724,9 @@ class GradientBoostedDecisionTreeModel(object):
active_handlers_current_layer = (
active_handlers_current_layer <
self._learner_config.feature_fraction_per_tree)
- active_handlers = array_ops.stack(active_handlers_current_layer,
- array_ops.ones(
- [len(handlers)], dtype=dtypes.bool))
+ active_handlers = array_ops.stack([
+ active_handlers_current_layer,
+ array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1)
else:
active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool)