aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-21 11:25:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-21 11:29:41 -0700
commit8e40f1adcbc94c2c21dccaa557604a917fd86f22 (patch)
treef238bf35900465c0b7c668c0f5b492f7b12590e3
parent41f3f76970726fe4ec2cd9e485a04e6f072a3bce (diff)
Migrate ops for new version of TensorForest.
PiperOrigin-RevId: 159718610
-rwxr-xr-xtensorflow/contrib/BUILD4
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake2
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake4
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD293
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops.cc299
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/stats_ops.cc564
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/BUILD26
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc142
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h81
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc140
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h69
-rw-r--r--tensorflow/contrib/tensor_forest/ops/model_ops.cc135
-rw-r--r--tensorflow/contrib/tensor_forest/ops/stats_ops.cc146
-rw-r--r--tensorflow/contrib/tensor_forest/python/__init__.py2
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/model_ops.py124
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/stats_ops.py114
16 files changed, 1887 insertions, 258 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index f852eded1e..dfcdf3991c 100755
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -93,6 +93,8 @@ cc_library(
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nccl:nccl_kernels",
"//tensorflow/contrib/seq2seq:beam_search_ops_kernels",
+ "//tensorflow/contrib/tensor_forest:model_ops_kernels",
+ "//tensorflow/contrib/tensor_forest:stats_ops_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
"//tensorflow/contrib/text:all_kernels",
],
@@ -110,6 +112,8 @@ cc_library(
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
"//tensorflow/contrib/seq2seq:beam_search_ops_op_lib",
+ "//tensorflow/contrib/tensor_forest:model_ops_op_lib",
+ "//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
"//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index f1b01250a4..a9defc1139 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -85,6 +85,8 @@ GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib
GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc")
+GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_model "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/model_ops.cc")
+GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/stats_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 84a1302344..d17fcf6456 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -659,6 +659,10 @@ GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/hybrid/ops/gen_training_ops.py)
+GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_model_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_model_ops.py)
+GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_stats_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_stats_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops"
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 7b5f9472e7..1ca2d7596b 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -28,37 +28,35 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
+# ---------------------------------- V2 ops ------------------------------------------#
filegroup(
- name = "custom_op_sources",
- srcs = glob(
- [
- "kernels/*.cc",
- "ops/*.cc",
- ],
- exclude = [
- "kernels/*_test.cc",
- "kernels/tree_utils.cc",
- ],
- ),
+ name = "v2_op_sources",
+ srcs = [
+ "kernels/best_splits_op.cc",
+ "kernels/count_extremely_random_stats_op.cc",
+ "kernels/finished_nodes_op.cc",
+ "kernels/grow_tree_op.cc",
+ "kernels/reinterpret_string_to_float_op.cc",
+ "kernels/sample_inputs_op.cc",
+ "kernels/scatter_add_ndim_op.cc",
+ "kernels/tree_predictions_op.cc",
+ "kernels/update_fertile_slots_op.cc",
+ ],
)
filegroup(
- name = "custom_op_headers",
- srcs = glob(
- [
- "kernels/*.h",
- ],
- exclude = [
- "kernels/data_spec.h",
- "kernels/tree_utils.h",
- ],
- ),
+ name = "v2_op_defs",
+ srcs = [
+ "ops/tensor_forest_ops.cc",
+ ],
)
cc_library(
- name = "all_ops",
- srcs = [":custom_op_sources"],
- hdrs = [":custom_op_headers"],
+ name = "v2_ops",
+ srcs = [
+ ":v2_op_defs",
+ ":v2_op_sources",
+ ],
deps = [
":tree_utils",
"//tensorflow/core:framework_headers_lib",
@@ -105,16 +103,8 @@ tf_gen_op_wrapper_py(
tf_custom_op_library(
name = "python/ops/_tensor_forest_ops.so",
srcs = [
- "kernels/best_splits_op.cc",
- "kernels/count_extremely_random_stats_op.cc",
- "kernels/finished_nodes_op.cc",
- "kernels/grow_tree_op.cc",
- "kernels/reinterpret_string_to_float_op.cc",
- "kernels/sample_inputs_op.cc",
- "kernels/scatter_add_ndim_op.cc",
- "kernels/tree_predictions_op.cc",
- "kernels/update_fertile_slots_op.cc",
- "ops/tensor_forest_ops.cc",
+ ":v2_op_defs",
+ ":v2_op_sources",
],
deps = [":tree_utils"],
)
@@ -131,7 +121,9 @@ py_library(
":constants",
":data_ops_py",
":eval_metrics",
+ ":model_ops_py",
":random_forest",
+ ":stats_ops_py",
":tensor_forest_ops_py",
":tensor_forest_py",
],
@@ -140,21 +132,11 @@ py_library(
tf_kernel_library(
name = "tensor_forest_kernels",
srcs = [
- "kernels/best_splits_op.cc",
- "kernels/count_extremely_random_stats_op.cc",
- "kernels/finished_nodes_op.cc",
- "kernels/grow_tree_op.cc",
- "kernels/reinterpret_string_to_float_op.cc",
- "kernels/sample_inputs_op.cc",
- "kernels/scatter_add_ndim_op.cc",
- "kernels/tree_predictions_op.cc",
- "kernels/update_fertile_slots_op.cc",
+ ":v2_op_sources",
],
deps = [
":tree_utils",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:framework_headers_lib",
"//tensorflow/core/kernels:bounds_check",
],
)
@@ -181,6 +163,192 @@ tf_custom_op_py_library(
],
)
+cc_test(
+ name = "tensor_forest_ops_test",
+ size = "small",
+ srcs = [
+ "kernels/tensor_forest_ops_test.cc",
+ ":v2_op_defs",
+ ":v2_op_sources",
+ ],
+ deps = [
+ ":tree_utils",
+ "//tensorflow/core",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//third_party/eigen3",
+ ],
+)
+
+# -------------------------------------- V4 ops ------------------------------- #
+cc_library(
+ name = "tensor_forest_v4_kernels",
+ deps = [
+ ":model_ops_kernels",
+ ":stats_ops_kernels",
+ ],
+)
+
+cc_library(
+ name = "tensor_forest_v4_ops_op_lib",
+ deps = [
+ ":model_ops_op_lib",
+ ":stats_ops_op_lib",
+ ],
+)
+
+py_library(
+ name = "tensor_forest_v4_ops_py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":model_ops_py",
+ ":stats_ops_py",
+ ],
+)
+
+# Model Ops.
+cc_library(
+ name = "model_ops_lib",
+ srcs = ["kernels/model_ops.cc"],
+ deps = [
+ "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc",
+ "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc",
+ "//tensorflow/contrib/tensor_forest:tree_utils",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:input_data",
+ "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["model_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_model_ops_py",
+ out = "python/ops/gen_model_ops.py",
+ deps = [":model_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "model_ops_kernels",
+ deps = [
+ ":model_ops_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "python/ops/_model_ops.so",
+ deps = [
+ ":model_ops_lib",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "model_ops_py",
+ srcs = ["python/ops/model_ops.py"],
+ dso = ["python/ops/_model_ops.so"],
+ kernels = [
+ ":model_ops_kernels",
+ ":model_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_model_ops_py",
+ ":stats_ops_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+# Stats Ops.
+cc_library(
+ name = "stats_ops_lib",
+ srcs = ["kernels/stats_ops.cc"],
+ deps = [
+ "//tensorflow/contrib/tensor_forest:tree_utils",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:fertile-stats-resource",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:input_data",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:input_target",
+ "//tensorflow/contrib/tensor_forest/kernels/v4:params",
+ "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = 1,
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["stats_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_stats_ops_py",
+ out = "python/ops/gen_stats_ops.py",
+ deps = [":stats_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "stats_ops_kernels",
+ deps = [
+ ":stats_ops_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_library(
+ name = "python/ops/_stats_ops.so",
+ deps = [
+ ":stats_ops_lib",
+ ],
+)
+
+tf_custom_op_py_library(
+ name = "stats_ops_py",
+ srcs = ["python/ops/stats_ops.py"],
+ dso = ["python/ops/_stats_ops.so"],
+ kernels = [
+ ":stats_ops_kernels",
+ ":stats_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_stats_ops_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
+# ---------------------------------- Common libs ------------------------ #
+cc_library(
+ name = "tree_utils",
+ srcs = ["kernels/tree_utils.cc"],
+ hdrs = [
+ "kernels/data_spec.h",
+ "kernels/tree_utils.h",
+ ],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf//:protobuf_headers",
+ ],
+)
+
+# --------------------------------- Python -------------------------------- #
+
py_library(
name = "eval_metrics",
srcs = ["client/eval_metrics.py"],
@@ -220,20 +388,6 @@ py_library(
],
)
-cc_library(
- name = "tree_utils",
- srcs = ["kernels/tree_utils.cc"],
- hdrs = [
- "kernels/data_spec.h",
- "kernels/tree_utils.h",
- ],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf//:protobuf_headers",
- ],
-)
-
py_test(
name = "best_splits_op_test",
size = "small",
@@ -380,25 +534,6 @@ py_test(
],
)
-cc_test(
- name = "tensor_forest_ops_test",
- size = "small",
- srcs = [
- "kernels/tensor_forest_ops_test.cc",
- ":custom_op_sources",
- ],
- deps = [
- ":tree_utils",
- "//tensorflow/core",
- "//tensorflow/core:framework",
- "//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- "//third_party/eigen3",
- ],
-)
-
py_library(
name = "random_forest",
srcs = ["client/random_forest.py"],
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
new file mode 100644
index 0000000000..195221a48e
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -0,0 +1,299 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
+#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
+#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
+#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_handle.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+// Creates a tree variable.
+class CreateTreeVariableOp : public OpKernel {
+ public:
+ explicit CreateTreeVariableOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* tree_config_t;
+ OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
+ errors::InvalidArgument("Tree config must be a scalar."));
+
+ auto* result = new DecisionTreeResource();
+ if (!ParseProtoUnlimited(result->mutable_decision_tree(),
+ tree_config_t->scalar<string>()())) {
+ result->Unref();
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument("Unable to parse tree config."));
+ }
+
+ result->MaybeInitialize();
+
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions.
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ TensorForestParams param_proto_;
+};
+
+// Op for serializing a model.
+class TreeSerializeOp : public OpKernel {
+ public:
+ explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ DecisionTreeResource* decision_tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+ Tensor* output_config_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, TensorShape(), &output_config_t));
+ output_config_t->scalar<string>()() =
+ decision_tree_resource->decision_tree().SerializeAsString();
+ }
+};
+
+// Op for deserializing a tree variable from a checkpoint.
+class TreeDeserializeOp : public OpKernel {
+ public:
+ explicit TreeDeserializeOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ DecisionTreeResource* decision_tree_resource;
+ auto handle = HandleFromInput(context, 0);
+ OP_REQUIRES_OK(context, LookupResource(context, handle,
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+
+ const Tensor* tree_config_t;
+ OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
+ errors::InvalidArgument("Tree config must be a scalar."));
+ // Deallocate all the previous objects on the resource.
+ decision_tree_resource->Reset();
+ decision_trees::Model* config =
+ decision_tree_resource->mutable_decision_tree();
+ OP_REQUIRES(context,
+ ParseProtoUnlimited(config, tree_config_t->scalar<string>()()),
+ errors::InvalidArgument("Unable to parse tree config."));
+ decision_tree_resource->MaybeInitialize();
+ }
+
+ private:
+ TensorForestParams param_proto_;
+};
+
+// Op for getting tree size.
+class TreeSizeOp : public OpKernel {
+ public:
+ explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ DecisionTreeResource* decision_tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, TensorShape(), &output_t));
+ output_t->scalar<int32>()() =
+ decision_tree_resource->decision_tree().decision_tree().nodes_size();
+ }
+};
+
+
+// Op for tree inference.
+class TreePredictionsV4Op : public OpKernel {
+ public:
+ explicit TreePredictionsV4Op(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+
+ string serialized_proto;
+ OP_REQUIRES_OK(context, context->GetAttr(
+ "input_spec", &serialized_proto));
+ input_spec_.ParseFromString(serialized_proto);
+
+ data_set_ =
+ std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
+
+ model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_data = context->input(1);
+ const Tensor& sparse_input_indices = context->input(2);
+ const Tensor& sparse_input_values = context->input(3);
+
+ data_set_->set_input_tensors(input_data, sparse_input_indices,
+ sparse_input_values);
+
+ DecisionTreeResource* decision_tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+
+ Tensor* output_predictions = nullptr;
+ TensorShape output_shape;
+ output_shape.AddDim(data_set_->NumItems());
+ output_shape.AddDim(param_proto_.num_outputs());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape,
+ &output_predictions));
+
+ auto out = output_predictions->tensor<float, 2>();
+ for (int i = 0; i < data_set_->NumItems(); ++i) {
+ const int32 leaf_id =
+ decision_tree_resource->TraverseTree(data_set_, i, nullptr);
+ const decision_trees::Leaf& leaf =
+ decision_tree_resource->get_leaf(leaf_id);
+ for (int j = 0; j < param_proto_.num_outputs(); ++j) {
+ const float count = model_op_->GetOutputValue(leaf, j);
+ out(i, j) = count;
+ }
+ }
+ }
+
+ private:
+ tensorforest::TensorForestDataSpec input_spec_;
+ std::unique_ptr<TensorDataSet> data_set_;
+ std::unique_ptr<LeafModelOperator> model_op_;
+ TensorForestParams param_proto_;
+};
+
+// Op for getting feature usage counts.
+class FeatureUsageCountsOp : public OpKernel {
+ public:
+ explicit FeatureUsageCountsOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ DecisionTreeResource* decision_tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+
+
+ const auto& tree = decision_tree_resource->decision_tree();
+
+ Tensor* output_counts = nullptr;
+ TensorShape output_shape;
+ output_shape.AddDim(param_proto_.num_features());
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, output_shape, &output_counts));
+
+ auto counts = output_counts->unaligned_flat<int32>();
+ counts.setZero();
+
+ for (const auto& node : tree.decision_tree().nodes()) {
+ if (node.has_custom_node_type()) {
+ LOG(WARNING) << "Can't count feature usage for custom nodes.";
+ } else if (node.has_binary_node()) {
+ const auto& bnode = node.binary_node();
+ if (bnode.has_custom_left_child_test()) {
+ decision_trees::MatchingValuesTest test;
+ if (!bnode.custom_left_child_test().UnpackTo(&test)) {
+ LOG(WARNING) << "Unknown custom child test";
+ continue;
+ }
+ int32 feat;
+ safe_strto32(test.feature_id().id().value(), &feat);
+ ++counts(feat);
+ } else {
+ const auto& test = bnode.inequality_left_child_test();
+ if (test.has_feature_id()) {
+ int32 feat;
+ safe_strto32(test.feature_id().id().value(), &feat);
+ ++counts(feat);
+ } else if (test.has_oblique()) {
+ for (const auto& featid : test.oblique().features()) {
+ int32 feat;
+ safe_strto32(featid.id().value(), &feat);
+ ++counts(feat);
+ }
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ TensorForestParams param_proto_;
+};
+
+
+REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeResource);
+
+REGISTER_KERNEL_BUILDER(Name("TreeIsInitializedOp").Device(DEVICE_CPU),
+ IsResourceInitialized<DecisionTreeResource>);
+
+REGISTER_KERNEL_BUILDER(Name("CreateTreeVariable").Device(DEVICE_CPU),
+ CreateTreeVariableOp);
+
+REGISTER_KERNEL_BUILDER(Name("TreeSerialize").Device(DEVICE_CPU),
+ TreeSerializeOp);
+
+REGISTER_KERNEL_BUILDER(Name("TreeDeserialize").Device(DEVICE_CPU),
+ TreeDeserializeOp);
+
+REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU),
+ TreeSizeOp);
+
+REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU),
+ TreePredictionsV4Op);
+
+REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU),
+ FeatureUsageCountsOp);
+
+} // namespace tensorforest
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
new file mode 100644
index 0000000000..7442469507
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -0,0 +1,564 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <queue>
+
+#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
+#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_handle.pb.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+using gtl::FindOrNull;
+
+// Creates a stats variable.
+class CreateFertileStatsVariableOp : public OpKernel {
+ public:
+ explicit CreateFertileStatsVariableOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* stats_config_t;
+ OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()),
+ errors::InvalidArgument("Stats config must be a scalar."));
+ auto* result = new FertileStatsResource(param_proto_);
+ FertileStats stats;
+ if (!ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()())) {
+ result->Unref();
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument("Unable to parse stats config."));
+ }
+
+ result->ExtractFromProto(stats);
+ result->MaybeInitialize();
+
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions.
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ TensorForestParams param_proto_;
+};
+
+// Op for serializing a model.
+class FertileStatsSerializeOp : public OpKernel {
+ public:
+ explicit FertileStatsSerializeOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ FertileStatsResource* fertile_stats_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &fertile_stats_resource));
+ mutex_lock l(*fertile_stats_resource->get_mutex());
+ core::ScopedUnref unref_me(fertile_stats_resource);
+ Tensor* output_config_t = nullptr;
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, TensorShape(), &output_config_t));
+
+ FertileStats stats;
+ fertile_stats_resource->PackToProto(&stats);
+ output_config_t->scalar<string>()() = stats.SerializeAsString();
+ }
+
+ private:
+ TensorForestParams param_proto_;
+};
+
+// Op for deserializing a stats variable from a checkpoint.
+class FertileStatsDeserializeOp : public OpKernel {
+ public:
+ explicit FertileStatsDeserializeOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ FertileStatsResource* fertile_stats_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &fertile_stats_resource));
+ mutex_lock l(*fertile_stats_resource->get_mutex());
+ core::ScopedUnref unref_me(fertile_stats_resource);
+
+ const Tensor* stats_config_t;
+ OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()),
+ errors::InvalidArgument("Stats config must be a scalar."));
+ // Deallocate all the previous objects on the resource.
+ fertile_stats_resource->Reset();
+ FertileStats stats;
+ OP_REQUIRES(context,
+ ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()()),
+ errors::InvalidArgument("Unable to parse stats config."));
+
+ fertile_stats_resource->ExtractFromProto(stats);
+ fertile_stats_resource->MaybeInitialize();
+ }
+
+ private:
+ TensorForestParams param_proto_;
+};
+
+void TraverseTree(const DecisionTreeResource* tree_resource,
+ const std::unique_ptr<TensorDataSet>& data, int32 start,
+ int32 end, std::vector<int32>* leaf_ids,
+ std::vector<int32>* leaf_depths) {
+ for (int i = start; i < end; ++i) {
+ int32 depth;
+ const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth);
+ (*leaf_ids)[i] = leaf_id;
+ (*leaf_depths)[i] = depth;
+ }
+}
+
+// Try to update a leaf's stats by acquiring its lock. If it can't be
+// acquired, put it in a waiting queue to come back to later and try the next
+// one. Once all leaf_ids have been visited, cycle through the waiting ids
+// until they're gone.
+void UpdateStats(FertileStatsResource* fertile_stats_resource,
+ const std::unique_ptr<TensorDataSet>& data,
+ const Tensor& input_labels, const Tensor& input_weights,
+ int num_targets, const std::vector<int32>& leaf_ids,
+ const std::vector<int32>& leaf_depths,
+ std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
+ mutex* set_lock, int32 start, int32 end,
+ std::unordered_set<int32>* ready_to_split) {
+ const auto labels = input_labels.unaligned_flat<float>();
+ const auto weights = input_weights.unaligned_flat<float>();
+ // Stores leaf_id, leaf_depth, example_id for examples that are waiting
+ // on another to finish.
+ std::queue<std::tuple<int32, int32, int32>> waiting;
+
+ int32 i = start;
+ TensorInputTarget target(&labels, &weights, input_labels, num_targets);
+ while (i < end || !waiting.empty()) {
+ int32 leaf_id;
+ int32 leaf_depth;
+ int32 example_id;
+ bool was_waiting = false;
+ if (i >= end) {
+ std::tie(leaf_id, leaf_depth, example_id) = waiting.front();
+ waiting.pop();
+ was_waiting = true;
+ } else {
+ leaf_id = leaf_ids[i];
+ leaf_depth = leaf_depths[i];
+ example_id = i;
+ ++i;
+ }
+ const std::unique_ptr<mutex>& leaf_lock = (*locks)[leaf_id];
+ if (was_waiting) {
+ leaf_lock->lock();
+ } else {
+ if (!leaf_lock->try_lock()) {
+ waiting.emplace(leaf_id, leaf_depth, example_id);
+ continue;
+ }
+ }
+
+ bool is_finished;
+ fertile_stats_resource->AddExampleToStatsAndInitialize(
+ data, &target, {example_id}, leaf_id, leaf_depth,
+ &is_finished);
+ leaf_lock->unlock();
+ if (is_finished) {
+ set_lock->lock();
+ ready_to_split->insert(leaf_id);
+ set_lock->unlock();
+ }
+ }
+}
+
+// Update leaves from start through end in the leaf_examples iterator.
+void UpdateStatsCollated(
+ FertileStatsResource* fertile_stats_resource,
+ DecisionTreeResource* tree_resource,
+ const std::unique_ptr<TensorDataSet>& data, const Tensor& input_labels,
+ const Tensor& input_weights, int num_targets,
+ const std::unordered_map<int32, std::vector<int>>& leaf_examples,
+ const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start,
+ int32 end, std::unordered_set<int32>* ready_to_split) {
+ const auto labels = input_labels.unaligned_flat<float>();
+ const auto weights = input_weights.unaligned_flat<float>();
+
+ TensorInputTarget target(&labels, &weights, input_labels, num_targets);
+ auto it = leaf_examples.begin();
+ std::advance(it, start);
+ auto end_it = leaf_examples.begin();
+ std::advance(end_it, end);
+ while (it != end_it) {
+ int32 leaf_id = it->first;
+ bool is_finished;
+ fertile_stats_resource->AddExampleToStatsAndInitialize(
+ data, &target, it->second, leaf_id, leaf_depths[it->second[0]],
+ &is_finished);
+ if (is_finished) {
+ set_lock->lock();
+ ready_to_split->insert(leaf_id);
+ set_lock->unlock();
+ }
+ ++it;
+ }
+}
+
+// Op for traversing the tree with each example, accumulating statistics, and
+// outputting node ids that are ready to split.
+class ProcessInputOp : public OpKernel {
+ public:
+ explicit ProcessInputOp(OpKernelConstruction* context) : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+
+ OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
+
+ string serialized_proto;
+ OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
+ input_spec_.ParseFromString(serialized_proto);
+
+ data_set_ = std::unique_ptr<TensorDataSet>(
+ new TensorDataSet(input_spec_, random_seed_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_data = context->input(2);
+ const Tensor& sparse_input_indices = context->input(3);
+ const Tensor& sparse_input_values = context->input(4);
+ const Tensor& input_labels = context->input(6);
+ const Tensor& input_weights = context->input(7);
+
+ data_set_->set_input_tensors(input_data, sparse_input_indices,
+ sparse_input_values);
+
+ FertileStatsResource* fertile_stats_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
+ &fertile_stats_resource));
+ DecisionTreeResource* tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &tree_resource));
+ mutex_lock l1(*fertile_stats_resource->get_mutex());
+ mutex_lock l2(*tree_resource->get_mutex());
+
+ core::ScopedUnref unref_stats(fertile_stats_resource);
+ core::ScopedUnref unref_tree(tree_resource);
+
+ const int32 num_data = data_set_->NumItems();
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ int num_threads = worker_threads->num_threads;
+
+ // First find the leaf ids for each example.
+ std::vector<int32> leaf_ids(num_data);
+
+ // The depth of the leaf for example i.
+ std::vector<int32> leaf_depths(num_data);
+
+ const int64 costPerTraverse = 500;
+ auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data](
+ int64 start, int64 end) {
+ CHECK(start <= end);
+ CHECK(end <= num_data);
+ TraverseTree(tree_resource, data_set_, static_cast<int32>(start),
+ static_cast<int32>(end), &leaf_ids, &leaf_depths);
+ };
+ Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
+ traverse);
+
+ // Create one mutex per leaf. We need to protect access to leaf pointers,
+ // so instead of grouping examples by leaf, we spread examples out among
+ // threads to provide uniform work for each of them and protect access
+ // with mutexes.
+ std::unordered_map<int, std::unique_ptr<mutex>> locks;
+ std::unordered_map<int32, std::vector<int>> leaf_examples;
+ if (param_proto_.collate_examples()) {
+ for (int i = 0; i < num_data; ++i) {
+ leaf_examples[leaf_ids[i]].push_back(i);
+ }
+ } else {
+ for (const int32 id : leaf_ids) {
+ if (FindOrNull(locks, id) == nullptr) {
+ // TODO(gilberth): Consider using a memory pool for these.
+ locks[id] = std::unique_ptr<mutex>(new mutex);
+ }
+ }
+ }
+
+ const int32 num_leaves = leaf_examples.size();
+ const int32 label_dim =
+ input_labels.shape().dims() <= 1
+ ? 0
+ : static_cast<int>(input_labels.shape().dim_size(1));
+ const int32 num_targets =
+ param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
+
+ // Ids of leaves that can split.
+ std::unordered_set<int32> ready_to_split;
+ mutex set_lock;
+
+ // TODO(gilberth): This is a rough approximation based on measurements
+ // from a digits run on local desktop. Heuristics might be necessary
+ // if it really matters that much.
+ const int64 costPerUpdate = 1000;
+ auto update = [this, &input_labels, &input_weights, &leaf_ids, &leaf_depths,
+ &num_targets, fertile_stats_resource, &locks, &set_lock,
+ &ready_to_split, num_data](int64 start, int64 end) {
+ CHECK(start <= end);
+ CHECK(end <= num_data);
+ UpdateStats(fertile_stats_resource, data_set_, input_labels,
+ input_weights, num_targets, leaf_ids, leaf_depths, &locks,
+ &set_lock, static_cast<int32>(start), static_cast<int32>(end),
+ &ready_to_split);
+ };
+
+ auto update_collated = [this, &input_labels, &input_weights, &leaf_ids,
+ &num_targets, &leaf_depths, fertile_stats_resource,
+ tree_resource, &leaf_examples, &set_lock,
+ &ready_to_split,
+ num_leaves](int64 start, int64 end) {
+ CHECK(start <= end);
+ CHECK(end <= num_leaves);
+ UpdateStatsCollated(
+ fertile_stats_resource, tree_resource, data_set_, input_labels,
+ input_weights, num_targets, leaf_examples, leaf_depths, &set_lock,
+ static_cast<int32>(start), static_cast<int32>(end), &ready_to_split);
+ };
+
+ if (param_proto_.collate_examples()) {
+ Shard(num_threads, worker_threads->workers, num_leaves, costPerUpdate,
+ update_collated);
+ } else {
+ Shard(num_threads, worker_threads->workers, num_data, costPerUpdate,
+ update);
+ }
+
+ Tensor* output_finished_t = nullptr;
+ TensorShape output_shape;
+ output_shape.AddDim(ready_to_split.size());
+ OP_REQUIRES_OK(
+ context, context->allocate_output(0, output_shape, &output_finished_t));
+ auto output = output_finished_t->unaligned_flat<int32>();
+ std::copy(ready_to_split.begin(), ready_to_split.end(), output.data());
+ }
+
+ private:
+ int32 random_seed_;
+ tensorforest::TensorForestDataSpec input_spec_;
+ std::unique_ptr<TensorDataSet> data_set_;
+ TensorForestParams param_proto_;
+};
+
+
+// Op for growing finished nodes.
+class GrowTreeOp : public OpKernel {
+ public:
+ explicit GrowTreeOp(OpKernelConstruction* context) : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ FertileStatsResource* fertile_stats_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
+ &fertile_stats_resource));
+ DecisionTreeResource* tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &tree_resource));
+ mutex_lock l1(*fertile_stats_resource->get_mutex());
+ mutex_lock l2(*tree_resource->get_mutex());
+
+ core::ScopedUnref unref_stats(fertile_stats_resource);
+ core::ScopedUnref unref_tree(tree_resource);
+
+ const Tensor& finished_nodes = context->input(2);
+
+ const auto finished = finished_nodes.unaligned_flat<int32>();
+
+ const int32 num_nodes =
+ static_cast<int32>(finished_nodes.shape().dim_size(0));
+
+ // TODO(gilberth): distribute this work over a number of threads.
+ for (int i = 0;
+ i < num_nodes &&
+ tree_resource->decision_tree().decision_tree().nodes_size() <
+ param_proto_.max_nodes();
+ ++i) {
+ const int32 node = finished(i);
+ std::unique_ptr<SplitCandidate> best(new SplitCandidate);
+ int32 parent_depth;
+ bool found =
+ fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
+ if (found) {
+ std::vector<int32> new_children;
+ tree_resource->SplitNode(node, best.get(), &new_children);
+ fertile_stats_resource->Allocate(parent_depth, new_children);
+ fertile_stats_resource->set_leaf_stat(best->left_stats(),
+ new_children[0]);
+ fertile_stats_resource->set_leaf_stat(best->right_stats(),
+ new_children[1]);
+ // We are done with best, so it is now safe to clear node.
+ fertile_stats_resource->Clear(node);
+ CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
+ } else { // reset
+ fertile_stats_resource->ResetSplitStats(node, parent_depth);
+ }
+ }
+ }
+
+ private:
+ tensorforest::TensorForestDataSpec input_spec_;
+ TensorForestParams param_proto_;
+};
+
+void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression,
+ bool drop_final_class,
+ const std::unique_ptr<LeafModelOperator>& leaf_op,
+ decision_trees::Leaf* leaf) {
+ leaf_op->ExportModel(leaf_stats, leaf);
+
+ // TODO(thomaswc): Move the rest of this into ExportModel.
+
+ // regression models are already stored in leaf in normalized form.
+ if (is_regression) {
+ return;
+ }
+
+ float sum = leaf_stats.weight_sum();
+ if (sum <= 0.0) {
+ LOG(WARNING) << "Leaf with sum " << sum
+ << " has stats " << leaf->ShortDebugString();
+ return;
+ }
+
+ if (leaf->has_vector()) {
+ for (int i = 0; i < leaf->vector().value_size(); i++) {
+ auto *v = leaf->mutable_vector()->mutable_value(i);
+ v->set_float_value(v->float_value() / sum);
+ }
+ if (drop_final_class) {
+ leaf->mutable_vector()->mutable_value()->RemoveLast();
+ }
+ return;
+ }
+
+ if (leaf->has_sparse_vector()) {
+ for (auto& it : *leaf->mutable_sparse_vector()->mutable_sparse_value()) {
+ it.second.set_float_value(it.second.float_value() / sum);
+ }
+ return;
+ }
+
+ LOG(FATAL) << "Unknown leaf type in " << leaf->DebugString();
+}
+
+// Op for finalizing a tree at the end of training.
+class FinalizeTreeOp : public OpKernel {
+ public:
+ explicit FinalizeTreeOp(OpKernelConstruction* context) : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+
+ model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ DecisionTreeResource* tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &tree_resource));
+ FertileStatsResource* fertile_stats_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
+ &fertile_stats_resource));
+
+ mutex_lock l1(*fertile_stats_resource->get_mutex());
+ mutex_lock l2(*tree_resource->get_mutex());
+
+ core::ScopedUnref unref_me(tree_resource);
+ core::ScopedUnref unref_stats(fertile_stats_resource);
+
+ // TODO(thomaswc): Add threads
+ int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
+ for (int i = 0; i < num_nodes; i++) {
+ auto* node = tree_resource->mutable_decision_tree()
+ ->mutable_decision_tree()->mutable_nodes(i);
+ if (node->has_leaf()) {
+ const auto& leaf_stats = fertile_stats_resource->leaf_stat(i);
+ FinalizeLeaf(leaf_stats, param_proto_.is_regression(),
+ param_proto_.drop_final_class(), model_op_,
+ node->mutable_leaf());
+ }
+ }
+ }
+
+ private:
+ std::unique_ptr<LeafModelOperator> model_op_;
+ TensorForestParams param_proto_;
+};
+
+REGISTER_RESOURCE_HANDLE_KERNEL(FertileStatsResource);
+
+REGISTER_KERNEL_BUILDER(Name("FertileStatsIsInitializedOp").Device(DEVICE_CPU),
+ IsResourceInitialized<FertileStatsResource>);
+
+REGISTER_KERNEL_BUILDER(Name("CreateFertileStatsVariable").Device(DEVICE_CPU),
+ CreateFertileStatsVariableOp);
+
+REGISTER_KERNEL_BUILDER(Name("FertileStatsSerialize").Device(DEVICE_CPU),
+ FertileStatsSerializeOp);
+
+REGISTER_KERNEL_BUILDER(Name("FertileStatsDeserialize").Device(DEVICE_CPU),
+ FertileStatsDeserializeOp);
+
+REGISTER_KERNEL_BUILDER(Name("ProcessInputV4").Device(DEVICE_CPU),
+ ProcessInputOp);
+
+REGISTER_KERNEL_BUILDER(Name("GrowTreeV4").Device(DEVICE_CPU),
+ GrowTreeOp);
+
+REGISTER_KERNEL_BUILDER(Name("FinalizeTree").Device(DEVICE_CPU),
+ FinalizeTreeOp);
+
+} // namespace tensorforest
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
index 0542508a8e..a9d8093d13 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD
@@ -40,9 +40,7 @@ cc_library(
":split_collection_operators",
"//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc",
"//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc",
- "//tensorflow/core:framework",
"//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:lib",
],
)
@@ -111,7 +109,7 @@ cc_library(
"//tensorflow/contrib/tensor_forest:tree_utils",
"//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc",
"//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc",
- "//tensorflow/core:lib",
+ "//tensorflow/core:framework_headers_lib",
],
)
@@ -153,7 +151,7 @@ cc_library(
":input_data",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc",
- "//tensorflow/core:lib",
+ "//tensorflow/core:framework_headers_lib",
],
)
@@ -175,12 +173,32 @@ cc_library(
srcs = ["split_collection_operators.cc"],
hdrs = ["split_collection_operators.h"],
deps = [
+ ":grow_stats",
+ ":input_data",
+ ":input_target",
+ ":leaf_model_operators",
+ ":params",
+ ":stat_utils",
+ "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc",
+ "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc",
+ "//tensorflow/contrib/tensor_forest:tree_utils",
+ "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc",
+ "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "graph_collection_operator",
+ srcs = ["graph_collection_operator.cc"],
+ hdrs = ["graph_collection_operator.h"],
+ deps = [
":candidate_graph_runner",
":grow_stats",
":input_data",
":input_target",
":leaf_model_operators",
":params",
+ ":split_collection_operators",
":stat_utils",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc",
"//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc",
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc
new file mode 100644
index 0000000000..2c925b5dd7
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc
@@ -0,0 +1,142 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h"
+
+#include <cfloat>
+
+#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+REGISTER_SPLIT_COLLECTION(GRAPH_RUNNER_COLLECTION,
+ GraphRunnerSplitCollectionOperator);
+
+std::unique_ptr<GrowStats> GraphRunnerSplitCollectionOperator::CreateGrowStats(
+ int32 node_id, int32 depth) const {
+ return std::unique_ptr<GrowStats>(new SimpleStats(params_, depth));
+}
+
+int64 GraphRunnerSplitCollectionOperator::UniqueId(int32 node_id,
+ int32 split_id) const {
+ return node_id * num_splits_to_consider_ + split_id;
+}
+
+bool GraphRunnerSplitCollectionOperator::BestSplit(int32 node_id,
+ SplitCandidate* best,
+ int32* depth) const {
+ float min_score = FLT_MAX;
+ int best_index = -1;
+ auto* slot = stats_.at(node_id).get();
+ *depth = slot->depth();
+ for (int i = 0; i < slot->num_splits(); ++i) {
+ // TODO(gilberth): Support uselessness.
+ auto& runner = runners_[UniqueId(node_id, i)];
+ const float split_score = runner->SplitScore();
+ if (split_score < min_score) {
+ min_score = split_score;
+ best_index = i;
+ }
+ }
+
+ // This could happen if all the splits are useless.
+ if (best_index < 0) {
+ return false;
+ }
+
+ // Fill in split info and left/right stats to initialize models with.
+ *best = SplitCandidate();
+ auto& runner = runners_[UniqueId(node_id, best_index)];
+ runner->GetLeftStats(best->mutable_left_stats());
+ runner->GetRightStats(best->mutable_right_stats());
+ runner->GetSplit(best->mutable_split());
+ return true;
+}
+
+void GraphRunnerSplitCollectionOperator::AddExample(
+ const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
+ const std::vector<int>& examples, int32 node_id) const {
+ // Build input Tensors.
+ int size = examples.size();
+ Tensor examples_t(tensorflow::DT_INT32, TensorShape({size}));
+ auto ex_data = examples_t.flat<int32>();
+ std::copy(examples.begin(), examples.end(), ex_data.data());
+
+ const TensorInputTarget* tensor_target =
+ dynamic_cast<const TensorInputTarget*>(target);
+ CHECK_NOTNULL(tensor_target);
+
+ const Tensor& data_t = input_data->original_tensor();
+ const Tensor& target_t = tensor_target->original_tensor();
+
+ // Add to candidates.
+ auto* slot = stats_.at(node_id).get();
+ for (int i = 0; i < slot->num_splits(); ++i) {
+ auto& runner = runners_[UniqueId(node_id, i)];
+ runner->AddExample(data_t, target_t, examples_t);
+ }
+
+ // Update simple weight sums so we know when we're done.
+ for (int example : examples) {
+ slot->AddExample(input_data, target, example);
+ }
+}
+
+void GraphRunnerSplitCollectionOperator::
+ CreateAndInitializeCandidateWithExample(
+ const std::unique_ptr<TensorDataSet>& input_data, int example,
+ int32 node_id) const {
+ auto* slot = stats_.at(node_id).get();
+ int cand_num = slot->num_splits();
+ const int64 unique_id = UniqueId(node_id, cand_num);
+
+ decision_trees::BinaryNode split;
+
+ decision_trees::InequalityTest* test =
+ split.mutable_inequality_left_child_test();
+ auto* oblique = test->mutable_oblique();
+ for (int i = 0; i < features_per_node_; ++i) {
+ float bias;
+ int type;
+ // This is really just a way to select a list of random features.
+ // Also a way to warn the user that categoricals don't make sense here.
+ input_data->RandomSample(example, oblique->add_features(), &bias, &type);
+
+ if (type == kDataFloat) {
+ test->set_type(decision_trees::InequalityTest::LESS_OR_EQUAL);
+
+ // The comparison bias is assumed to be zero.
+ test->mutable_threshold()->set_float_value(0);
+ } else {
+ LOG(ERROR) << "Categorical features not supported with this system.";
+ return;
+ }
+ }
+
+ slot->AddSplit(split);
+
+ runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split));
+ runners_[unique_id]->Init();
+}
+
+void GraphRunnerSplitCollectionOperator::ClearSlot(int32 node_id) {
+ SplitCollectionOperator::ClearSlot(node_id);
+ for (int i = 0; i < num_splits_to_consider_; ++i) {
+ runners_.erase(UniqueId(node_id, i));
+ }
+}
+
+} // namespace tensorforest
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h
new file mode 100644
index 0000000000..9b18e3e969
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h
@@ -0,0 +1,81 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
+
+#include <vector>
+#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
+#include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h"
+#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
+#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
+
+namespace tensorflow {
+namespace tensorforest {
+
+// Holds split candidates that are trained by running any TF graph.
+class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator {
+ public:
+ explicit GraphRunnerSplitCollectionOperator(const TensorForestParams& params)
+ : SplitCollectionOperator(params) {
+ if (params.num_splits_to_consider().ParamType_case() ==
+ DepthDependentParam::PARAMTYPE_NOT_SET) {
+ LOG(FATAL) << "GRAPH_RUNNER_COLLECTION must specify a constant value for "
+ << " num_splits_to_consider";
+ } else {
+ num_splits_to_consider_ =
+ params.num_splits_to_consider().constant_value();
+ }
+ }
+
+ std::unique_ptr<GrowStats> CreateGrowStats(int32 node_id,
+ int32 depth) const override;
+
+ // Updates the slot's candidates with the new example.
+ // Assumes slot has been initialized.
+ void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
+ const InputTarget* target, const std::vector<int>& examples,
+ int32 node_id) const override;
+
+ // Create a new candidate and initialize it with the given example.
+ void CreateAndInitializeCandidateWithExample(
+ const std::unique_ptr<TensorDataSet>& input_data, int example,
+ int32 node_id) const override;
+
+ bool BestSplit(int32 node_id, SplitCandidate* best,
+ int32* depth) const override;
+
+ void ClearSlot(int32 node_id) override;
+
+ protected:
+ int64 UniqueId(int32 node_id, int32 split_id) const;
+
+ mutable std::unordered_map<int64, std::unique_ptr<CandidateGraphRunner>>
+ runners_;
+ int features_per_node_;
+ string graph_dir_;
+ // Must have a constant value because of how we make unique ids right now.
+ int32 num_splits_to_consider_;
+};
+
+} // namespace tensorforest
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
index ddf4be8799..c207c0859d 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
@@ -23,22 +23,20 @@
namespace tensorflow {
namespace tensorforest {
+std::unordered_map<int, CollectionCreator*>
+ SplitCollectionOperatorFactory::factories_; // NOLINT
+REGISTER_SPLIT_COLLECTION(COLLECTION_BASIC, SplitCollectionOperator);
+
std::unique_ptr<SplitCollectionOperator>
SplitCollectionOperatorFactory::CreateSplitCollectionOperator(
const TensorForestParams& params) {
- switch (params.collection_type()) {
- case COLLECTION_BASIC:
- return std::unique_ptr<SplitCollectionOperator>(
- new SplitCollectionOperator(params));
-
- case GRAPH_RUNNER_COLLECTION:
- return std::unique_ptr<SplitCollectionOperator>(
- new GraphRunnerSplitCollectionOperator(params));
-
- default:
- LOG(ERROR) << "Unknown split collection operator: "
- << params.collection_type();
- return nullptr;
+ auto it = factories_.find(params.collection_type());
+ if (it == factories_.end()) {
+ LOG(ERROR) << "Unknown split collection operator: "
+ << params.collection_type();
+ return nullptr;
+ } else {
+ return it->second->Create(params);
}
}
@@ -137,121 +135,5 @@ bool SplitCollectionOperator::BestSplit(int32 node_id,
return slot->BestSplit(best);
}
-// -------------------------------- GraphRunner ------------------ //
-
-std::unique_ptr<GrowStats> GraphRunnerSplitCollectionOperator::CreateGrowStats(
- int32 node_id, int32 depth) const {
- return std::unique_ptr<GrowStats>(new SimpleStats(params_, depth));
-}
-
-int64 GraphRunnerSplitCollectionOperator::UniqueId(int32 node_id,
- int32 split_id) const {
- return node_id * num_splits_to_consider_ + split_id;
-}
-
-bool GraphRunnerSplitCollectionOperator::BestSplit(int32 node_id,
- SplitCandidate* best,
- int32* depth) const {
- float min_score = FLT_MAX;
- int best_index = -1;
- auto* slot = stats_.at(node_id).get();
- *depth = slot->depth();
- for (int i = 0; i < slot->num_splits(); ++i) {
- // TODO(gilberth): Support uselessness.
- auto& runner = runners_[UniqueId(node_id, i)];
- const float split_score = runner->SplitScore();
- if (split_score < min_score) {
- min_score = split_score;
- best_index = i;
- }
- }
-
- // This could happen if all the splits are useless.
- if (best_index < 0) {
- return false;
- }
-
- // Fill in split info and left/right stats to initialize models with.
- *best = SplitCandidate();
- auto& runner = runners_[UniqueId(node_id, best_index)];
- runner->GetLeftStats(best->mutable_left_stats());
- runner->GetRightStats(best->mutable_right_stats());
- runner->GetSplit(best->mutable_split());
- return true;
-}
-
-void GraphRunnerSplitCollectionOperator::AddExample(
- const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
- const std::vector<int>& examples, int32 node_id) const {
- // Build input Tensors.
- int size = examples.size();
- Tensor examples_t(tensorflow::DT_INT32, TensorShape({size}));
- auto ex_data = examples_t.flat<int32>();
- std::copy(examples.begin(), examples.end(), ex_data.data());
-
- const TensorInputTarget* tensor_target =
- dynamic_cast<const TensorInputTarget*>(target);
- CHECK_NOTNULL(tensor_target);
-
- const Tensor& data_t = input_data->original_tensor();
- const Tensor& target_t = tensor_target->original_tensor();
-
- // Add to candidates.
- auto* slot = stats_.at(node_id).get();
- for (int i = 0; i < slot->num_splits(); ++i) {
- auto& runner = runners_[UniqueId(node_id, i)];
- runner->AddExample(data_t, target_t, examples_t);
- }
-
- // Update simple weight sums so we know when we're done.
- for (int example : examples) {
- slot->AddExample(input_data, target, example);
- }
-}
-
-void GraphRunnerSplitCollectionOperator::
- CreateAndInitializeCandidateWithExample(
- const std::unique_ptr<TensorDataSet>& input_data, int example,
- int32 node_id) const {
- auto* slot = stats_.at(node_id).get();
- int cand_num = slot->num_splits();
- const int64 unique_id = UniqueId(node_id, cand_num);
-
- decision_trees::BinaryNode split;
-
- decision_trees::InequalityTest* test =
- split.mutable_inequality_left_child_test();
- auto* oblique = test->mutable_oblique();
- for (int i = 0; i < features_per_node_; ++i) {
- float bias;
- int type;
- // This is really just a way to select a list of random features.
- // Also a way to warn the user that categoricals don't make sense here.
- input_data->RandomSample(example, oblique->add_features(), &bias, &type);
-
- if (type == kDataFloat) {
- test->set_type(decision_trees::InequalityTest::LESS_OR_EQUAL);
-
- // The comparison bias is assumed to be zero.
- test->mutable_threshold()->set_float_value(0);
- } else {
- LOG(ERROR) << "Categorical features not supported with this system.";
- return;
- }
- }
-
- slot->AddSplit(split);
-
- runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split));
- runners_[unique_id]->Init();
-}
-
-void GraphRunnerSplitCollectionOperator::ClearSlot(int32 node_id) {
- SplitCollectionOperator::ClearSlot(node_id);
- for (int i = 0; i < num_splits_to_consider_; ++i) {
- runners_.erase(UniqueId(node_id, i));
- }
-}
-
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
index d0ea33612a..81d820a6b2 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
@@ -17,7 +17,6 @@
#include <vector>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
-#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h"
#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
@@ -86,60 +85,38 @@ class SplitCollectionOperator {
std::unordered_map<int32, std::unique_ptr<GrowStats>> stats_;
};
-
-class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator {
+class CollectionCreator {
public:
- explicit GraphRunnerSplitCollectionOperator(const TensorForestParams& params)
- : SplitCollectionOperator(params) {
- if (params.num_splits_to_consider().ParamType_case() ==
- DepthDependentParam::PARAMTYPE_NOT_SET) {
- LOG(FATAL) << "GRAPH_RUNNER_COLLECTION must specify a constant value for "
- << " num_splits_to_consider";
- } else {
- num_splits_to_consider_ =
- params.num_splits_to_consider().constant_value();
- }
- }
-
- std::unique_ptr<GrowStats> CreateGrowStats(int32 node_id,
- int32 depth) const override;
-
- // Updates the slot's candidates with the new example.
- // Assumes slot has been initialized.
- void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
- const InputTarget* target, const std::vector<int>& examples,
- int32 node_id) const override;
-
- // Create a new candidate and initialize it with the given example.
- void CreateAndInitializeCandidateWithExample(
- const std::unique_ptr<TensorDataSet>& input_data, int example,
- int32 node_id) const override;
-
- bool BestSplit(int32 node_id, SplitCandidate* best,
- int32* depth) const override;
-
- void ClearSlot(int32 node_id) override;
-
- protected:
- int64 UniqueId(int32 node_id, int32 split_id) const;
-
- mutable std::unordered_map<int64, std::unique_ptr<CandidateGraphRunner>>
- runners_;
- int features_per_node_;
- string graph_dir_;
- // Must have a constant value because of how we make unique ids right now.
- int32 num_splits_to_consider_;
+ virtual std::unique_ptr<SplitCollectionOperator> Create(
+ const TensorForestParams& params) = 0;
+ virtual ~CollectionCreator() {}
};
-// Creates a type of SplitCollectionOperator depending on the type passed,
-// which is SplitCollectionType in fertile_stats.proto.
-// Can create a SplitCollectionOperator itself, known as "basic".
class SplitCollectionOperatorFactory {
public:
static std::unique_ptr<SplitCollectionOperator> CreateSplitCollectionOperator(
const TensorForestParams& params);
+
+ static std::unordered_map<int, CollectionCreator*> factories_;
+};
+
+template <typename T>
+class AnyCollectionCreator : public CollectionCreator {
+ public:
+ AnyCollectionCreator(SplitCollectionType type) {
+ SplitCollectionOperatorFactory::factories_[type] = this;
+ }
+ virtual std::unique_ptr<SplitCollectionOperator> Create(
+ const TensorForestParams& params) {
+ return std::unique_ptr<SplitCollectionOperator>(new T(params));
+ }
};
+#define REGISTER_SPLIT_COLLECTION(name, cls) \
+ namespace { \
+ AnyCollectionCreator<cls> creator(name); \
+ }
+
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
new file mode 100644
index 0000000000..1c3c2153a6
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
@@ -0,0 +1,135 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+
+namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+
+namespace tensorforest {
+
+REGISTER_RESOURCE_HANDLE_OP(DecisionTreeResource);
+
+REGISTER_OP("TreeIsInitializedOp")
+ .Input("tree_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Checks whether a tree has been initialized.
+)doc");
+
+REGISTER_OP("CreateTreeVariable")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("tree_config: string")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(R"doc(
+Creates a tree model and returns a handle to it.
+
+params: A serialized TensorForestParams proto.
+tree_handle: handle to the tree resource to be created.
+tree_config: Serialized proto of the tree.
+)doc");
+
+REGISTER_OP("TreeSerialize")
+ .Input("tree_handle: resource")
+ .Output("tree_config: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Serializes the tree to a proto.
+
+tree_handle: The handle to the tree.
+tree_config: Serialized proto of the tree.
+)doc");
+
+REGISTER_OP("TreeDeserialize")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("tree_config: string")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(R"doc(
+Deserializes a serialized tree config and replaces current tree.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree .
+tree_config: Serialized proto of the .
+)doc");
+
+REGISTER_OP("TreeSize")
+ .Input("tree_handle: resource")
+ .Output("tree_size: int32")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Outputs the size of the tree, including leaves.
+
+tree_handle: The handle to the tree.
+tree_size: Size scalar.
+)doc");
+
+REGISTER_OP("TreePredictionsV4")
+ .Attr("input_spec: string")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("input_data: float")
+ .Input("sparse_input_indices: int64")
+ .Input("sparse_input_values: float")
+ .Input("sparse_input_shape: int64")
+ .Output("predictions: float")
+ .SetShapeFn([](InferenceContext* c) {
+ DimensionHandle num_points = c->UnknownDim();
+
+ if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0) {
+ num_points = c->Dim(c->input(1), 0);
+ }
+
+ c->set_output(0, c->Matrix(num_points, c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the predictions for the given input data.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+ gives the j-th feature of the i-th input.
+sparse_input_indices: The indices tensor from the SparseTensor input.
+sparse_input_values: The values tensor from the SparseTensor input.
+sparse_input_shape: The shape tensor from the SparseTensor input.
+predictions: `predictions[i][j]` is the probability that input i is class j.
+)doc");
+
+REGISTER_OP("FeatureUsageCounts")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Output("feature_counts: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the number of times each feature was used in a split.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+feature_counts: `feature_counts[i]` is the number of times feature i was used
+ in a split.
+)doc");
+
+} // namespace tensorforest
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
new file mode 100644
index 0000000000..48e91e3466
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
@@ -0,0 +1,146 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+using shape_inference::InferenceContext;
+
+namespace tensorforest {
+
+
+REGISTER_RESOURCE_HANDLE_OP(FertileStatsResource);
+
+REGISTER_OP("FertileStatsIsInitializedOp")
+ .Input("stats_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Checks whether a stats has been initialized.
+)doc");
+
+REGISTER_OP("CreateFertileStatsVariable")
+ .Attr("params: string")
+ .Input("stats_handle: resource")
+ .Input("stats_config: string")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(R"doc(
+Creates a stats model and returns a handle to it.
+
+params: A serialized TensorForestParams proto.
+stats_handle: handle to the stats resource to be created.
+stats_config: Serialized proto of the stats.
+)doc");
+
+REGISTER_OP("FertileStatsSerialize")
+ .Attr("params: string")
+ .Input("stats_handle: resource")
+ .Output("stats_config: string")
+ .SetShapeFn(tensorflow::shape_inference::ScalarShape)
+ .Doc(R"doc(
+Serializes the stats to a proto.
+
+params: A serialized TensorForestParams proto.
+stats_handle: The handle to the stats.
+stats_config: Serialized proto of the stats.
+)doc");
+
+REGISTER_OP("FertileStatsDeserialize")
+ .Attr("params: string")
+ .Input("stats_handle: resource")
+ .Input("stats_config: string")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(R"doc(
+Deserializes a serialized stats config and replaces current stats.
+
+params: A serialized TensorForestParams proto.
+stats_handle: The handle to the stats.
+stats_config: Serialized proto of the stats.
+)doc");
+
+REGISTER_OP("GrowTreeV4")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("stats_handle: resource")
+ .Input("finshed_nodes: int32")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(R"doc(
+Grows the tree for finished nodes and allocates waiting nodes.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+stats_handle: The handle to the stats.
+finshed_nodes: A 1-d Tensor of finished node ids from ProcessInput.
+)doc");
+
+REGISTER_OP("ProcessInputV4")
+ .Attr("random_seed: int")
+ .Attr("input_spec: string")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("stats_handle: resource")
+ .Input("input_data: float")
+ .Input("sparse_input_indices: int64")
+ .Input("sparse_input_values: float")
+ .Input("sparse_input_shape: int64")
+ .Input("input_labels: float")
+ .Input("input_weights: float")
+ .Output("finished_nodes: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Add labels to stats after traversing the tree for each example.
+
+Outputs node ids that are finished.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+stats_handle: The handle to the stats.
+input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+ gives the j-th feature of the i-th input.
+sparse_input_indices: The indices tensor from the SparseTensor input.
+sparse_input_values: The values tensor from the SparseTensor input.
+sparse_input_shape: The shape tensor from the SparseTensor input.
+input_labels: The training batch's labels as a 1 or 2-d tensor.
+ 'input_labels[i][j]' gives the j-th label/target for the i-th input.
+input_weights: The training batch's eample weights as a 1-d tensor.
+ 'input_weights[i]' gives the weight for the i-th input.
+finished_nodes: A 1-d tensor of node ids that have finished and are ready to
+ grow.
+)doc");
+
+REGISTER_OP("FinalizeTree")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("stats_handle: resource")
+ .SetShapeFn([](InferenceContext* c) {
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Puts the Leaf models inside the tree into their final form.
+
+If drop_final_class is true, the per-class probability prediction of the
+last class is not stored in the leaf models.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+stats_handle: The handle to the stats.
+)doc");
+} // namespace tensorforest
+} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/python/__init__.py b/tensorflow/contrib/tensor_forest/python/__init__.py
index 0d41d3500d..b2ca89877a 100644
--- a/tensorflow/contrib/tensor_forest/python/__init__.py
+++ b/tensorflow/contrib/tensor_forest/python/__init__.py
@@ -21,4 +21,6 @@ from __future__ import print_function
from tensorflow.contrib.tensor_forest.python import constants
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.contrib.tensor_forest.python.ops import data_ops
+from tensorflow.contrib.tensor_forest.python.ops import model_ops
+from tensorflow.contrib.tensor_forest.python.ops import stats_ops
from tensorflow.contrib.tensor_forest.python.ops import tensor_forest_ops
diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
new file mode 100644
index 0000000000..4c7218305b
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Model ops python wrappers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops
+from tensorflow.contrib.tensor_forest.python.ops import stats_ops
+
+# pylint: disable=unused-import
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
+# pylint: enable=unused-import
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.training import saver
+
+
+_model_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_model_ops.so"))
+
+
+ops.NotDifferentiable("TreeVariable")
+ops.NotDifferentiable("TreeSerialize")
+ops.NotDifferentiable("TreeDeserialize")
+ops.NotDifferentiable("TreeSize")
+ops.NotDifferentiable("TreePredictionsV4")
+ops.NotDifferentiable("FeatureUsageCounts")
+
+
+class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for TreeVariable."""
+
+ def __init__(self, params, tree_handle, stats_handle, create_op, name):
+ """Creates a TreeVariableSavable object.
+
+ Args:
+ params: A TensorForestParams object.
+ tree_handle: handle to the tree variable.
+ stats_handle: handle to the stats variable.
+ create_op: the op to initialize the variable.
+ name: the name to save the tree variable under.
+ """
+ self.params = params
+ deps = []
+ if stats_handle is not None:
+ deps.append(stats_ops.finalize_tree(
+ tree_handle, stats_handle,
+ params=params.serialized_params_proto))
+ with ops.control_dependencies(deps):
+ tensor = gen_model_ops.tree_serialize(tree_handle)
+ # slice_spec is useful for saving a slice from a variable.
+ # It's not meaningful the tree variable. So we just pass an empty value.
+ slice_spec = ""
+ specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name),]
+ super(TreeVariableSavable,
+ self).__init__(tree_handle, specs, name)
+ self._tree_handle = tree_handle
+ self._create_op = create_op
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ """Restores the associated tree from 'restored_tensors'.
+
+ Args:
+ restored_tensors: the tensors that were loaded from a checkpoint.
+ unused_restored_shapes: the shapes this object should conform to after
+ restore. Not meaningful for trees.
+
+ Returns:
+ The operation that restores the state of the tree variable.
+ """
+ with ops.control_dependencies([self._create_op]):
+ return gen_model_ops.tree_deserialize(
+ self._tree_handle,
+ restored_tensors[0],
+ params=self.params.serialized_params_proto)
+
+
+def tree_variable(params, tree_config, stats_handle, name, container=None):
+ r"""Creates a tree model and returns a handle to it.
+
+ Args:
+ params: A TensorForestParams object.
+ tree_config: A `Tensor` of type `string`. Serialized proto of the tree.
+ stats_handle: Resource handle to the stats object.
+ name: A name for the variable.
+ container: An optional `string`. Defaults to `""`.
+
+ Returns:
+ A `Tensor` of type mutable `string`. The handle to the tree.
+ """
+ with ops.name_scope(name, "TreeVariable") as name:
+ resource_handle = gen_model_ops.decision_tree_resource_handle_op(
+ container, name, name=name)
+
+ create_op = gen_model_ops.create_tree_variable(
+ resource_handle,
+ tree_config,
+ params=params.serialized_params_proto)
+ is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle)
+ # Adds the variable to the savable list.
+ saveable = TreeVariableSavable(params, resource_handle, stats_handle,
+ create_op,
+ "tree_checkpoint_{0}".format(name))
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ resources.register_resource(resource_handle, create_op, is_initialized_op)
+ return resource_handle
diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py
new file mode 100644
index 0000000000..be9f2e12b7
--- /dev/null
+++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py
@@ -0,0 +1,114 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stats ops python wrappers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.tensor_forest.python.ops import gen_stats_ops
+# pylint: disable=unused-import
+from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import finalize_tree
+from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import grow_tree_v4
+from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import process_input_v4
+# pylint: enable=unused-import
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import resources
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.training import saver
+
+
+_stats_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_stats_ops.so"))
+
+
+ops.NotDifferentiable("FertileStatsVariable")
+ops.NotDifferentiable("FertileStatsSerialize")
+ops.NotDifferentiable("FertileStatsDeserialize")
+ops.NotDifferentiable("GrowTreeV4")
+ops.NotDifferentiable("ProcessInputV4")
+ops.NotDifferentiable("FinalizeTree")
+
+
+class FertileStatsVariableSavable(saver.BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for FertileStatsVariable."""
+
+ def __init__(self, params, stats_handle, create_op, name):
+ """Creates a FertileStatsVariableSavable object.
+
+ Args:
+ params: A TensorForestParams object.
+ stats_handle: handle to the tree variable.
+ create_op: the op to initialize the variable.
+ name: the name to save the tree variable under.
+ """
+ self.params = params
+ tensor = gen_stats_ops.fertile_stats_serialize(
+ stats_handle, params=params.serialized_params_proto)
+ # slice_spec is useful for saving a slice from a variable.
+ # It's not meaningful the tree variable. So we just pass an empty value.
+ slice_spec = ""
+ specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name),]
+ super(FertileStatsVariableSavable,
+ self).__init__(stats_handle, specs, name)
+ self._stats_handle = stats_handle
+ self._create_op = create_op
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ """Restores the associated tree from 'restored_tensors'.
+
+ Args:
+ restored_tensors: the tensors that were loaded from a checkpoint.
+ unused_restored_shapes: the shapes this object should conform to after
+ restore. Not meaningful for trees.
+
+ Returns:
+ The operation that restores the state of the tree variable.
+ """
+ with ops.control_dependencies([self._create_op]):
+ return gen_stats_ops.fertile_stats_deserialize(
+ self._stats_handle, restored_tensors[0],
+ params=self.params.serialized_params_proto)
+
+
+def fertile_stats_variable(params, stats_config, name,
+ container=None):
+ r"""Creates a stats object and returns a handle to it.
+
+ Args:
+ params: A TensorForestParams object.
+ stats_config: A `Tensor` of type `string`. Serialized proto of the stats.
+ name: A name for the variable.
+ container: An optional `string`. Defaults to `""`.
+
+ Returns:
+ A `Tensor` of type mutable `string`. The handle to the stats.
+ """
+ with ops.name_scope(name, "FertileStatsVariable") as name:
+ resource_handle = gen_stats_ops.fertile_stats_resource_handle_op(
+ container, name, name=name)
+
+ create_op = gen_stats_ops.create_fertile_stats_variable(
+ resource_handle, stats_config,
+ params=params.serialized_params_proto)
+ is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op(
+ resource_handle)
+ # Adds the variable to the savable list.
+ saveable = FertileStatsVariableSavable(params, resource_handle, create_op,
+ "stats_checkpoint_{0}".format(name))
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ resources.register_resource(resource_handle, create_op, is_initialized_op)
+ return resource_handle