aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-08 11:39:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 11:43:52 -0700
commit2345997d328ef992d24c0182cc1b1f21bcc51f0d (patch)
treea3783852585638b9b8eea57ab9822d33dec4a185 /tensorflow/contrib/boosted_trees
parenta1d3ebbe6f40768ea5ccf6beef9e905bce207b42 (diff)
Do not bucketize when buckets are empty. This is for situations when for a rare feature, no buckets were found.
PiperOrigin-RevId: 207920196
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc2
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py4
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py4
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py19
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py102
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py3
6 files changed, 125 insertions, 9 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
index 5b4be2f258..1375fddf2b 100644
--- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc
@@ -125,6 +125,8 @@ void QuantizeFeatures(
auto flat_values = values_tensor.flat<float>();
for (int64 instance = 0; instance < num_values; ++instance) {
const float value = flat_values(instance);
+ CHECK(!buckets_vector.empty())
+ << "Got empty buckets for feature " << feature_index;
auto bucket_iter =
std::lower_bound(buckets_vector.begin(), buckets_vector.end(), value);
if (bucket_iter == buckets_vector.end()) {
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py
index 1b7f59ea42..5d4819b0f1 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/base_split_handler.py
@@ -132,6 +132,10 @@ class BaseSplitHandler(object):
return control_flow_ops.group(update_1, *update_2[self])
@abc.abstractmethod
+ def reset(self, stamp_token, next_stamp_token):
+ """Resets the state maintained by the handler."""
+
+ @abc.abstractmethod
def make_splits(self, stamp_token, next_stamp_token, class_id):
"""Create the best split using the accumulated stats and flush the state.
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index bf686237ff..efe29216c2 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -202,3 +202,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
# always return ready.
are_splits_ready = constant_op.constant(True)
return (are_splits_ready, partition_ids, gains, split_infos)
+
+ def reset(self, stamp_token, next_stamp_token):
+ reset = self._stats_accumulator.flush(stamp_token, next_stamp_token)
+ return reset
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index df0bec1fe3..2559fe9913 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -79,6 +79,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+
_BIAS_FEATURE_ID = -1
# Pattern to remove all non alpha numeric from a string.
_PATTERN = re.compile(r"[\W_]+")
@@ -147,6 +148,11 @@ class InequalitySplitHandler(base_split_handler.BaseSplitHandler):
num_quantiles=num_quantiles,
name="QuantileAccumulator/{}".format(self._name))
+ def reset(self, stamp_token, next_stamp_token):
+ reset_1 = self._stats_accumulator.flush(stamp_token, next_stamp_token)
+ reset_2 = self._quantile_accumulator.flush(stamp_token, next_stamp_token)
+ return control_flow_ops.group([reset_1, reset_2])
+
class DenseSplitHandler(InequalitySplitHandler):
"""Computes stats and finds the best inequality splits on dense columns."""
@@ -264,6 +270,7 @@ class DenseSplitHandler(InequalitySplitHandler):
self._feature_column_group_id, self._l1_regularization,
self._l2_regularization, self._tree_complexity_regularization,
self._min_node_weight, self._loss_uses_sum_reduction))
+
return are_splits_ready, partition_ids, gains, split_infos
@@ -579,8 +586,10 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column,
example_partition_ids, feature_ids, gradients, hessians = (
control_flow_ops.cond(
- math_ops.logical_and(are_buckets_ready, is_active[0]),
- ready_inputs_fn, not_ready_inputs_fn))
+ math_ops.logical_and(
+ math_ops.logical_and(are_buckets_ready,
+ array_ops.size(quantile_buckets) > 0),
+ is_active[0]), ready_inputs_fn, not_ready_inputs_fn))
return (quantile_values, quantile_weights, example_partition_ids, feature_ids,
gradients, hessians)
@@ -674,8 +683,10 @@ def sparse_make_stats_update(
lambda: handler_not_active))
example_partition_ids, feature_ids, gradients, hessians = (
- control_flow_ops.cond(are_buckets_ready, quantiles_ready,
- quantiles_not_ready))
+ control_flow_ops.cond(
+ math_ops.logical_and(are_buckets_ready,
+ array_ops.size(quantile_buckets) > 0),
+ quantiles_ready, quantiles_not_ready))
return (quantile_indices, quantile_values, quantile_shape, quantile_weights,
example_partition_ids, feature_ids, gradients, hessians)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index d59732cf92..5d82c4cae5 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -1072,8 +1072,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
with self.test_session() as sess:
# Batch is 4, 2 classes
- gradients = array_ops.constant(
- [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]])
+ gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
+ [4.0, -3]])
# 2x2 matrix for each instance
hessian_0 = [[0.12, 0.02], [0.3, 0.11]]
hessian_1 = [[0.07, -0.2], [-0.5, 0.2]]
@@ -1167,8 +1167,8 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
with self.test_session() as sess:
# Batch is 4, 2 classes
- gradients = array_ops.constant(
- [[0.2, 1.4], [-0.5, 0.1], [1.2, 3], [4.0, -3]])
+ gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
+ [4.0, -3]])
# Each hessian is a diagonal from a full hessian matrix.
hessian_0 = [0.12, 0.11]
hessian_1 = [0.07, 0.2]
@@ -1406,6 +1406,100 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(gains), 0)
self.assertEqual(len(splits), 0)
+ def testEmptyBuckets(self):
+ """Test that reproduces the case when quantile buckets were empty."""
+ with self.test_session() as sess:
+ sparse_column = array_ops.sparse_placeholder(dtypes.float32)
+
+ # We have two batches - at first, a sparse feature is empty.
+ empty_indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
+ empty_values = array_ops.constant([], dtype=dtypes.float32)
+ empty_sparse_column = sparse_tensor.SparseTensor(empty_indices,
+ empty_values, [4, 2])
+ empty_sparse_column = empty_sparse_column.eval(session=sess)
+
+ # For the second batch, the sparse feature is not empty.
+ non_empty_indices = array_ops.constant(
+ [[0, 0], [2, 1], [3, 2]], dtype=dtypes.int64, shape=[3, 2])
+ non_empty_values = array_ops.constant(
+ [0.52, 0.3, 0.52], dtype=dtypes.float32)
+ non_empty_sparse_column = sparse_tensor.SparseTensor(
+ non_empty_indices, non_empty_values, [4, 2])
+ non_empty_sparse_column = non_empty_sparse_column.eval(session=sess)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = ordinal_split_handler.SparseSplitHandler(
+ l1_regularization=0.0,
+ l2_regularization=2.0,
+ tree_complexity_regularization=0.0,
+ min_node_weight=0.0,
+ epsilon=0.01,
+ num_quantiles=2,
+ feature_column_group_id=0,
+ sparse_float_column=sparse_column,
+ init_stamp_token=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)
+ resources.initialize_resources(resources.shared_resources()).run()
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_1]):
+ are_splits_ready = split_handler.make_splits(
+ np.int64(0), np.int64(1), class_id)[0]
+
+ # First, calculate quantiles and try to update on an empty data for a
+ # feature.
+ are_splits_ready = (
+ sess.run(
+ are_splits_ready,
+ feed_dict={sparse_column: empty_sparse_column}))
+ self.assertFalse(are_splits_ready)
+
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ with ops.control_dependencies([update_2]):
+ are_splits_ready2, partitions, gains, splits = (
+ split_handler.make_splits(np.int64(1), np.int64(2), class_id))
+
+ # Now the feature in the second batch is not empty, but buckets
+ # calculated on the first batch are empty.
+ are_splits_ready2, partitions, gains, splits = (
+ sess.run(
+ [are_splits_ready2, partitions, gains, splits],
+ feed_dict={sparse_column: non_empty_sparse_column}))
+ self.assertFalse(are_splits_ready)
+ self.assertTrue(are_splits_ready2)
+ # Since the buckets were empty, we can't calculate the splits.
+ self.assertEqual(len(partitions), 0)
+ self.assertEqual(len(gains), 0)
+ self.assertEqual(len(splits), 0)
+
def testDegenerativeCase(self):
with self.test_session() as sess:
# One data example only, one leaf and thus one quantile bucket.The same
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index ba5ef700c5..d0d1249bd6 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -51,6 +51,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import device_setter
+
# Key names for prediction dict.
ENSEMBLE_STAMP = "ensemble_stamp"
PREDICTIONS = "predictions"
@@ -898,7 +899,7 @@ class GradientBoostedDecisionTreeModel(object):
reset_ops = []
for handler in handlers:
- reset_ops.append(handler.make_splits(stamp_token, next_stamp_token, 0))
+ reset_ops.append(handler.reset(stamp_token, next_stamp_token))
if self._center_bias:
reset_ops.append(
bias_stats_accumulator.flush(stamp_token, next_stamp_token))