aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/ops
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-21 08:53:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-21 08:59:25 -0700
commitf92eef788dff0e629cb4c408ce4b00530f152d4f (patch)
treeca7a4cd9f1303c75d63c4bb13d8705ab5e8cbbf3 /tensorflow/contrib/boosted_trees/ops
parent8c99f891d0aa359509163cfcf16fc0177fafc3e8 (diff)
Migrate kernels to boosted_trees.
PiperOrigin-RevId: 159698656
Diffstat (limited to 'tensorflow/contrib/boosted_trees/ops')
-rw-r--r--tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc44
-rw-r--r--tensorflow/contrib/boosted_trees/ops/model_ops.cc114
-rw-r--r--tensorflow/contrib/boosted_trees/ops/prediction_ops.cc132
-rw-r--r--tensorflow/contrib/boosted_trees/ops/quantile_ops.cc263
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc197
-rw-r--r--tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc475
-rw-r--r--tensorflow/contrib/boosted_trees/ops/training_ops.cc120
7 files changed, 1345 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc b/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc
new file mode 100644
index 0000000000..b5ea5e7849
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc
@@ -0,0 +1,44 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("AddTreesToEnsemble")
+ .Input("tree_ensemble_handle: resource")
+ .Input("ensemble_to_add: string")
+ .Input("feature_column_usage_counts_handle: Ref(int64)")
+ .Input("feature_column_usage_counts_to_add: int64")
+ .Input("feature_column_gains_handle: Ref(float)")
+ .Input("feature_column_gains_to_add: float")
+ .Input("drop_out_tree_indices_weights: float")
+ .Input("learning_rate: float")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Synchronously adds a tree ensemble to a an existing tree ensemble variable.
+tree_ensemble_handle: Handle to the ensemble variable.
+ensemble_to_add: Serialized DecisionTreeConfig proto of the tree.
+feature_column_usage_counts_handle: Handle to the feature column usage counts variable.
+feature_column_usage_counts_to_add: Rank 1 Tensor holding feature column usage counts to add.
+feature_column_gains_handle: Handle to the feature column gains variable.
+feature_column_gains_to_add: Rank 1 Tensor holding feature column gains to add.
+drop_out_tree_indices_weights: Rank 2 Tensor containing dropped trees indices
+and original weights of those trees during prediction.
+learning_rate: The learning rate that the tuner found for this iteration.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/ops/model_ops.cc b/tensorflow/contrib/boosted_trees/ops/model_ops.cc
new file mode 100644
index 0000000000..c490c765cf
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/model_ops.cc
@@ -0,0 +1,114 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace gtflow {
+
+REGISTER_RESOURCE_HANDLE_OP(DecisionTreeEnsembleResource);
+
+REGISTER_OP("TreeEnsembleIsInitializedOp")
+ .Input("tree_ensemble_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Checks whether a tree ensemble has been initialized.
+)doc");
+
+REGISTER_OP("CreateTreeEnsembleVariable")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("tree_ensemble_config: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Creates a tree ensemble model and returns a handle to it.
+
+tree_ensemble_handle: Handle to the tree ensemble resource to be created.
+stamp_token: Token to use as the initial value of the resource stamp.
+tree_ensemble_config: Serialized proto of the tree ensemble.
+)doc");
+
+REGISTER_OP("TreeEnsembleStampToken")
+ .Input("tree_ensemble_handle: resource")
+ .Output("stamp_token: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Retrieves the tree ensemble resource stamp token.
+
+tree_ensemble_handle: Handle to the tree ensemble.
+stamp_token: Stamp token of the tree ensemble resource.
+)doc");
+
+REGISTER_OP("TreeEnsembleSerialize")
+ .Input("tree_ensemble_handle: resource")
+ .Output("stamp_token: int64")
+ .Output("tree_ensemble_config: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ c->set_output(1, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Serializes the tree ensemble to a proto.
+
+tree_ensemble_handle: Handle to the tree ensemble.
+stamp_token: Stamp token of the tree ensemble resource.
+tree_ensemble_config: Serialized proto of the ensemble.
+)doc");
+
+REGISTER_OP("TreeEnsembleDeserialize")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("tree_ensemble_config: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Deserializes a serialized tree ensemble config and replaces current tree
+ensemble.
+
+tree_ensemble_handle: Handle to the tree ensemble.
+stamp_token: Token to use as the new value of the resource stamp.
+tree_ensemble_config: Serialized proto of the ensemble.
+)doc");
+
+} // namespace gtflow
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
new file mode 100644
index 0000000000..8effb6f98f
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc
@@ -0,0 +1,132 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+using tensorflow::boosted_trees::learner::LearnerConfig;
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+
+static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) {
+ string learner_config_str;
+ // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
+ c->GetAttr("learner_config", &learner_config_str).IgnoreError();
+ LearnerConfig learner_config;
+ ParseProtoUnlimited(&learner_config, learner_config_str);
+ // Sets the shape of the output as a matrix.
+ const bool reduce_dim =
+ learner_config.multi_class_strategy() == LearnerConfig::TREE_PER_CLASS;
+ c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim,
+ reduce_dim ? learner_config.num_classes() - 1
+ : learner_config.num_classes())});
+ c->set_output(1, {c->Matrix(InferenceContext::kUnknownDim,
+ reduce_dim ? learner_config.num_classes() - 1
+ : learner_config.num_classes())});
+ c->set_output(2, {c->Vector(InferenceContext::kUnknownDim)});
+ return Status::OK();
+}
+
+REGISTER_OP("GradientTreesPrediction")
+ .Attr("learner_config: string")
+ .Attr("num_dense_float_features: int >= 0")
+ .Attr("num_sparse_float_features: int >= 0")
+ .Attr("num_sparse_int_features: int >= 0")
+ .Attr("use_locking: bool = false")
+ .Attr("apply_dropout: bool")
+ .Attr("apply_averaging: bool")
+ .Attr("center_bias: bool")
+ .Input("tree_ensemble_handle: resource")
+ .Input("seed: int64")
+ .Input("dense_float_features: num_dense_float_features * float")
+ .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
+ .Input("sparse_float_feature_values: num_sparse_float_features * float")
+ .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
+ .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_values: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
+ .Output("predictions: float")
+ .Output("no_dropout_predictions: float")
+ .Output("drop_out_tree_indices_weights: float")
+ .SetShapeFn(ApplyGradientTreesPredictionShapeFn)
+ .Doc(R"doc(
+Runs multiple additive regression forests predictors on input instances
+and computes the final prediction for each class.
+
+learner_config: Config for the learner of type LearnerConfig proto. Prediction
+ops for now uses only LearningRateDropoutDrivenConfig config from the learner.
+num_dense_float_features: Number of dense float features.
+num_sparse_float_features: Number of sparse float features.
+num_sparse_int_features: Number of sparse int features.
+use_locking: Whether to use locking.
+seed: random seed to be used for dropout.
+apply_dropout: whether to apply dropout during prediction.
+apply_averaging: whether averaging of tree ensembles should take place. If set
+to true, will be based on AveragingConfig from learner_config.
+tree_ensemble_handle: The handle to the tree ensemble.
+dense_float_features: Rank 2 Tensors containing dense float feature values.
+sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
+sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
+sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
+sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
+sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
+sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
+predictions: Rank 2 Tensor containing predictions per example per class.
+no_dropout_predictions: The same as predictions, but using all trees (even
+those that were dropped due to dropout).
+drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices
+and original weights of those trees during prediction.
+)doc");
+
+REGISTER_OP("GradientTreesPartitionExamples")
+ .Attr("num_dense_float_features: int >= 0")
+ .Attr("num_sparse_float_features: int >= 0")
+ .Attr("num_sparse_int_features: int >= 0")
+ .Attr("use_locking: bool = false")
+ .Input("tree_ensemble_handle: resource")
+ .Input("dense_float_features: num_dense_float_features * float")
+ .Input("sparse_float_feature_indices: num_sparse_float_features * int64")
+ .Input("sparse_float_feature_values: num_sparse_float_features * float")
+ .Input("sparse_float_feature_shapes: num_sparse_float_features * int64")
+ .Input("sparse_int_feature_indices: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_values: num_sparse_int_features * int64")
+ .Input("sparse_int_feature_shapes: num_sparse_int_features * int64")
+ .Output("partition_ids: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ return c->set_output("partition_ids",
+ {c->Vector(InferenceContext::kUnknownDim)});
+ })
+ .Doc(R"doc(
+Splits input examples into the leaves of the tree.
+
+num_dense_float_features: Number of dense float features.
+num_sparse_float_features: Number of sparse float features.
+num_sparse_int_features: Number of sparse int features.
+use_locking: Whether to use locking.
+tree_ensemble_handle: The handle to the tree ensemble.
+dense_float_features: Rank 2 Tensors containing dense float feature values.
+sparse_float_feature_indices: Rank 2 Tensors containing sparse float indices.
+sparse_float_feature_values: Rank 1 Tensors containing sparse float values.
+sparse_float_feature_shapes: Rank 1 Tensors containing sparse float shapes.
+sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices.
+sparse_int_feature_values: Rank 1 Tensors containing sparse int values.
+sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes.
+partition_ids: Rank 1 Tensor containing partition ids per example.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
new file mode 100644
index 0000000000..c778a02fb6
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
@@ -0,0 +1,263 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace gtflow {
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+REGISTER_RESOURCE_HANDLE_OP(QuantileStreamResource);
+
+REGISTER_OP("QuantileAccumulatorIsInitialized")
+ .Input("quantile_accumulator_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Checks whether a quantile accumulator has been initialized.
+)doc");
+
+REGISTER_OP("CreateQuantileAccumulator")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''")
+ .Attr("max_elements: int = 1099511627776") // 1 << 40
+ .Attr("epsilon: float")
+ .Attr("num_quantiles: int")
+ .Input("quantile_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Creates a stateful accumulator for quantile summaries.
+
+epsilon: Error bound on the quantile summary.
+num_quantiles: Number of buckets that we create from the data.
+stamp_token: Token to use as the initial value of the resource stamp.
+quantile_accumulator_handle: The handle to the accumulator.
+)doc");
+
+REGISTER_OP("QuantileAccumulatorAddSummaries")
+ .Attr("num_resource_handles: int >= 1")
+ .Input("quantile_accumulator_handles: num_resource_handles * resource")
+ .Input("stamp_token: int64")
+ .Input("summaries: num_resource_handles * string")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_resource_handles;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_resource_handles", &num_resource_handles));
+ // All the inputs are scalars.
+ shape_inference::ShapeHandle unused_input;
+ for (int i = 0; i < 2 * num_resource_handles + 1; ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused_input));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Adds each quantile summary to its stream.
+
+quantile_accumulator_handles: The handles to the quantile stream resources.
+stamp_token: Stamp token to validate the Read/Write operation.
+summaries: A list of serialized QuantileSummaryState.
+)doc");
+
+REGISTER_OP("QuantileAccumulatorGetBuckets")
+ .Attr("num_resource_handles: int >= 1")
+ .Input("quantile_accumulator_handles: num_resource_handles * resource")
+ .Input("stamp_token: int64")
+ .Output("are_buckets_ready: num_resource_handles * bool")
+ .Output("buckets: num_resource_handles * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_resource_handles;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_resource_handles", &num_resource_handles));
+ for (int i = 0; i < num_resource_handles; ++i) {
+ c->set_output(i, c->Scalar());
+ c->set_output(i + num_resource_handles, c->Vector(c->UnknownDim()));
+ }
+ return Status::OK();
+ })
+
+ .Doc(R"doc(
+Returns quantile buckets created during previous flush of the accumulator.
+
+quantile_accumulator_handles: The handles to the quantile stream resources.
+stamp_token: Stamp token to validate the Read/Write operation.
+are_buckets_ready: Whether the buckets are ready or not.
+buckets: Output quantile summary representing boundaries with "num_quantile"
+ elements.
+)doc");
+
+REGISTER_OP("QuantileAccumulatorFlush")
+ .Input("quantile_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("next_stamp_token: int64")
+ .Doc(R"doc(
+Resets quantile summary streams for each column with a new token.
+
+quantile_accumulator_handle: The handle to the accumulator.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+next_stamp_token: Stamp token to be used for the next iteration.
+)doc");
+
+REGISTER_OP("QuantileAccumulatorSerialize")
+ .Input("quantile_accumulator_handle: resource")
+ .Output("stamp_token: int64")
+ .Output("stream_state: string")
+ .Output("are_buckets_ready: bool")
+ .Output("buckets: float")
+ .Doc(R"doc(
+Serializes the state of the given resource.
+
+quantile_accumulator_handle: The handle to the accumulator.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+stream_state: A serialized QuantileStreamState.
+are_buckets_ready: Whether the buckets are ready or not.
+buckets: Output quantile buckets representing boundaries with "num_quantile"
+ elements.
+)doc");
+
+REGISTER_OP("QuantileAccumulatorDeserialize")
+ .Input("quantile_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("stream_state: string")
+ .Input("are_buckets_ready: bool")
+ .Input("buckets: float")
+ .Doc(R"doc(
+Serializes the state of the given resource.
+
+quantile_accumulator_handle: The handle to the accumulator.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+stream_state: A serialized QuantileStreamState.
+are_buckets_ready: Whether the buckets are ready or not.
+buckets: Output quantile summary representing boundaries with "num_quantile"
+ elements.
+)doc");
+
+REGISTER_OP("MakeQuantileSummaries")
+ .Attr("num_dense_features: int >= 0")
+ .Attr("num_sparse_features: int >= 0")
+ .Attr("epsilon: float")
+ .Input("dense_float_features: num_dense_features * float")
+ .Input("sparse_float_feature_indices: num_sparse_features * int64")
+ .Input("sparse_float_feature_values: num_sparse_features * float")
+ .Input("sparse_float_feature_shapes: num_sparse_features * int64")
+ .Input("example_weights: float")
+ .Output("dense_summaries: num_dense_features * string")
+ .Output("sparse_summaries: num_sparse_features * string")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_dense_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_dense_features", &num_dense_features));
+ int num_sparse_features;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_sparse_features", &num_sparse_features));
+ for (int i = 0; i < num_dense_features; ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ for (int i = 0; i < num_sparse_features; ++i) {
+ c->set_output(i + num_dense_features, c->Scalar());
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Creates a summary for the given features.
+
+num_dense_features: Number of dense feature groups to compute quantiles on.
+num_sparse_features: Number of sparse feature groups to compute quantiles on.
+epsilon: Error bound on the computed summary.
+dense_float_features: A list of vectors which contains dense values.
+sparse_float_feature_indices: List of rank 2 tensors containing the sparse float
+feature indices.
+sparse_float_feature_values: List of rank 1 tensors containing the sparse float
+feature values.
+sparse_float_feature_shapes: List of rank 1 tensors containing the shape of the
+float feature.
+example_weights: Rank 1 tensor containing the example weight tensor.
+dense_summaries: A list of serialized QuantileSummaryState for dense columns.
+sparse_summaries: A list of serialized QuantileSummaryState for sparse columns.
+)doc");
+
+REGISTER_OP("QuantileBuckets")
+ .Attr("num_dense_features: int >= 0")
+ .Attr("num_sparse_features: int >= 0")
+ .Attr("dense_config: list(string)")
+ .Attr("sparse_config: list(string)")
+ .Input("dense_float_features: num_dense_features * float")
+ .Input("sparse_float_feature_indices: num_sparse_features * int64")
+ .Input("sparse_float_feature_values: num_sparse_features * float")
+ .Input("sparse_float_feature_shapes: num_sparse_features * int64")
+ .Input("example_weights: float")
+ .Output("dense_buckets: num_dense_features * float")
+ .Output("sparse_buckets: num_sparse_features * float")
+ .Doc(R"doc(
+Computes quantile buckets for a given list of dense and sparse features with
+given example weights.
+
+num_dense_features: Number of dense feature groups to compute quantiles on.
+num_sparse_features: Number of sparse feature groups to compute quantiles on.
+dense_config: Config for computing buckets for dense values.
+Each entry is QuantileConfig proto.
+sparse_config: Config for computing buckets for sparse feature values.
+Each entry is QuantileConfig proto.
+dense_float_features: A list of vectors which contains dense values.
+sparse_float_feature_indices: List of rank 2 tensors containing the sparse float
+feature indices.
+sparse_float_feature_values: List of rank 1 tensors containing the sparse float
+feature values.
+sparse_float_feature_shapes: List of rank 1 tensors containing the shape of the
+float feature.
+example_weights: Rank 1 tensor containing the example weight tensor.
+dense_buckets: Output quantile summary for each dense float tensor
+representing boundaries each with "num_quantile" elements.
+sparse_buckets: Output quantile summary for each sparse float value tensor
+representing boundaries each with "num_quantile" elements.
+)doc");
+
+REGISTER_OP("Quantiles")
+ .Attr("num_dense_features: int >= 0")
+ .Attr("num_sparse_features: int >= 0")
+ .Input("dense_values: num_dense_features * float")
+ .Input("sparse_values: num_sparse_features * float")
+ .Input("dense_buckets: num_dense_features * float")
+ .Input("sparse_buckets: num_sparse_features * float")
+ .Output("dense_quantiles: num_dense_features * int32")
+ .Output("sparse_quantiles: num_sparse_features * int32")
+ .Doc(R"doc(
+Computes quantile for each a given list of dense and sparse feature values using
+the given buckets.
+
+num_dense_features: Number of dense feature groups to generate quantiles for.
+num_sparse_features: Number of sparse feature groups to generate quantiles for.
+dense_values: List of rank 1 tensors containing the dense values.
+sparse_values: List of rank 1 tensors containing the sparse feature values.
+dense_buckets: Quantile summary for each of the dense float tensor.
+sparse_buckets: Quantile summary for each of the sparse feature float tensor.
+dense_quantiles: Rank 1 tensors representing associated quantiles for each of
+dense float tensors.
+sparse_quantiles: Rank 1 tensors representing associated quantiles for each of
+the sparse feature tensors.
+)doc");
+
+} // namespace gtflow
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
new file mode 100644
index 0000000000..d32507127c
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -0,0 +1,197 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::DimensionHandle;
+
+REGISTER_OP("BuildDenseInequalitySplits")
+ .Attr("feature_column_group_id: int")
+ .Attr("l1_regularization: float")
+ .Attr("l2_regularization: float")
+ .Attr("tree_complexity_regularization: float")
+ .Attr("min_node_weight: float")
+ .Input("num_minibatches: int64")
+ .Input("partition_ids: int32")
+ .Input("bucket_ids: int64")
+ .Input("gradients: float32")
+ .Input("hessians: float32")
+ .Input("bucket_boundaries: float32")
+ .Output("output_partition_ids: int32")
+ .Output("gains: float32")
+ .Output("split_infos: string")
+ .SetShapeFn([](InferenceContext* c) {
+ DimensionHandle unused_dim;
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape));
+
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape));
+ ShapeHandle bucket_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(bucket_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+ ShapeHandle bucket_boundaries_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &bucket_boundaries_shape));
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ c->set_output(1, c->Vector(c->UnknownDim()));
+ c->set_output(2, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Find the split that has the best gain for the accumulated stats.
+
+num_minibatches: A scalar, the number of times per example gradients & hessians
+ were accumulated. The stats are divided by this to get per example stats.
+partition_ids: A rank 1 tensor of partition IDs.
+bucket_ids: A rank 1 tensor of buckets IDs.
+gradients: A rank 1 tensor of gradients.
+hessians: A rank 1 tensor of hessians.
+bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
+ for.
+gains: A rank 1 tensor, for the computed gain for the created splits.
+split_infos: A rank 1 tensor of serialized protos which contains the
+ `SplitInfo`s.
+)doc");
+
+REGISTER_OP("BuildSparseInequalitySplits")
+ .Attr("feature_column_group_id: int")
+ .Attr("bias_feature_id: int")
+ .Attr("l1_regularization: float")
+ .Attr("l2_regularization: float")
+ .Attr("tree_complexity_regularization: float")
+ .Attr("min_node_weight: float")
+ .Input("num_minibatches: int64")
+ .Input("partition_ids: int32")
+ .Input("bucket_ids: int64")
+ .Input("gradients: float32")
+ .Input("hessians: float32")
+ .Input("bucket_boundaries: float32")
+ .Output("output_partition_ids: int32")
+ .Output("gains: float32")
+ .Output("split_infos: string")
+ .SetShapeFn([](InferenceContext* c) {
+ DimensionHandle unused_dim;
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape));
+
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape));
+ ShapeHandle bucket_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(bucket_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+ ShapeHandle bucket_boundaries_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &bucket_boundaries_shape));
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ c->set_output(1, c->Vector(c->UnknownDim()));
+ c->set_output(2, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Find the split that has the best gain for the accumulated stats.
+
+num_minibatches: A scalar, the number of times per example gradients & hessians
+ were accumulated. The stats are divided by this to get per example stats.
+partition_ids: A rank 1 tensor of partition IDs.
+bucket_ids: A rank 1 tensor of buckets IDs.
+gradients: A rank 1 tensor of gradients.
+hessians: A rank 1 tensor of hessians.
+bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization.
+output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
+ for.
+gains: A rank 1 tensor, for the computed gain for the created splits.
+split_infos: A rank 1 tensor of serialized protos which contains the
+ `SplitInfo`s.
+)doc");
+
+REGISTER_OP("BuildCategoricalEqualitySplits")
+ .Attr("feature_column_group_id: int")
+ .Attr("bias_feature_id: int")
+ .Attr("l1_regularization: float")
+ .Attr("l2_regularization: float")
+ .Attr("tree_complexity_regularization: float")
+ .Attr("min_node_weight: float")
+ .Input("num_minibatches: int64")
+ .Input("partition_ids: int32")
+ .Input("feature_ids: int64")
+ .Input("gradients: float32")
+ .Input("hessians: float32")
+ .Output("output_partition_ids: int32")
+ .Output("gains: float32")
+ .Output("split_infos: string")
+ .SetShapeFn([](InferenceContext* c) {
+ DimensionHandle unused_dim;
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_shape));
+
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape));
+ ShapeHandle bucket_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(bucket_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(4), 1, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ c->set_output(1, c->Vector(c->UnknownDim()));
+ c->set_output(2, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Find the split that has the best gain for the accumulated stats.
+
+num_minibatches: A scalar, the number of times per example gradients & hessians
+ were accumulated. The stats are divided by this to get per example stats.
+partition_ids: A rank 1 tensor of partition IDs.
+feature_ids: A rank 1 tensor of feature IDs.
+gradients: A rank 1 tensor of gradients.
+hessians: A rank 1 tensor of hessians.
+output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
+ for.
+gains: A rank 1 tensor, for the computed gain for the created splits.
+split_infos: A rank 1 tensor of serialized protos which contains the
+ `SplitInfo`s.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc
new file mode 100644
index 0000000000..9fa5ec9a83
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc
@@ -0,0 +1,475 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace gtflow {
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::DimensionHandle;
+
+REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorScalarResource);
+
+REGISTER_OP("StatsAccumulatorScalarIsInitialized")
+ .Input("stats_accumulator_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Checks whether a stats accumulator has been initialized.
+)doc");
+
+REGISTER_OP("CreateStatsAccumulatorScalar")
+ .Input("stats_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // stamp_token is a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Creates a scalar stats accumulator.
+
+stats_accumulator_handle: handle to the stats accumulator.
+stamp_token: Token to use as the initial value of the resource stamp.
+)doc");
+
+REGISTER_OP("StatsAccumulatorScalarAdd")
+ .Attr("num_resource_handles: int >= 1")
+ .Input("stats_accumulator_handles: num_resource_handles * resource")
+ .Input("stamp_token: int64")
+ .Input("partition_ids: num_resource_handles * int32")
+ .Input("feature_ids: num_resource_handles * int64")
+ .Input("gradients: num_resource_handles * float")
+ .Input("hessians: num_resource_handles * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_resource_handles;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_resource_handles", &num_resource_handles));
+ for (int i = 0; i < num_resource_handles; ++i) {
+ ShapeHandle unused_input;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused_input));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_resource_handles), 0, &unused_input));
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(num_resource_handles + i + 1),
+ 1, &partition_ids_shape));
+ ShapeHandle feature_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(
+ c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(feature_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(
+ c->input(num_resource_handles * 3 + i + 1), 1, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(
+ c->input(num_resource_handles * 4 + i + 1), 1, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the scalar stats accumulator.
+
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+stats_accumulator_handles: A list of handles to the stats accumulator.
+partition_ids: A list of vectors of partition_ids.
+feature_ids: A list of vectors of feature_ids.
+gradients: A list of vectors of gradients for each slot in
+ <partition_id, feature_id>.
+hessians: A list of vectors of hessians for each slot in
+ <partition_id, feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorScalarFlush")
+ .Input("stats_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("next_stamp_token: int64")
+ .Output("num_updates: int64")
+ .Output("output_partition_ids: int32")
+ .Output("output_feature_ids: int64")
+ .Output("output_gradients: float")
+ .Output("output_hessians: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ c->set_output(1, c->Vector(c->UnknownDim()));
+ c->set_output(2, c->Vector(c->UnknownDim()));
+ c->set_output(3, c->Vector(c->UnknownDim()));
+ c->set_output(4, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Flushes the scalar stats accumulator to output and resets the internal state.
+
+stats_accumulator_handle: handle to the stats accumulator.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+next_stamp_token: Stamp token for the next iteration.
+num_updates: Number of times stats were added to this accumulator since last
+ flush.
+output_partition_ids A vector of partition_ids for the slots.
+output_feature_ids: A vector of feature_ids for the slots.
+output_gradients: A vector of gradients, with a value for each slot
+ in <output_partition_id, output_feature_id>.
+output_hessians: A vector of hessians, with a value for each slot
+ in <output_partition_id, output_feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorScalarDeserialize")
+ .Input("stats_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("num_updates: int64")
+ .Input("partition_ids: int32")
+ .Input("feature_ids: int64")
+ .Input("gradients: float")
+ .Input("hessians: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ DimensionHandle unused_dim;
+ // stats_accumulator_handle
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // stamp_token
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ // num_updates
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape));
+ ShapeHandle feature_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(feature_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Resets the scalar stats accumulator with the serialized state.
+
+stats_accumulator_handle: handle to the stats accumulator.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+num_updates: Number of times stats were added to this accumulator since last
+ flush.
+partition_ids: A vector of partition_ids.
+feature_ids: A vector of feature_ids.
+gradients: A vector of gradients for each slot in <partition_id, feature_id>.
+hessians: A vector of hessians for each slot in <partition_id, feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorScalarSerialize")
+ .Input("stats_accumulator_handle: resource")
+ .Output("stamp_token: int64")
+ .Output("num_updates: int64")
+ .Output("output_partition_ids: int32")
+ .Output("output_feature_ids: int64")
+ .Output("output_gradients: float")
+ .Output("output_hessians: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // stamp_token
+ c->set_output(0, c->Scalar());
+ // num_updates
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Vector(c->UnknownDim()));
+ c->set_output(3, c->Vector(c->UnknownDim()));
+ c->set_output(4, c->Vector(c->UnknownDim()));
+ c->set_output(5, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Serializes the scalar stats accumulator state.
+
+stats_accumulator_handle: handle to the stats accumulator.
+stamp_token: The current stamp token for the resource.
+num_updates: Number of times stats were added to this accumulator since last
+ flush.
+output_partition_ids A vector of partition_ids for the slots.
+output_feature_ids: A vector of feature_ids for the slots.
+output_gradients: A vector of gradients, with a value for each slot
+ in <output_partition_id, output_feature_id>.
+output_hessians: A vector of hessians, with a value for each slot
+ in <output_partition_id, output_feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorScalarMakeSummary")
+ .Input("partition_ids: int32")
+ .Input("feature_ids: int64")
+ .Input("gradients: float")
+ .Input("hessians: float")
+ .Output("output_partition_ids: int32")
+ .Output("output_feature_ids: int64")
+ .Output("output_gradients: float")
+ .Output("output_hessians: float")
+ .Doc(R"doc(
+)doc");
+
+// Tensor version of the stats accumulator ops.
+REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorTensorResource);
+
+REGISTER_OP("StatsAccumulatorTensorIsInitialized")
+ .Input("stats_accumulator_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Checks whether a tensor stats accumulator has been initialized.
+)doc");
+
+REGISTER_OP("CreateStatsAccumulatorTensor")
+ .Input("stats_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("per_slot_gradient_shape: int64")
+ .Input("per_slot_hessian_shape: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // stamp_token is a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_input));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Creates a tensor stats accumulator.
+
+stats_accumulator_handle: handle to the tree ensemble resource to be created.
+stamp_token: Token to use as the initial value of the resource stamp.
+per_slot_gradient_shape: a vector that defines the shape of gradients.
+per_slot_hessian_shape: a vector that defines the shape of hessians.
+)doc");
+
+REGISTER_OP("StatsAccumulatorTensorAdd")
+ .Attr("num_resource_handles: int >= 1")
+ .Input("stats_accumulator_handles: num_resource_handles * resource")
+ .Input("stamp_token: int64")
+ .Input("partition_ids: num_resource_handles * int32")
+ .Input("feature_ids: num_resource_handles * int64")
+ .Input("gradients: num_resource_handles * float")
+ .Input("hessians: num_resource_handles * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_resource_handles;
+ TF_RETURN_IF_ERROR(
+ c->GetAttr("num_resource_handles", &num_resource_handles));
+ for (int i = 0; i < num_resource_handles; ++i) {
+ ShapeHandle unused_input;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused_input));
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_resource_handles), 0, &unused_input));
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(num_resource_handles + i + 1),
+ 1, &partition_ids_shape));
+ ShapeHandle feature_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(
+ c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(feature_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(
+ c->input(num_resource_handles * 3 + i + 1), 2, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(
+ c->input(num_resource_handles * 4 + i + 1), 2, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Updates the tensor stats accumulator.
+
+stats_accumulator_handles: A list of handles to the stats accumulator.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+partition_ids: A list of vectors of partition_ids.
+feature_ids: A list of vectors of feature_ids.
+gradients: A list of vectors of gradients for each slot in
+ <partition_id, feature_id>.
+hessians: A list of vectors of hessians for each slot in
+ <partition_id, feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorTensorFlush")
+ .Input("stats_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("next_stamp_token: int64")
+ .Output("num_updates: int64")
+ .Output("output_partition_ids: int32")
+ .Output("output_feature_ids: int64")
+ .Output("output_gradients: float")
+ .Output("output_hessians: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ // num_updates
+ c->set_output(0, c->Scalar());
+ c->set_output(1, c->Matrix(c->UnknownDim(), 1));
+ c->set_output(2, c->Matrix(c->UnknownDim(), 1));
+ c->set_output(3, c->UnknownShape());
+ c->set_output(4, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Flushes the stats accumulator to output and resets the internal state.
+
+stats_accumulator_handle: handle to the tree ensemble resource to be created.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+next_stamp_token: Stamp token to be used for the next iteration.
+num_updates: Number of times stats were added to this accumulator since last
+ flush.
+output_partition_ids: A vector of partition_ids for the slots.
+output_feature_ids: A vector of feature_ids for the slots.
+output_gradients: A tensor of gradients, first dimension matches slots
+ in <partition_id, feature_id>.
+output_hessians: A tensor of hessians, first dimension matches slots
+ in <partition_id, feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorTensorDeserialize")
+ .Input("stats_accumulator_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("num_updates: int64")
+ .Input("partition_ids: int32")
+ .Input("feature_ids: int64")
+ .Input("gradients: float")
+ .Input("hessians: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ DimensionHandle unused_dim;
+ // stats_accumulator_handle
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // stamp_token
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ // num_updates
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ ShapeHandle partition_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape));
+ ShapeHandle feature_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(feature_ids_shape, 0), &unused_dim));
+ ShapeHandle gradients_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(5), 2, &gradients_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(gradients_shape, 0), &unused_dim));
+ ShapeHandle hessians_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(6), 2, &hessians_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0),
+ c->Dim(hessians_shape, 0), &unused_dim));
+
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Resets the tensor stats accumulator with the serialized state.
+
+stats_accumulator_handle: handle to the tree ensemble resource to be created.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+num_updates: Number of times stats were added to this accumulator since last
+ flush.
+partition_ids: A vector of partition_ids.
+feature_ids: A vector of feature_ids.
+gradients: A vector of gradients for each slot in <partition_id, feature_id>.
+hessians: A vector of hessians for each slot in <partition_id, feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorTensorSerialize")
+ .Input("stats_accumulator_handle: resource")
+ .Output("stamp_token: int64")
+ .Output("num_updates: int64")
+ .Output("output_partition_ids: int32")
+ .Output("output_feature_ids: int64")
+ .Output("output_gradients: float")
+ .Output("output_hessians: float")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // stamp_token
+ c->set_output(0, c->Scalar());
+ // num_updates
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Matrix(c->UnknownDim(), 1));
+ c->set_output(3, c->Matrix(c->UnknownDim(), 1));
+ c->set_output(4, c->UnknownShape());
+ c->set_output(5, c->UnknownShape());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Serializes the scalar stats accumulator state.
+
+stats_accumulator_handle: handle to the tree ensemble resource to be created.
+stamp_token: Stamp token for Read/Write operations.
+ Any operation with a mismatching token will be dropped.
+num_updates: Number of times stats were added to this accumulator since last
+ flush.
+output_partition_ids: A vector of partition_ids for the slots.
+output_feature_ids: A vector of feature_ids for the slots.
+output_gradients: A tensor of gradients, first dimension matches slots
+ in <partition_id, feature_id>.
+output_hessians: A tensor of hessians, first dimension matches slots
+ in <partition_id, feature_id>.
+)doc");
+
+REGISTER_OP("StatsAccumulatorTensorMakeSummary")
+ .Input("partition_ids: int32")
+ .Input("feature_ids: int64")
+ .Input("gradients: float")
+ .Input("hessians: float")
+ .Output("output_partition_ids: int32")
+ .Output("output_feature_ids: int64")
+ .Output("output_gradients: float")
+ .Output("output_hessians: float")
+ .Doc(R"doc(
+Summarizes the stats by summing the <gradients, hessians> that are for the same
+<partition_id, feature_id>.
+
+partition_ids: A vector of partition_ids.
+feature_ids: A vector of feature_ids.
+gradients: A vector of gradients for each slot in <partition_id, feature_id>.
+hessians: A vector of hessians for each slot in <partition_id, feature_id>.
+output_partition_ids: A vector of partition_ids for the slots.
+output_feature_ids: A vector of feature_ids for the slots.
+output_gradients: A tensor of gradients, first dimension matches slots
+ in <partition_id, feature_id>.
+output_hessians: A tensor of hessians, first dimension matches slots
+ in <partition_id, feature_id>.
+)doc");
+} // namespace gtflow
+} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/ops/training_ops.cc b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
new file mode 100644
index 0000000000..d2debbe03d
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/ops/training_ops.cc
@@ -0,0 +1,120 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace gtflow {
+
+REGISTER_OP("CenterTreeEnsembleBias")
+ .Attr("learner_config: string")
+ .Attr("centering_epsilon: float = 0.01")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("next_stamp_token: int64")
+ .Input("delta_updates: float")
+ .Output("continue_centering: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Centers the tree ensemble bias before adding trees based on feature splits.
+
+learner_config: Config for the learner of type LearnerConfig proto.
+tree_ensemble_handle: Handle to the ensemble variable.
+stamp_token: Stamp token for validating operation consistency.
+next_stamp_token: Stamp token to be used for the next iteration.
+delta_updates: Rank 1 Tensor containing delta updates per bias dimension.
+continue_centering: Scalar indicating whether more centering is needed.
+)doc");
+
+REGISTER_OP("GrowTreeEnsemble")
+ .Attr("learner_config: string")
+ .Attr("num_handlers: int >= 0")
+ .Attr("center_bias: bool")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Input("next_stamp_token: int64")
+ .Input("learning_rate: float")
+ .Input("dropout_seed: int64")
+ .Input("partition_ids: num_handlers * int32")
+ .Input("gains: num_handlers * float")
+ .Input("splits: num_handlers * string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused_input));
+ // Dropout seed.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_input));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Grows the tree ensemble by either adding a layer to the last tree being grown
+or by starting a new tree.
+
+learner_config: Config for the learner of type LearnerConfig proto.
+num_handlers: Number of handlers generating candidates.
+tree_ensemble_handle: Handle to the ensemble variable.
+stamp_token: Stamp token for validating operation consistency.
+next_stamp_token: Stamp token to be used for the next iteration.
+learning_rate: Scalar learning rate.
+partition_ids: List of Rank 1 Tensors containing partition Id per candidate.
+gains: List of Rank 1 Tensors containing gains per candidate.
+splits: List of Rank 1 Tensors containing serialized SplitInfo protos per candidate.
+)doc");
+
+REGISTER_OP("TreeEnsembleStats")
+ .Input("tree_ensemble_handle: resource")
+ .Input("stamp_token: int64")
+ .Output("num_trees: int64")
+ .Output("num_layers: int64")
+ .Output("active_tree: int64")
+ .Output("active_layer: int64")
+ .Output("attempted_trees: int64")
+ .Output("attempted_layers: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ c->set_output(3, c->Scalar());
+ c->set_output(4, c->Scalar());
+ c->set_output(5, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Retrieves stats related to the tree ensemble.
+
+tree_ensemble_handle: Handle to the ensemble variable.
+stamp_token: Stamp token for validating operation consistency.
+num_trees: Scalar indicating the number of finalized trees in the ensemble.
+num_layers: Scalar indicating the number of layers in the ensemble.
+active_tree: Scalar indicating the active tree being trained.
+active_layer: Scalar indicating the active layer being trained.
+)doc");
+
+} // namespace gtflow
+} // namespace tensorflow