aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py')
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py75
1 files changed, 58 insertions, 17 deletions
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 2559fe9913..f45010ec26 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -64,6 +64,7 @@ from __future__ import print_function
import re
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import gen_quantile_ops
from tensorflow.contrib.boosted_trees.python.ops import gen_stats_accumulator_ops
from tensorflow.contrib.boosted_trees.python.ops import quantile_ops
@@ -171,6 +172,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -192,6 +194,7 @@ class DenseSplitHandler(InequalitySplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(DenseSplitHandler, self).__init__(
@@ -209,6 +212,7 @@ class DenseSplitHandler(InequalitySplitHandler):
multiclass_strategy=multiclass_strategy,
loss_uses_sum_reduction=loss_uses_sum_reduction)
self._dense_float_column = dense_float_column
+ self._weak_learner_type = weak_learner_type
# Register dense_make_stats_update function as an Op to the graph.
g = ops.get_default_graph()
dense_make_stats_update.add_to_graph(g)
@@ -269,16 +273,17 @@ class DenseSplitHandler(InequalitySplitHandler):
next_stamp_token, self._multiclass_strategy, class_id,
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
- self._min_node_weight, self._loss_uses_sum_reduction))
-
+ self._min_node_weight, self._loss_uses_sum_reduction,
+ self._weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
-def _make_dense_split(
- quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
- next_stamp_token, multiclass_strategy, class_id, feature_column_id,
- l1_regularization, l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional, loss_uses_sum_reduction):
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional,
+ loss_uses_sum_reduction, weak_learner_type):
"""Function that builds splits for a dense feature column."""
# Get the bucket boundaries
are_splits_ready, buckets = (
@@ -327,7 +332,8 @@ def _make_dense_split(
l2_regularization=l2_regularization,
tree_complexity_regularization=tree_complexity_regularization,
min_node_weight=min_node_weight,
- multiclass_strategy=multiclass_strategy))
+ multiclass_strategy=multiclass_strategy,
+ weak_learner_type=weak_learner_type))
return are_splits_ready, partition_ids, gains, split_infos
@@ -507,7 +513,40 @@ def _make_sparse_split(
return are_splits_ready, partition_ids, gains, split_infos
-def _specialize_make_split(func, is_multi_dimentional):
+def _specialize_make_split_dense(func, is_multi_dimentional):
+ """Builds a specialized version of the function."""
+
+ @function.Defun(
+ dtypes.resource,
+ dtypes.resource,
+ dtypes.int64,
+ dtypes.int64,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.bool,
+ dtypes.int32,
+ noinline=True)
+ def f(quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
+ next_stamp_token, multiclass_strategy, class_id, feature_column_id,
+ l1_regularization, l2_regularization, tree_complexity_regularization,
+ min_node_weight, loss_uses_sum_reduction, weak_learner_type):
+ """Function that builds splits for a sparse feature column."""
+ return func(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy, class_id,
+ feature_column_id, l1_regularization, l2_regularization,
+ tree_complexity_regularization, min_node_weight,
+ is_multi_dimentional, loss_uses_sum_reduction,
+ weak_learner_type)
+
+ return f
+
+
+def _specialize_make_split_sparse(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
@function.Defun(
@@ -537,15 +576,17 @@ def _specialize_make_split(func, is_multi_dimentional):
return f
-make_dense_split_scalar = _specialize_make_split(_make_dense_split,
- is_multi_dimentional=False)
-make_dense_split_tensor = _specialize_make_split(_make_dense_split,
- is_multi_dimentional=True)
-make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
- is_multi_dimentional=False)
-make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
- is_multi_dimentional=True)
+make_dense_split_scalar = _specialize_make_split_dense(
+ _make_dense_split, is_multi_dimentional=False)
+
+make_dense_split_tensor = _specialize_make_split_dense(
+ _make_dense_split, is_multi_dimentional=True)
+
+make_sparse_split_scalar = _specialize_make_split_sparse(
+ _make_sparse_split, is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split_sparse(
+ _make_sparse_split, is_multi_dimentional=True)
@function.Defun(