diff options
Diffstat (limited to 'tensorflow/core/ops/boosted_trees_ops.cc')
-rw-r--r-- | tensorflow/core/ops/boosted_trees_ops.cc | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index edcdc4cb6a..01452b3e85 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -331,4 +331,27 @@ REGISTER_OP("BoostedTreesUpdateEnsemble") return Status::OK(); }); +REGISTER_OP("BoostedTreesCenterBias") + .Input("tree_ensemble_handle: resource") + .Input("mean_gradients: float") + .Input("mean_hessians: float") + // Regularization-related. + .Input("l1: float") + .Input("l2: float") + .Output("continue_centering: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle gradients_shape; + shape_inference::ShapeHandle hessians_shape; + shape_inference::ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &gradients_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &hessians_shape)); + TF_RETURN_IF_ERROR( + c->Merge(gradients_shape, hessians_shape, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape)); + + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + } // namespace tensorflow |