aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py230
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py34
2 files changed, 150 insertions, 114 deletions
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 8225318b70..409a2d8f46 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -243,45 +243,74 @@ class DenseSplitHandler(InequalitySplitHandler):
def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state."""
- # Get the bucket boundaries
- are_splits_ready, buckets = (
- self._quantile_accumulator.get_buckets(stamp_token))
- # After we receive the boundaries from previous iteration we can flush
- # the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = self._quantile_accumulator.flush(
- stamp_token=stamp_token, next_stamp_token=next_stamp_token)
-
- # Get the aggregated gradients and hessians per <partition_id, feature_id>
- # pair.
- # In order to distribute the computation on all the PSs we use the PS that
- # had the stats accumulator on.
- with ops.device(None):
- with ops.device(self._stats_accumulator.resource().device):
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- self._stats_accumulator.flush(stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- are_splits_ready = control_flow_ops.with_dependencies(
- [flush_quantiles, partition_ids], are_splits_ready)
-
- partition_ids, gains, split_infos = (
- split_handler_ops.build_dense_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=self._feature_column_group_id,
- l1_regularization=self._l1_regularization,
- l2_regularization=self._l2_regularization,
- tree_complexity_regularization=self.
- _tree_complexity_regularization,
- min_node_weight=self._min_node_weight,
- multiclass_strategy=self._multiclass_strategy))
- return (are_splits_ready, partition_ids, gains, split_infos)
+ if (self._gradient_shape == tensor_shape.scalar() and
+ self._hessian_shape == tensor_shape.scalar()):
+ handler = make_dense_split_scalar
+ else:
+ handler = make_dense_split_tensor
+
+ are_splits_ready, partition_ids, gains, split_infos = (
+ handler(self._quantile_accumulator.resource(),
+ self._stats_accumulator.resource(), stamp_token,
+ next_stamp_token, self._multiclass_strategy, class_id,
+ self._feature_column_group_id, self._l1_regularization,
+ self._l2_regularization, self._tree_complexity_regularization,
+ self._min_node_weight))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _make_dense_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a dense feature column."""
+ # Get the bucket boundaries
+ are_splits_ready, buckets = (
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
+ # quantile_accumulator_get_buckets returns a list of results per handle that
+ # we pass to it. In this case we're getting results just for one resource.
+ are_splits_ready = are_splits_ready[0]
+ buckets = buckets[0]
+
+ # After we receive the boundaries from previous iteration we can flush
+ # the quantile accumulator.
+ with ops.control_dependencies([buckets]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_dense_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
class SparseSplitHandler(InequalitySplitHandler):
@@ -399,63 +428,64 @@ class SparseSplitHandler(InequalitySplitHandler):
return are_splits_ready, partition_ids, gains, split_infos
-def _specialize_sparse_split(is_multi_dimentional):
+def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
+ stamp_token, next_stamp_token, multiclass_strategy,
+ class_id, feature_column_id, l1_regularization,
+ l2_regularization, tree_complexity_regularization,
+ min_node_weight, is_multi_dimentional):
+ """Function that builds splits for a sparse feature column."""
+ # Get the bucket boundaries
+ are_splits_ready, buckets = (
+ gen_quantile_ops.quantile_accumulator_get_buckets(
+ quantile_accumulator_handles=[quantile_accumulator_handle],
+ stamp_token=stamp_token))
+ # quantile_accumulator_get_buckets returns a list of results per handle that
+ # we pass to it. In this case we're getting results just for one resource.
+ are_splits_ready = are_splits_ready[0]
+ buckets = buckets[0]
+
+ # After we receive the boundaries from previous iteration we can flush
+ # the quantile accumulator.
+ with ops.control_dependencies([buckets]):
+ flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
+ quantile_accumulator_handle=quantile_accumulator_handle,
+ stamp_token=stamp_token,
+ next_stamp_token=next_stamp_token)
+
+ if is_multi_dimentional:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+ else:
+ num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
+ gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
+ stats_accumulator_handle, stamp_token, next_stamp_token))
+
+ # Put quantile and stats accumulator flushing in the dependency path.
+ with ops.control_dependencies([flush_quantiles, partition_ids]):
+ are_splits_ready = array_ops.identity(are_splits_ready)
+ partition_ids, gains, split_infos = (
+ split_handler_ops.build_sparse_inequality_splits(
+ num_minibatches=num_minibatches,
+ bucket_boundaries=buckets,
+ partition_ids=partition_ids,
+ bucket_ids=bucket_ids,
+ gradients=gradients,
+ hessians=hessians,
+ class_id=class_id,
+ feature_column_group_id=feature_column_id,
+ l1_regularization=l1_regularization,
+ l2_regularization=l2_regularization,
+ tree_complexity_regularization=tree_complexity_regularization,
+ min_node_weight=min_node_weight,
+ bias_feature_id=_BIAS_FEATURE_ID,
+ multiclass_strategy=multiclass_strategy))
+ return are_splits_ready, partition_ids, gains, split_infos
+
+
+def _specialize_make_split(func, is_multi_dimentional):
"""Builds a specialized version of the function."""
- def _make_sparse_split(quantile_accumulator_handle, stats_accumulator_handle,
- stamp_token, next_stamp_token, multiclass_strategy,
- class_id, feature_column_id, l1_regularization,
- l2_regularization, tree_complexity_regularization,
- min_node_weight, is_multi_dimentional):
- """Function that builds splits for a sparse feature column."""
- # Get the bucket boundaries
- are_splits_ready, buckets = (
- gen_quantile_ops.quantile_accumulator_get_buckets(
- quantile_accumulator_handles=[quantile_accumulator_handle],
- stamp_token=stamp_token))
- # quantile_accumulator_get_buckets returns a list of results per handle that
- # we pass to it. In this case we're getting results just for one resource.
- are_splits_ready = are_splits_ready[0]
- buckets = buckets[0]
-
- # After we receive the boundaries from previous iteration we can flush
- # the quantile accumulator.
- with ops.control_dependencies([buckets]):
- flush_quantiles = gen_quantile_ops.quantile_accumulator_flush(
- quantile_accumulator_handle=quantile_accumulator_handle,
- stamp_token=stamp_token,
- next_stamp_token=next_stamp_token)
-
- if is_multi_dimentional:
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- gen_stats_accumulator_ops.stats_accumulator_tensor_flush(
- stats_accumulator_handle, stamp_token, next_stamp_token))
- else:
- num_minibatches, partition_ids, bucket_ids, gradients, hessians = (
- gen_stats_accumulator_ops.stats_accumulator_scalar_flush(
- stats_accumulator_handle, stamp_token, next_stamp_token))
-
- # Put quantile and stats accumulator flushing in the dependency path.
- with ops.control_dependencies([flush_quantiles, partition_ids]):
- are_splits_ready = array_ops.identity(are_splits_ready)
- partition_ids, gains, split_infos = (
- split_handler_ops.build_sparse_inequality_splits(
- num_minibatches=num_minibatches,
- bucket_boundaries=buckets,
- partition_ids=partition_ids,
- bucket_ids=bucket_ids,
- gradients=gradients,
- hessians=hessians,
- class_id=class_id,
- feature_column_group_id=feature_column_id,
- l1_regularization=l1_regularization,
- l2_regularization=l2_regularization,
- tree_complexity_regularization=tree_complexity_regularization,
- min_node_weight=min_node_weight,
- bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=multiclass_strategy))
- return are_splits_ready, partition_ids, gains, split_infos
-
@function.Defun(
dtypes.resource,
dtypes.resource,
@@ -474,7 +504,7 @@ def _specialize_sparse_split(is_multi_dimentional):
l1_regularization, l2_regularization, tree_complexity_regularization,
min_node_weight):
"""Function that builds splits for a sparse feature column."""
- return _make_sparse_split(
+ return func(
quantile_accumulator_handle, stats_accumulator_handle, stamp_token,
next_stamp_token, multiclass_strategy, class_id, feature_column_id,
l1_regularization, l2_regularization, tree_complexity_regularization,
@@ -482,9 +512,15 @@ def _specialize_sparse_split(is_multi_dimentional):
return f
+make_dense_split_scalar = _specialize_make_split(_make_dense_split,
+ is_multi_dimentional=False)
+make_dense_split_tensor = _specialize_make_split(_make_dense_split,
+ is_multi_dimentional=True)
-make_sparse_split_scalar = _specialize_sparse_split(is_multi_dimentional=False)
-make_sparse_split_tensor = _specialize_sparse_split(is_multi_dimentional=True)
+make_sparse_split_scalar = _specialize_make_split(_make_sparse_split,
+ is_multi_dimentional=False)
+make_sparse_split_tensor = _specialize_make_split(_make_sparse_split,
+ is_multi_dimentional=True)
@function.Defun(
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index c081a3f2c4..2f2c230211 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -67,9 +67,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessian_shape = tensor_shape.scalar()
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
@@ -203,10 +203,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessian_shape = tensor_shape.TensorShape([2, 2])
split_handler = ordinal_split_handler.DenseSplitHandler(
- l1_regularization=0,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=3,
feature_column_group_id=0,
@@ -291,10 +291,10 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
hessian_shape = tensor_shape.TensorShape([2])
split_handler = ordinal_split_handler.DenseSplitHandler(
- l1_regularization=0,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l1_regularization=0.,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=3,
feature_column_group_id=0,
@@ -376,9 +376,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
- tree_complexity_regularization=0,
- min_node_weight=0,
+ l2_regularization=1.,
+ tree_complexity_regularization=0.,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
@@ -451,9 +451,9 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
+ l2_regularization=1.,
tree_complexity_regularization=0.5,
- min_node_weight=0,
+ min_node_weight=0.,
epsilon=0.001,
num_quantiles=10,
feature_column_group_id=0,
@@ -585,7 +585,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
split_handler = ordinal_split_handler.DenseSplitHandler(
l1_regularization=0.1,
- l2_regularization=1,
+ l2_regularization=1.,
tree_complexity_regularization=0.5,
min_node_weight=1.5,
epsilon=0.001,