aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/boosted_trees_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 14:09:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 14:17:14 -0700
commit2aac0e887ca27d9818607cd52f28044cb7673c70 (patch)
treed93dff30e8e0458894dcf197358ffa200ea77ce7 /tensorflow/core/ops/boosted_trees_ops.cc
parentbdd84aa59d3bdedc42647711e401229f489c7d25 (diff)
- Adding ability to center bias as a first step of training gbdt
- Fixing non determinism in choosing a split when gains are the same. PiperOrigin-RevId: 203180755
Diffstat (limited to 'tensorflow/core/ops/boosted_trees_ops.cc')
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index edcdc4cb6a..01452b3e85 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -331,4 +331,27 @@ REGISTER_OP("BoostedTreesUpdateEnsemble")
return Status::OK();
});
+REGISTER_OP("BoostedTreesCenterBias")
+ .Input("tree_ensemble_handle: resource")
+ .Input("mean_gradients: float")
+ .Input("mean_hessians: float")
+ // Regularization-related.
+ .Input("l1: float")
+ .Input("l2: float")
+ .Output("continue_centering: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle gradients_shape;
+ shape_inference::ShapeHandle hessians_shape;
+ shape_inference::ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(gradients_shape, hessians_shape, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
} // namespace tensorflow