aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/boosted_trees_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/boosted_trees_ops.cc')
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index edcdc4cb6a..01452b3e85 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -331,4 +331,27 @@ REGISTER_OP("BoostedTreesUpdateEnsemble")
return Status::OK();
});
+REGISTER_OP("BoostedTreesCenterBias")
+ .Input("tree_ensemble_handle: resource")
+ .Input("mean_gradients: float")
+ .Input("mean_hessians: float")
+ // Regularization-related.
+ .Input("l1: float")
+ .Input("l2: float")
+ .Output("continue_centering: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle gradients_shape;
+ shape_inference::ShapeHandle hessians_shape;
+ shape_inference::ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape));
+ TF_RETURN_IF_ERROR(
+ c->Merge(gradients_shape, hessians_shape, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
} // namespace tensorflow