diff options
Diffstat (limited to 'tensorflow/core/ops/boosted_trees_ops.cc')
-rw-r--r-- | tensorflow/core/ops/boosted_trees_ops.cc | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 01452b3e85..7c4184bff4 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -22,6 +22,10 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource); REGISTER_OP("IsBoostedTreesEnsembleInitialized") @@ -354,4 +358,125 @@ REGISTER_OP("BoostedTreesCenterBias") return Status::OK(); }); +REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource); + +REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized") + .Input("quantile_stream_resource_handle: resource") + .Output("is_initialized: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesCreateQuantileStreamResource") + .Attr("max_elements: int = 1099511627776") // 1 << 40 + .Input("quantile_stream_resource_handle: resource") + .Input("epsilon: float") + .Input("num_streams: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesMakeQuantileSummaries") + .Attr("num_features: int >= 0") + .Input("float_values: num_features * float") + .Input("example_weights: float") + .Input("epsilon: float") + .Output("summaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + ShapeHandle example_weights_shape; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(num_features), 1, &example_weights_shape)); + for (int i = 0; i < num_features; ++i) { + ShapeHandle feature_shape; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(example_weights_shape, 0), + &unused_dim)); + // the columns are value, weight, min_rank, max_rank. + c->set_output(i, c->MakeShape({c->UnknownDim(), 4})); + } + // epsilon must be a scalar. + ShapeHandle unused_input; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(num_features + 1), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries") + .Attr("num_features: int >= 0") + .Input("quantile_stream_resource_handle: resource") + .Input("summaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + // resource handle must be a scalar. + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + // each summary must be rank 2. + for (int i = 1; i < num_features + 1; i++) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input)); + } + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceFlush") + .Attr("generate_quantiles: bool = False") + .Input("quantile_stream_resource_handle: resource") + .Input("num_buckets: int64") + .SetShapeFn([](InferenceContext* c) { + // All the inputs are scalars. + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries") + .Attr("num_features: int >= 0") + .Input("quantile_stream_resource_handle: resource") + .Output("bucket_boundaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + shape_inference::ShapeHandle unused_input; + // resource handle must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + for (int i = 0; i < num_features; i++) { + c->set_output(i, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesBucketize") + .Attr("num_features: int >= 0") + .Input("float_values: num_features * float") + .Input("bucket_boundaries: num_features * float") + .Output("buckets: num_features * int32") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + ShapeHandle feature_shape; + DimensionHandle unused_dim; + for (int i = 0; i < num_features; i++) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(c->input(0), 0), &unused_dim)); + } + // Bucketized result should have same dimension as input. + for (int i = 0; i < num_features; i++) { + c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1})); + } + return Status::OK(); + }); + } // namespace tensorflow |