aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 09:01:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 09:05:55 -0800
commit23384d7d8a60a36c68fbbdc509b22d385ea9a12c (patch)
treeafea8582801bfa30555f916ca68b5b30be00aa87
parenta47cd30d960b128e5ed405cb36e914aa36fe462a (diff)
Fix feature fraction per tree.
PiperOrigin-RevId: 188339438
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 233e21f1cf..85b909e4f2 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -724,9 +724,9 @@ class GradientBoostedDecisionTreeModel(object):
active_handlers_current_layer = (
active_handlers_current_layer <
self._learner_config.feature_fraction_per_tree)
- active_handlers = array_ops.stack(active_handlers_current_layer,
- array_ops.ones(
- [len(handlers)], dtype=dtypes.bool))
+ active_handlers = array_ops.stack([
+ active_handlers_current_layer,
+ array_ops.ones([len(handlers)], dtype=dtypes.bool)], axis=1)
else:
active_handlers = array_ops.ones([len(handlers), 2], dtype=dtypes.bool)