aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-27 12:36:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 12:41:02 -0700
commit9e27c8f01c4548e4cc7fe1a5015af1ec8e32e5d1 (patch)
treee36f4a31c6509154387f6bc52ea641005c707156 /tensorflow/contrib/boosted_trees
parent05285015795b374e4f71b24d21e31f6c59ebfc8e (diff)
Fixed a bug in the dense split handler ops.
PiperOrigin-RevId: 210412659
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc17
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py43
2 files changed, 37 insertions, 23 deletions
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 3a48635319..d0fd39fa30 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <limits>
#include <memory>
#include <string>
#include <vector>
@@ -325,13 +326,21 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
}
float best_gain = std::numeric_limits<float>::lowest();
- int64 best_bucket_idx = 0;
+ int64 best_bucket_id = 0;
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
- int64 current_bucket_id = 0;
+ int64 current_bucket_id = std::numeric_limits<int64>::max();
int64 last_bucket_id = -1;
+ // Find the lowest bucket id, this is going to be the first bucket id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (bucket_ids(start_index, 0) < current_bucket_id) {
+ current_bucket_id = bucket_ids(start_index, 0);
+ }
+ }
// Indexes offsets for each of the partitions that can be used to access
// gradients of a partition for a current bucket we consider.
std::vector<int> current_layer_offsets(num_elements, 0);
@@ -373,6 +382,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
best_gain = gain_of_split;
best_left_node_stats = current_left_node_stats;
best_right_node_stats = current_right_node_stats;
+ best_bucket_id = current_bucket_id;
}
current_bucket_id = next_bucket_id;
}
@@ -387,8 +397,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
oblivious_split_info.mutable_split_node()
->mutable_oblivious_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
- oblivious_dense_split->set_threshold(
- bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
+ oblivious_dense_split->set_threshold(bucket_boundaries(best_bucket_id));
(*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index 31043264a1..5532bd026a 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -186,11 +186,12 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
with self.test_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
- # i0 | (0.2, 0.12) | 1 | 2 |
- # i1 | (-0.5, 0.07) | 1 | 2 |
- # i2 | (1.2, 0.2) | 1 | 0 |
- # i3 | (4.0, 0.13) | 2 | 1 |
- dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
+ # i0 | (0.2, 0.12) | 1 | 3 |
+ # i1 | (-0.5, 0.07) | 1 | 3 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ dense_column = array_ops.placeholder(
+ dtypes.float32, shape=(4, 1), name="dense_column")
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32)
@@ -230,24 +231,28 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
with ops.control_dependencies([update_1]):
are_splits_ready = split_handler.make_splits(
np.int64(0), np.int64(1), class_id)[0]
+ # Forcing the creation of four buckets.
+ are_splits_ready = sess.run(
+ [are_splits_ready],
+ feed_dict={dense_column: [[0.2], [0.62], [0.3], [0.52]]})[0]
- with ops.control_dependencies([are_splits_ready]):
- update_2 = split_handler.update_stats_sync(
- 1,
- partition_ids,
- gradients,
- hessians,
- empty_gradients,
- empty_hessians,
- example_weights,
- is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
split_handler.make_splits(np.int64(1), np.int64(2), class_id))
- are_splits_ready, are_splits_ready2, partitions, gains, splits = (
- sess.run([
- are_splits_ready, are_splits_ready2, partitions, gains, splits
- ]))
+ # Only using the last three buckets.
+ are_splits_ready2, partitions, gains, splits = (
+ sess.run(
+ [are_splits_ready2, partitions, gains, splits],
+ feed_dict={dense_column: [[0.62], [0.62], [0.3], [0.52]]}))
# During the first iteration, inequality split handlers are not going to
# have any splits. Make sure that we return not_ready in that case.