diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-21 11:25:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-21 11:29:41 -0700 |
commit | 8e40f1adcbc94c2c21dccaa557604a917fd86f22 (patch) | |
tree | f238bf35900465c0b7c668c0f5b492f7b12590e3 | |
parent | 41f3f76970726fe4ec2cd9e485a04e6f072a3bce (diff) |
Migrate ops for new version of TensorForest.
PiperOrigin-RevId: 159718610
16 files changed, 1887 insertions, 258 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index f852eded1e..dfcdf3991c 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -93,6 +93,8 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", + "//tensorflow/contrib/tensor_forest:model_ops_kernels", + "//tensorflow/contrib/tensor_forest:stats_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", ], @@ -110,6 +112,8 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", + "//tensorflow/contrib/tensor_forest:model_ops_op_lib", + "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index f1b01250a4..a9defc1139 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -85,6 +85,8 @@ GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_model "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/model_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}") GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 84a1302344..d17fcf6456 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -659,6 +659,10 @@ GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_tensor_forest_ops.py) GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_hybrid_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/hybrid/ops/gen_training_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_model_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_model_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_tensor_forest_stats_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/tensor_forest/python/ops/gen_stats_ops.py) GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 7b5f9472e7..1ca2d7596b 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -28,37 +28,35 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +# ---------------------------------- V2 ops ------------------------------------------# filegroup( - name = "custom_op_sources", - srcs = glob( - [ - "kernels/*.cc", - "ops/*.cc", - ], - exclude = [ - "kernels/*_test.cc", - "kernels/tree_utils.cc", - ], - ), + name = "v2_op_sources", + srcs = [ + "kernels/best_splits_op.cc", + "kernels/count_extremely_random_stats_op.cc", + "kernels/finished_nodes_op.cc", + "kernels/grow_tree_op.cc", + "kernels/reinterpret_string_to_float_op.cc", + "kernels/sample_inputs_op.cc", + "kernels/scatter_add_ndim_op.cc", + "kernels/tree_predictions_op.cc", + "kernels/update_fertile_slots_op.cc", + ], ) filegroup( - name = "custom_op_headers", - srcs = glob( - [ - "kernels/*.h", - ], - exclude = [ - "kernels/data_spec.h", - "kernels/tree_utils.h", - ], - ), + name = "v2_op_defs", + srcs = [ + "ops/tensor_forest_ops.cc", + ], ) cc_library( - name = "all_ops", - srcs = [":custom_op_sources"], - hdrs = [":custom_op_headers"], + name = "v2_ops", + srcs = [ + ":v2_op_defs", + ":v2_op_sources", + ], deps = [ ":tree_utils", "//tensorflow/core:framework_headers_lib", @@ -105,16 +103,8 @@ tf_gen_op_wrapper_py( tf_custom_op_library( name = "python/ops/_tensor_forest_ops.so", srcs = [ - "kernels/best_splits_op.cc", - "kernels/count_extremely_random_stats_op.cc", - "kernels/finished_nodes_op.cc", - "kernels/grow_tree_op.cc", - "kernels/reinterpret_string_to_float_op.cc", - "kernels/sample_inputs_op.cc", - "kernels/scatter_add_ndim_op.cc", - "kernels/tree_predictions_op.cc", - "kernels/update_fertile_slots_op.cc", - "ops/tensor_forest_ops.cc", + ":v2_op_defs", + ":v2_op_sources", ], deps = [":tree_utils"], ) @@ -131,7 +121,9 @@ py_library( ":constants", ":data_ops_py", ":eval_metrics", + ":model_ops_py", ":random_forest", + ":stats_ops_py", ":tensor_forest_ops_py", ":tensor_forest_py", ], @@ -140,21 +132,11 @@ py_library( tf_kernel_library( name = "tensor_forest_kernels", srcs = [ - "kernels/best_splits_op.cc", - "kernels/count_extremely_random_stats_op.cc", - "kernels/finished_nodes_op.cc", - "kernels/grow_tree_op.cc", - "kernels/reinterpret_string_to_float_op.cc", - "kernels/sample_inputs_op.cc", - "kernels/scatter_add_ndim_op.cc", - "kernels/tree_predictions_op.cc", - "kernels/update_fertile_slots_op.cc", + ":v2_op_sources", ], deps = [ ":tree_utils", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + "//tensorflow/core:framework_headers_lib", "//tensorflow/core/kernels:bounds_check", ], ) @@ -181,6 +163,192 @@ tf_custom_op_py_library( ], ) +cc_test( + name = "tensor_forest_ops_test", + size = "small", + srcs = [ + "kernels/tensor_forest_ops_test.cc", + ":v2_op_defs", + ":v2_op_sources", + ], + deps = [ + ":tree_utils", + "//tensorflow/core", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//third_party/eigen3", + ], +) + +# -------------------------------------- V4 ops ------------------------------- # +cc_library( + name = "tensor_forest_v4_kernels", + deps = [ + ":model_ops_kernels", + ":stats_ops_kernels", + ], +) + +cc_library( + name = "tensor_forest_v4_ops_op_lib", + deps = [ + ":model_ops_op_lib", + ":stats_ops_op_lib", + ], +) + +py_library( + name = "tensor_forest_v4_ops_py", + srcs_version = "PY2AND3", + deps = [ + ":model_ops_py", + ":stats_ops_py", + ], +) + +# Model Ops. +cc_library( + name = "model_ops_lib", + srcs = ["kernels/model_ops.cc"], + deps = [ + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", + "//tensorflow/contrib/tensor_forest:tree_utils", + "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource", + "//tensorflow/contrib/tensor_forest/kernels/v4:input_data", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +tf_gen_op_libs( + op_lib_names = ["model_ops"], +) + +tf_gen_op_wrapper_py( + name = "gen_model_ops_py", + out = "python/ops/gen_model_ops.py", + deps = [":model_ops_op_lib"], +) + +tf_kernel_library( + name = "model_ops_kernels", + deps = [ + ":model_ops_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], + alwayslink = 1, +) + +tf_custom_op_library( + name = "python/ops/_model_ops.so", + deps = [ + ":model_ops_lib", + ], +) + +tf_custom_op_py_library( + name = "model_ops_py", + srcs = ["python/ops/model_ops.py"], + dso = ["python/ops/_model_ops.so"], + kernels = [ + ":model_ops_kernels", + ":model_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_model_ops_py", + ":stats_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +# Stats Ops. +cc_library( + name = "stats_ops_lib", + srcs = ["kernels/stats_ops.cc"], + deps = [ + "//tensorflow/contrib/tensor_forest:tree_utils", + "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource", + "//tensorflow/contrib/tensor_forest/kernels/v4:fertile-stats-resource", + "//tensorflow/contrib/tensor_forest/kernels/v4:input_data", + "//tensorflow/contrib/tensor_forest/kernels/v4:input_target", + "//tensorflow/contrib/tensor_forest/kernels/v4:params", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + +tf_gen_op_libs( + op_lib_names = ["stats_ops"], +) + +tf_gen_op_wrapper_py( + name = "gen_stats_ops_py", + out = "python/ops/gen_stats_ops.py", + deps = [":stats_ops_op_lib"], +) + +tf_kernel_library( + name = "stats_ops_kernels", + deps = [ + ":stats_ops_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], + alwayslink = 1, +) + +tf_custom_op_library( + name = "python/ops/_stats_ops.so", + deps = [ + ":stats_ops_lib", + ], +) + +tf_custom_op_py_library( + name = "stats_ops_py", + srcs = ["python/ops/stats_ops.py"], + dso = ["python/ops/_stats_ops.so"], + kernels = [ + ":stats_ops_kernels", + ":stats_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_stats_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +# ---------------------------------- Common libs ------------------------ # +cc_library( + name = "tree_utils", + srcs = ["kernels/tree_utils.cc"], + hdrs = [ + "kernels/data_spec.h", + "kernels/tree_utils.h", + ], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf//:protobuf_headers", + ], +) + +# --------------------------------- Python -------------------------------- # + py_library( name = "eval_metrics", srcs = ["client/eval_metrics.py"], @@ -220,20 +388,6 @@ py_library( ], ) -cc_library( - name = "tree_utils", - srcs = ["kernels/tree_utils.cc"], - hdrs = [ - "kernels/data_spec.h", - "kernels/tree_utils.h", - ], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf//:protobuf_headers", - ], -) - py_test( name = "best_splits_op_test", size = "small", @@ -380,25 +534,6 @@ py_test( ], ) -cc_test( - name = "tensor_forest_ops_test", - size = "small", - srcs = [ - "kernels/tensor_forest_ops_test.cc", - ":custom_op_sources", - ], - deps = [ - ":tree_utils", - "//tensorflow/core", - "//tensorflow/core:framework", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//third_party/eigen3", - ], -) - py_library( name = "random_forest", srcs = ["client/random_forest.py"], diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc new file mode 100644 index 0000000000..195221a48e --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -0,0 +1,299 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_handle.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorforest { + +// Creates a tree variable. +class CreateTreeVariableOp : public OpKernel { + public: + explicit CreateTreeVariableOp(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + const Tensor* tree_config_t; + OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()), + errors::InvalidArgument("Tree config must be a scalar.")); + + auto* result = new DecisionTreeResource(); + if (!ParseProtoUnlimited(result->mutable_decision_tree(), + tree_config_t->scalar<string>()())) { + result->Unref(); + OP_REQUIRES(context, false, + errors::InvalidArgument("Unable to parse tree config.")); + } + + result->MaybeInitialize(); + + // Only create one, if one does not exist already. Report status for all + // other exceptions. + auto status = CreateResource(context, HandleFromInput(context, 0), result); + if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { + OP_REQUIRES(context, false, status); + } + } + + private: + TensorForestParams param_proto_; +}; + +// Op for serializing a model. +class TreeSerializeOp : public OpKernel { + public: + explicit TreeSerializeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + DecisionTreeResource* decision_tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + Tensor* output_config_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape(), &output_config_t)); + output_config_t->scalar<string>()() = + decision_tree_resource->decision_tree().SerializeAsString(); + } +}; + +// Op for deserializing a tree variable from a checkpoint. +class TreeDeserializeOp : public OpKernel { + public: + explicit TreeDeserializeOp(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + DecisionTreeResource* decision_tree_resource; + auto handle = HandleFromInput(context, 0); + OP_REQUIRES_OK(context, LookupResource(context, handle, + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + + const Tensor* tree_config_t; + OP_REQUIRES_OK(context, context->input("tree_config", &tree_config_t)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()), + errors::InvalidArgument("Tree config must be a scalar.")); + // Deallocate all the previous objects on the resource. + decision_tree_resource->Reset(); + decision_trees::Model* config = + decision_tree_resource->mutable_decision_tree(); + OP_REQUIRES(context, + ParseProtoUnlimited(config, tree_config_t->scalar<string>()()), + errors::InvalidArgument("Unable to parse tree config.")); + decision_tree_resource->MaybeInitialize(); + } + + private: + TensorForestParams param_proto_; +}; + +// Op for getting tree size. +class TreeSizeOp : public OpKernel { + public: + explicit TreeSizeOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + DecisionTreeResource* decision_tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + Tensor* output_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape(), &output_t)); + output_t->scalar<int32>()() = + decision_tree_resource->decision_tree().decision_tree().nodes_size(); + } +}; + + +// Op for tree inference. +class TreePredictionsV4Op : public OpKernel { + public: + explicit TreePredictionsV4Op(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + + string serialized_proto; + OP_REQUIRES_OK(context, context->GetAttr( + "input_spec", &serialized_proto)); + input_spec_.ParseFromString(serialized_proto); + + data_set_ = + std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0)); + + model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_data = context->input(1); + const Tensor& sparse_input_indices = context->input(2); + const Tensor& sparse_input_values = context->input(3); + + data_set_->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values); + + DecisionTreeResource* decision_tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + + Tensor* output_predictions = nullptr; + TensorShape output_shape; + output_shape.AddDim(data_set_->NumItems()); + output_shape.AddDim(param_proto_.num_outputs()); + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, + &output_predictions)); + + auto out = output_predictions->tensor<float, 2>(); + for (int i = 0; i < data_set_->NumItems(); ++i) { + const int32 leaf_id = + decision_tree_resource->TraverseTree(data_set_, i, nullptr); + const decision_trees::Leaf& leaf = + decision_tree_resource->get_leaf(leaf_id); + for (int j = 0; j < param_proto_.num_outputs(); ++j) { + const float count = model_op_->GetOutputValue(leaf, j); + out(i, j) = count; + } + } + } + + private: + tensorforest::TensorForestDataSpec input_spec_; + std::unique_ptr<TensorDataSet> data_set_; + std::unique_ptr<LeafModelOperator> model_op_; + TensorForestParams param_proto_; +}; + +// Op for getting feature usage counts. +class FeatureUsageCountsOp : public OpKernel { + public: + explicit FeatureUsageCountsOp(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + DecisionTreeResource* decision_tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &decision_tree_resource)); + mutex_lock l(*decision_tree_resource->get_mutex()); + core::ScopedUnref unref_me(decision_tree_resource); + + + const auto& tree = decision_tree_resource->decision_tree(); + + Tensor* output_counts = nullptr; + TensorShape output_shape; + output_shape.AddDim(param_proto_.num_features()); + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &output_counts)); + + auto counts = output_counts->unaligned_flat<int32>(); + counts.setZero(); + + for (const auto& node : tree.decision_tree().nodes()) { + if (node.has_custom_node_type()) { + LOG(WARNING) << "Can't count feature usage for custom nodes."; + } else if (node.has_binary_node()) { + const auto& bnode = node.binary_node(); + if (bnode.has_custom_left_child_test()) { + decision_trees::MatchingValuesTest test; + if (!bnode.custom_left_child_test().UnpackTo(&test)) { + LOG(WARNING) << "Unknown custom child test"; + continue; + } + int32 feat; + safe_strto32(test.feature_id().id().value(), &feat); + ++counts(feat); + } else { + const auto& test = bnode.inequality_left_child_test(); + if (test.has_feature_id()) { + int32 feat; + safe_strto32(test.feature_id().id().value(), &feat); + ++counts(feat); + } else if (test.has_oblique()) { + for (const auto& featid : test.oblique().features()) { + int32 feat; + safe_strto32(featid.id().value(), &feat); + ++counts(feat); + } + } + } + } + } + } + + private: + TensorForestParams param_proto_; +}; + + +REGISTER_RESOURCE_HANDLE_KERNEL(DecisionTreeResource); + +REGISTER_KERNEL_BUILDER(Name("TreeIsInitializedOp").Device(DEVICE_CPU), + IsResourceInitialized<DecisionTreeResource>); + +REGISTER_KERNEL_BUILDER(Name("CreateTreeVariable").Device(DEVICE_CPU), + CreateTreeVariableOp); + +REGISTER_KERNEL_BUILDER(Name("TreeSerialize").Device(DEVICE_CPU), + TreeSerializeOp); + +REGISTER_KERNEL_BUILDER(Name("TreeDeserialize").Device(DEVICE_CPU), + TreeDeserializeOp); + +REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), + TreeSizeOp); + +REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU), + TreePredictionsV4Op); + +REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU), + FeatureUsageCountsOp); + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc new file mode 100644 index 0000000000..7442469507 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -0,0 +1,564 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include <queue> + +#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_handle.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +namespace tensorforest { + +using gtl::FindOrNull; + +// Creates a stats variable. +class CreateFertileStatsVariableOp : public OpKernel { + public: + explicit CreateFertileStatsVariableOp(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + const Tensor* stats_config_t; + OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()), + errors::InvalidArgument("Stats config must be a scalar.")); + auto* result = new FertileStatsResource(param_proto_); + FertileStats stats; + if (!ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()())) { + result->Unref(); + OP_REQUIRES(context, false, + errors::InvalidArgument("Unable to parse stats config.")); + } + + result->ExtractFromProto(stats); + result->MaybeInitialize(); + + // Only create one, if one does not exist already. Report status for all + // other exceptions. + auto status = CreateResource(context, HandleFromInput(context, 0), result); + if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { + OP_REQUIRES(context, false, status); + } + } + + private: + TensorForestParams param_proto_; +}; + +// Op for serializing a model. +class FertileStatsSerializeOp : public OpKernel { + public: + explicit FertileStatsSerializeOp(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + FertileStatsResource* fertile_stats_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &fertile_stats_resource)); + mutex_lock l(*fertile_stats_resource->get_mutex()); + core::ScopedUnref unref_me(fertile_stats_resource); + Tensor* output_config_t = nullptr; + OP_REQUIRES_OK( + context, context->allocate_output(0, TensorShape(), &output_config_t)); + + FertileStats stats; + fertile_stats_resource->PackToProto(&stats); + output_config_t->scalar<string>()() = stats.SerializeAsString(); + } + + private: + TensorForestParams param_proto_; +}; + +// Op for deserializing a stats variable from a checkpoint. +class FertileStatsDeserializeOp : public OpKernel { + public: + explicit FertileStatsDeserializeOp(OpKernelConstruction* context) + : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + FertileStatsResource* fertile_stats_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &fertile_stats_resource)); + mutex_lock l(*fertile_stats_resource->get_mutex()); + core::ScopedUnref unref_me(fertile_stats_resource); + + const Tensor* stats_config_t; + OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t)); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()), + errors::InvalidArgument("Stats config must be a scalar.")); + // Deallocate all the previous objects on the resource. + fertile_stats_resource->Reset(); + FertileStats stats; + OP_REQUIRES(context, + ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()()), + errors::InvalidArgument("Unable to parse stats config.")); + + fertile_stats_resource->ExtractFromProto(stats); + fertile_stats_resource->MaybeInitialize(); + } + + private: + TensorForestParams param_proto_; +}; + +void TraverseTree(const DecisionTreeResource* tree_resource, + const std::unique_ptr<TensorDataSet>& data, int32 start, + int32 end, std::vector<int32>* leaf_ids, + std::vector<int32>* leaf_depths) { + for (int i = start; i < end; ++i) { + int32 depth; + const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth); + (*leaf_ids)[i] = leaf_id; + (*leaf_depths)[i] = depth; + } +} + +// Try to update a leaf's stats by acquiring its lock. If it can't be +// acquired, put it in a waiting queue to come back to later and try the next +// one. Once all leaf_ids have been visited, cycle through the waiting ids +// until they're gone. +void UpdateStats(FertileStatsResource* fertile_stats_resource, + const std::unique_ptr<TensorDataSet>& data, + const Tensor& input_labels, const Tensor& input_weights, + int num_targets, const std::vector<int32>& leaf_ids, + const std::vector<int32>& leaf_depths, + std::unordered_map<int32, std::unique_ptr<mutex>>* locks, + mutex* set_lock, int32 start, int32 end, + std::unordered_set<int32>* ready_to_split) { + const auto labels = input_labels.unaligned_flat<float>(); + const auto weights = input_weights.unaligned_flat<float>(); + // Stores leaf_id, leaf_depth, example_id for examples that are waiting + // on another to finish. + std::queue<std::tuple<int32, int32, int32>> waiting; + + int32 i = start; + TensorInputTarget target(&labels, &weights, input_labels, num_targets); + while (i < end || !waiting.empty()) { + int32 leaf_id; + int32 leaf_depth; + int32 example_id; + bool was_waiting = false; + if (i >= end) { + std::tie(leaf_id, leaf_depth, example_id) = waiting.front(); + waiting.pop(); + was_waiting = true; + } else { + leaf_id = leaf_ids[i]; + leaf_depth = leaf_depths[i]; + example_id = i; + ++i; + } + const std::unique_ptr<mutex>& leaf_lock = (*locks)[leaf_id]; + if (was_waiting) { + leaf_lock->lock(); + } else { + if (!leaf_lock->try_lock()) { + waiting.emplace(leaf_id, leaf_depth, example_id); + continue; + } + } + + bool is_finished; + fertile_stats_resource->AddExampleToStatsAndInitialize( + data, &target, {example_id}, leaf_id, leaf_depth, + &is_finished); + leaf_lock->unlock(); + if (is_finished) { + set_lock->lock(); + ready_to_split->insert(leaf_id); + set_lock->unlock(); + } + } +} + +// Update leaves from start through end in the leaf_examples iterator. +void UpdateStatsCollated( + FertileStatsResource* fertile_stats_resource, + DecisionTreeResource* tree_resource, + const std::unique_ptr<TensorDataSet>& data, const Tensor& input_labels, + const Tensor& input_weights, int num_targets, + const std::unordered_map<int32, std::vector<int>>& leaf_examples, + const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start, + int32 end, std::unordered_set<int32>* ready_to_split) { + const auto labels = input_labels.unaligned_flat<float>(); + const auto weights = input_weights.unaligned_flat<float>(); + + TensorInputTarget target(&labels, &weights, input_labels, num_targets); + auto it = leaf_examples.begin(); + std::advance(it, start); + auto end_it = leaf_examples.begin(); + std::advance(end_it, end); + while (it != end_it) { + int32 leaf_id = it->first; + bool is_finished; + fertile_stats_resource->AddExampleToStatsAndInitialize( + data, &target, it->second, leaf_id, leaf_depths[it->second[0]], + &is_finished); + if (is_finished) { + set_lock->lock(); + ready_to_split->insert(leaf_id); + set_lock->unlock(); + } + ++it; + } +} + +// Op for traversing the tree with each example, accumulating statistics, and +// outputting node ids that are ready to split. +class ProcessInputOp : public OpKernel { + public: + explicit ProcessInputOp(OpKernelConstruction* context) : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + + OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); + + string serialized_proto; + OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); + input_spec_.ParseFromString(serialized_proto); + + data_set_ = std::unique_ptr<TensorDataSet>( + new TensorDataSet(input_spec_, random_seed_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_data = context->input(2); + const Tensor& sparse_input_indices = context->input(3); + const Tensor& sparse_input_values = context->input(4); + const Tensor& input_labels = context->input(6); + const Tensor& input_weights = context->input(7); + + data_set_->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values); + + FertileStatsResource* fertile_stats_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), + &fertile_stats_resource)); + DecisionTreeResource* tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &tree_resource)); + mutex_lock l1(*fertile_stats_resource->get_mutex()); + mutex_lock l2(*tree_resource->get_mutex()); + + core::ScopedUnref unref_stats(fertile_stats_resource); + core::ScopedUnref unref_tree(tree_resource); + + const int32 num_data = data_set_->NumItems(); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + int num_threads = worker_threads->num_threads; + + // First find the leaf ids for each example. + std::vector<int32> leaf_ids(num_data); + + // The depth of the leaf for example i. + std::vector<int32> leaf_depths(num_data); + + const int64 costPerTraverse = 500; + auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data]( + int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + TraverseTree(tree_resource, data_set_, static_cast<int32>(start), + static_cast<int32>(end), &leaf_ids, &leaf_depths); + }; + Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, + traverse); + + // Create one mutex per leaf. We need to protect access to leaf pointers, + // so instead of grouping examples by leaf, we spread examples out among + // threads to provide uniform work for each of them and protect access + // with mutexes. + std::unordered_map<int, std::unique_ptr<mutex>> locks; + std::unordered_map<int32, std::vector<int>> leaf_examples; + if (param_proto_.collate_examples()) { + for (int i = 0; i < num_data; ++i) { + leaf_examples[leaf_ids[i]].push_back(i); + } + } else { + for (const int32 id : leaf_ids) { + if (FindOrNull(locks, id) == nullptr) { + // TODO(gilberth): Consider using a memory pool for these. + locks[id] = std::unique_ptr<mutex>(new mutex); + } + } + } + + const int32 num_leaves = leaf_examples.size(); + const int32 label_dim = + input_labels.shape().dims() <= 1 + ? 0 + : static_cast<int>(input_labels.shape().dim_size(1)); + const int32 num_targets = + param_proto_.is_regression() ? (std::max(1, label_dim)) : 1; + + // Ids of leaves that can split. + std::unordered_set<int32> ready_to_split; + mutex set_lock; + + // TODO(gilberth): This is a rough approximation based on measurements + // from a digits run on local desktop. Heuristics might be necessary + // if it really matters that much. + const int64 costPerUpdate = 1000; + auto update = [this, &input_labels, &input_weights, &leaf_ids, &leaf_depths, + &num_targets, fertile_stats_resource, &locks, &set_lock, + &ready_to_split, num_data](int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + UpdateStats(fertile_stats_resource, data_set_, input_labels, + input_weights, num_targets, leaf_ids, leaf_depths, &locks, + &set_lock, static_cast<int32>(start), static_cast<int32>(end), + &ready_to_split); + }; + + auto update_collated = [this, &input_labels, &input_weights, &leaf_ids, + &num_targets, &leaf_depths, fertile_stats_resource, + tree_resource, &leaf_examples, &set_lock, + &ready_to_split, + num_leaves](int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_leaves); + UpdateStatsCollated( + fertile_stats_resource, tree_resource, data_set_, input_labels, + input_weights, num_targets, leaf_examples, leaf_depths, &set_lock, + static_cast<int32>(start), static_cast<int32>(end), &ready_to_split); + }; + + if (param_proto_.collate_examples()) { + Shard(num_threads, worker_threads->workers, num_leaves, costPerUpdate, + update_collated); + } else { + Shard(num_threads, worker_threads->workers, num_data, costPerUpdate, + update); + } + + Tensor* output_finished_t = nullptr; + TensorShape output_shape; + output_shape.AddDim(ready_to_split.size()); + OP_REQUIRES_OK( + context, context->allocate_output(0, output_shape, &output_finished_t)); + auto output = output_finished_t->unaligned_flat<int32>(); + std::copy(ready_to_split.begin(), ready_to_split.end(), output.data()); + } + + private: + int32 random_seed_; + tensorforest::TensorForestDataSpec input_spec_; + std::unique_ptr<TensorDataSet> data_set_; + TensorForestParams param_proto_; +}; + + +// Op for growing finished nodes. +class GrowTreeOp : public OpKernel { + public: + explicit GrowTreeOp(OpKernelConstruction* context) : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + } + + void Compute(OpKernelContext* context) override { + FertileStatsResource* fertile_stats_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), + &fertile_stats_resource)); + DecisionTreeResource* tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &tree_resource)); + mutex_lock l1(*fertile_stats_resource->get_mutex()); + mutex_lock l2(*tree_resource->get_mutex()); + + core::ScopedUnref unref_stats(fertile_stats_resource); + core::ScopedUnref unref_tree(tree_resource); + + const Tensor& finished_nodes = context->input(2); + + const auto finished = finished_nodes.unaligned_flat<int32>(); + + const int32 num_nodes = + static_cast<int32>(finished_nodes.shape().dim_size(0)); + + // TODO(gilberth): distribute this work over a number of threads. + for (int i = 0; + i < num_nodes && + tree_resource->decision_tree().decision_tree().nodes_size() < + param_proto_.max_nodes(); + ++i) { + const int32 node = finished(i); + std::unique_ptr<SplitCandidate> best(new SplitCandidate); + int32 parent_depth; + bool found = + fertile_stats_resource->BestSplit(node, best.get(), &parent_depth); + if (found) { + std::vector<int32> new_children; + tree_resource->SplitNode(node, best.get(), &new_children); + fertile_stats_resource->Allocate(parent_depth, new_children); + fertile_stats_resource->set_leaf_stat(best->left_stats(), + new_children[0]); + fertile_stats_resource->set_leaf_stat(best->right_stats(), + new_children[1]); + // We are done with best, so it is now safe to clear node. + fertile_stats_resource->Clear(node); + CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false); + } else { // reset + fertile_stats_resource->ResetSplitStats(node, parent_depth); + } + } + } + + private: + tensorforest::TensorForestDataSpec input_spec_; + TensorForestParams param_proto_; +}; + +void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression, + bool drop_final_class, + const std::unique_ptr<LeafModelOperator>& leaf_op, + decision_trees::Leaf* leaf) { + leaf_op->ExportModel(leaf_stats, leaf); + + // TODO(thomaswc): Move the rest of this into ExportModel. + + // regression models are already stored in leaf in normalized form. + if (is_regression) { + return; + } + + float sum = leaf_stats.weight_sum(); + if (sum <= 0.0) { + LOG(WARNING) << "Leaf with sum " << sum + << " has stats " << leaf->ShortDebugString(); + return; + } + + if (leaf->has_vector()) { + for (int i = 0; i < leaf->vector().value_size(); i++) { + auto *v = leaf->mutable_vector()->mutable_value(i); + v->set_float_value(v->float_value() / sum); + } + if (drop_final_class) { + leaf->mutable_vector()->mutable_value()->RemoveLast(); + } + return; + } + + if (leaf->has_sparse_vector()) { + for (auto& it : *leaf->mutable_sparse_vector()->mutable_sparse_value()) { + it.second.set_float_value(it.second.float_value() / sum); + } + return; + } + + LOG(FATAL) << "Unknown leaf type in " << leaf->DebugString(); +} + +// Op for finalizing a tree at the end of training. +class FinalizeTreeOp : public OpKernel { + public: + explicit FinalizeTreeOp(OpKernelConstruction* context) : OpKernel(context) { + string serialized_params; + OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params)); + ParseProtoUnlimited(¶m_proto_, serialized_params); + + model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_); + } + + void Compute(OpKernelContext* context) override { + DecisionTreeResource* tree_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &tree_resource)); + FertileStatsResource* fertile_stats_resource; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), + &fertile_stats_resource)); + + mutex_lock l1(*fertile_stats_resource->get_mutex()); + mutex_lock l2(*tree_resource->get_mutex()); + + core::ScopedUnref unref_me(tree_resource); + core::ScopedUnref unref_stats(fertile_stats_resource); + + // TODO(thomaswc): Add threads + int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size(); + for (int i = 0; i < num_nodes; i++) { + auto* node = tree_resource->mutable_decision_tree() + ->mutable_decision_tree()->mutable_nodes(i); + if (node->has_leaf()) { + const auto& leaf_stats = fertile_stats_resource->leaf_stat(i); + FinalizeLeaf(leaf_stats, param_proto_.is_regression(), + param_proto_.drop_final_class(), model_op_, + node->mutable_leaf()); + } + } + } + + private: + std::unique_ptr<LeafModelOperator> model_op_; + TensorForestParams param_proto_; +}; + +REGISTER_RESOURCE_HANDLE_KERNEL(FertileStatsResource); + +REGISTER_KERNEL_BUILDER(Name("FertileStatsIsInitializedOp").Device(DEVICE_CPU), + IsResourceInitialized<FertileStatsResource>); + +REGISTER_KERNEL_BUILDER(Name("CreateFertileStatsVariable").Device(DEVICE_CPU), + CreateFertileStatsVariableOp); + +REGISTER_KERNEL_BUILDER(Name("FertileStatsSerialize").Device(DEVICE_CPU), + FertileStatsSerializeOp); + +REGISTER_KERNEL_BUILDER(Name("FertileStatsDeserialize").Device(DEVICE_CPU), + FertileStatsDeserializeOp); + +REGISTER_KERNEL_BUILDER(Name("ProcessInputV4").Device(DEVICE_CPU), + ProcessInputOp); + +REGISTER_KERNEL_BUILDER(Name("GrowTreeV4").Device(DEVICE_CPU), + GrowTreeOp); + +REGISTER_KERNEL_BUILDER(Name("FinalizeTree").Device(DEVICE_CPU), + FinalizeTreeOp); + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD index 0542508a8e..a9d8093d13 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD +++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD @@ -40,9 +40,7 @@ cc_library( ":split_collection_operators", "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", ], ) @@ -111,7 +109,7 @@ cc_library( "//tensorflow/contrib/tensor_forest:tree_utils", "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", - "//tensorflow/core:lib", + "//tensorflow/core:framework_headers_lib", ], ) @@ -153,7 +151,7 @@ cc_library( ":input_data", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", - "//tensorflow/core:lib", + "//tensorflow/core:framework_headers_lib", ], ) @@ -175,12 +173,32 @@ cc_library( srcs = ["split_collection_operators.cc"], hdrs = ["split_collection_operators.h"], deps = [ + ":grow_stats", + ":input_data", + ":input_target", + ":leaf_model_operators", + ":params", + ":stat_utils", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", + "//tensorflow/contrib/tensor_forest:tree_utils", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + ], +) + +cc_library( + name = "graph_collection_operator", + srcs = ["graph_collection_operator.cc"], + hdrs = ["graph_collection_operator.h"], + deps = [ ":candidate_graph_runner", ":grow_stats", ":input_data", ":input_target", ":leaf_model_operators", ":params", + ":split_collection_operators", ":stat_utils", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc new file mode 100644 index 0000000000..2c925b5dd7 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.cc @@ -0,0 +1,142 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h" + +#include <cfloat> + +#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" + +namespace tensorflow { +namespace tensorforest { + +REGISTER_SPLIT_COLLECTION(GRAPH_RUNNER_COLLECTION, + GraphRunnerSplitCollectionOperator); + +std::unique_ptr<GrowStats> GraphRunnerSplitCollectionOperator::CreateGrowStats( + int32 node_id, int32 depth) const { + return std::unique_ptr<GrowStats>(new SimpleStats(params_, depth)); +} + +int64 GraphRunnerSplitCollectionOperator::UniqueId(int32 node_id, + int32 split_id) const { + return node_id * num_splits_to_consider_ + split_id; +} + +bool GraphRunnerSplitCollectionOperator::BestSplit(int32 node_id, + SplitCandidate* best, + int32* depth) const { + float min_score = FLT_MAX; + int best_index = -1; + auto* slot = stats_.at(node_id).get(); + *depth = slot->depth(); + for (int i = 0; i < slot->num_splits(); ++i) { + // TODO(gilberth): Support uselessness. + auto& runner = runners_[UniqueId(node_id, i)]; + const float split_score = runner->SplitScore(); + if (split_score < min_score) { + min_score = split_score; + best_index = i; + } + } + + // This could happen if all the splits are useless. + if (best_index < 0) { + return false; + } + + // Fill in split info and left/right stats to initialize models with. + *best = SplitCandidate(); + auto& runner = runners_[UniqueId(node_id, best_index)]; + runner->GetLeftStats(best->mutable_left_stats()); + runner->GetRightStats(best->mutable_right_stats()); + runner->GetSplit(best->mutable_split()); + return true; +} + +void GraphRunnerSplitCollectionOperator::AddExample( + const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, + const std::vector<int>& examples, int32 node_id) const { + // Build input Tensors. + int size = examples.size(); + Tensor examples_t(tensorflow::DT_INT32, TensorShape({size})); + auto ex_data = examples_t.flat<int32>(); + std::copy(examples.begin(), examples.end(), ex_data.data()); + + const TensorInputTarget* tensor_target = + dynamic_cast<const TensorInputTarget*>(target); + CHECK_NOTNULL(tensor_target); + + const Tensor& data_t = input_data->original_tensor(); + const Tensor& target_t = tensor_target->original_tensor(); + + // Add to candidates. + auto* slot = stats_.at(node_id).get(); + for (int i = 0; i < slot->num_splits(); ++i) { + auto& runner = runners_[UniqueId(node_id, i)]; + runner->AddExample(data_t, target_t, examples_t); + } + + // Update simple weight sums so we know when we're done. + for (int example : examples) { + slot->AddExample(input_data, target, example); + } +} + +void GraphRunnerSplitCollectionOperator:: + CreateAndInitializeCandidateWithExample( + const std::unique_ptr<TensorDataSet>& input_data, int example, + int32 node_id) const { + auto* slot = stats_.at(node_id).get(); + int cand_num = slot->num_splits(); + const int64 unique_id = UniqueId(node_id, cand_num); + + decision_trees::BinaryNode split; + + decision_trees::InequalityTest* test = + split.mutable_inequality_left_child_test(); + auto* oblique = test->mutable_oblique(); + for (int i = 0; i < features_per_node_; ++i) { + float bias; + int type; + // This is really just a way to select a list of random features. + // Also a way to warn the user that categoricals don't make sense here. + input_data->RandomSample(example, oblique->add_features(), &bias, &type); + + if (type == kDataFloat) { + test->set_type(decision_trees::InequalityTest::LESS_OR_EQUAL); + + // The comparison bias is assumed to be zero. + test->mutable_threshold()->set_float_value(0); + } else { + LOG(ERROR) << "Categorical features not supported with this system."; + return; + } + } + + slot->AddSplit(split); + + runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split)); + runners_[unique_id]->Init(); +} + +void GraphRunnerSplitCollectionOperator::ClearSlot(int32 node_id) { + SplitCollectionOperator::ClearSlot(node_id); + for (int i = 0; i < num_splits_to_consider_; ++i) { + runners_.erase(UniqueId(node_id, i)); + } +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h new file mode 100644 index 0000000000..9b18e3e969 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/graph_collection_operator.h @@ -0,0 +1,81 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ + +#include <vector> +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" + +namespace tensorflow { +namespace tensorforest { + +// Holds split candidates that are trained by running any TF graph. +class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator { + public: + explicit GraphRunnerSplitCollectionOperator(const TensorForestParams& params) + : SplitCollectionOperator(params) { + if (params.num_splits_to_consider().ParamType_case() == + DepthDependentParam::PARAMTYPE_NOT_SET) { + LOG(FATAL) << "GRAPH_RUNNER_COLLECTION must specify a constant value for " + << " num_splits_to_consider"; + } else { + num_splits_to_consider_ = + params.num_splits_to_consider().constant_value(); + } + } + + std::unique_ptr<GrowStats> CreateGrowStats(int32 node_id, + int32 depth) const override; + + // Updates the slot's candidates with the new example. + // Assumes slot has been initialized. + void AddExample(const std::unique_ptr<TensorDataSet>& input_data, + const InputTarget* target, const std::vector<int>& examples, + int32 node_id) const override; + + // Create a new candidate and initialize it with the given example. + void CreateAndInitializeCandidateWithExample( + const std::unique_ptr<TensorDataSet>& input_data, int example, + int32 node_id) const override; + + bool BestSplit(int32 node_id, SplitCandidate* best, + int32* depth) const override; + + void ClearSlot(int32 node_id) override; + + protected: + int64 UniqueId(int32 node_id, int32 split_id) const; + + mutable std::unordered_map<int64, std::unique_ptr<CandidateGraphRunner>> + runners_; + int features_per_node_; + string graph_dir_; + // Must have a constant value because of how we make unique ids right now. + int32 num_splits_to_consider_; +}; + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GRAPH_COLLECTION_OPERATOR_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc index ddf4be8799..c207c0859d 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -23,22 +23,20 @@ namespace tensorflow { namespace tensorforest { +std::unordered_map<int, CollectionCreator*> + SplitCollectionOperatorFactory::factories_; // NOLINT +REGISTER_SPLIT_COLLECTION(COLLECTION_BASIC, SplitCollectionOperator); + std::unique_ptr<SplitCollectionOperator> SplitCollectionOperatorFactory::CreateSplitCollectionOperator( const TensorForestParams& params) { - switch (params.collection_type()) { - case COLLECTION_BASIC: - return std::unique_ptr<SplitCollectionOperator>( - new SplitCollectionOperator(params)); - - case GRAPH_RUNNER_COLLECTION: - return std::unique_ptr<SplitCollectionOperator>( - new GraphRunnerSplitCollectionOperator(params)); - - default: - LOG(ERROR) << "Unknown split collection operator: " - << params.collection_type(); - return nullptr; + auto it = factories_.find(params.collection_type()); + if (it == factories_.end()) { + LOG(ERROR) << "Unknown split collection operator: " + << params.collection_type(); + return nullptr; + } else { + return it->second->Create(params); } } @@ -137,121 +135,5 @@ bool SplitCollectionOperator::BestSplit(int32 node_id, return slot->BestSplit(best); } -// -------------------------------- GraphRunner ------------------ // - -std::unique_ptr<GrowStats> GraphRunnerSplitCollectionOperator::CreateGrowStats( - int32 node_id, int32 depth) const { - return std::unique_ptr<GrowStats>(new SimpleStats(params_, depth)); -} - -int64 GraphRunnerSplitCollectionOperator::UniqueId(int32 node_id, - int32 split_id) const { - return node_id * num_splits_to_consider_ + split_id; -} - -bool GraphRunnerSplitCollectionOperator::BestSplit(int32 node_id, - SplitCandidate* best, - int32* depth) const { - float min_score = FLT_MAX; - int best_index = -1; - auto* slot = stats_.at(node_id).get(); - *depth = slot->depth(); - for (int i = 0; i < slot->num_splits(); ++i) { - // TODO(gilberth): Support uselessness. - auto& runner = runners_[UniqueId(node_id, i)]; - const float split_score = runner->SplitScore(); - if (split_score < min_score) { - min_score = split_score; - best_index = i; - } - } - - // This could happen if all the splits are useless. - if (best_index < 0) { - return false; - } - - // Fill in split info and left/right stats to initialize models with. - *best = SplitCandidate(); - auto& runner = runners_[UniqueId(node_id, best_index)]; - runner->GetLeftStats(best->mutable_left_stats()); - runner->GetRightStats(best->mutable_right_stats()); - runner->GetSplit(best->mutable_split()); - return true; -} - -void GraphRunnerSplitCollectionOperator::AddExample( - const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, - const std::vector<int>& examples, int32 node_id) const { - // Build input Tensors. - int size = examples.size(); - Tensor examples_t(tensorflow::DT_INT32, TensorShape({size})); - auto ex_data = examples_t.flat<int32>(); - std::copy(examples.begin(), examples.end(), ex_data.data()); - - const TensorInputTarget* tensor_target = - dynamic_cast<const TensorInputTarget*>(target); - CHECK_NOTNULL(tensor_target); - - const Tensor& data_t = input_data->original_tensor(); - const Tensor& target_t = tensor_target->original_tensor(); - - // Add to candidates. - auto* slot = stats_.at(node_id).get(); - for (int i = 0; i < slot->num_splits(); ++i) { - auto& runner = runners_[UniqueId(node_id, i)]; - runner->AddExample(data_t, target_t, examples_t); - } - - // Update simple weight sums so we know when we're done. - for (int example : examples) { - slot->AddExample(input_data, target, example); - } -} - -void GraphRunnerSplitCollectionOperator:: - CreateAndInitializeCandidateWithExample( - const std::unique_ptr<TensorDataSet>& input_data, int example, - int32 node_id) const { - auto* slot = stats_.at(node_id).get(); - int cand_num = slot->num_splits(); - const int64 unique_id = UniqueId(node_id, cand_num); - - decision_trees::BinaryNode split; - - decision_trees::InequalityTest* test = - split.mutable_inequality_left_child_test(); - auto* oblique = test->mutable_oblique(); - for (int i = 0; i < features_per_node_; ++i) { - float bias; - int type; - // This is really just a way to select a list of random features. - // Also a way to warn the user that categoricals don't make sense here. - input_data->RandomSample(example, oblique->add_features(), &bias, &type); - - if (type == kDataFloat) { - test->set_type(decision_trees::InequalityTest::LESS_OR_EQUAL); - - // The comparison bias is assumed to be zero. - test->mutable_threshold()->set_float_value(0); - } else { - LOG(ERROR) << "Categorical features not supported with this system."; - return; - } - } - - slot->AddSplit(split); - - runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split)); - runners_[unique_id]->Init(); -} - -void GraphRunnerSplitCollectionOperator::ClearSlot(int32 node_id) { - SplitCollectionOperator::ClearSlot(node_id); - for (int i = 0; i < num_splits_to_consider_; ++i) { - runners_.erase(UniqueId(node_id, i)); - } -} - } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h index d0ea33612a..81d820a6b2 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -17,7 +17,6 @@ #include <vector> #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" -#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" @@ -86,60 +85,38 @@ class SplitCollectionOperator { std::unordered_map<int32, std::unique_ptr<GrowStats>> stats_; }; - -class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator { +class CollectionCreator { public: - explicit GraphRunnerSplitCollectionOperator(const TensorForestParams& params) - : SplitCollectionOperator(params) { - if (params.num_splits_to_consider().ParamType_case() == - DepthDependentParam::PARAMTYPE_NOT_SET) { - LOG(FATAL) << "GRAPH_RUNNER_COLLECTION must specify a constant value for " - << " num_splits_to_consider"; - } else { - num_splits_to_consider_ = - params.num_splits_to_consider().constant_value(); - } - } - - std::unique_ptr<GrowStats> CreateGrowStats(int32 node_id, - int32 depth) const override; - - // Updates the slot's candidates with the new example. - // Assumes slot has been initialized. - void AddExample(const std::unique_ptr<TensorDataSet>& input_data, - const InputTarget* target, const std::vector<int>& examples, - int32 node_id) const override; - - // Create a new candidate and initialize it with the given example. - void CreateAndInitializeCandidateWithExample( - const std::unique_ptr<TensorDataSet>& input_data, int example, - int32 node_id) const override; - - bool BestSplit(int32 node_id, SplitCandidate* best, - int32* depth) const override; - - void ClearSlot(int32 node_id) override; - - protected: - int64 UniqueId(int32 node_id, int32 split_id) const; - - mutable std::unordered_map<int64, std::unique_ptr<CandidateGraphRunner>> - runners_; - int features_per_node_; - string graph_dir_; - // Must have a constant value because of how we make unique ids right now. - int32 num_splits_to_consider_; + virtual std::unique_ptr<SplitCollectionOperator> Create( + const TensorForestParams& params) = 0; + virtual ~CollectionCreator() {} }; -// Creates a type of SplitCollectionOperator depending on the type passed, -// which is SplitCollectionType in fertile_stats.proto. -// Can create a SplitCollectionOperator itself, known as "basic". class SplitCollectionOperatorFactory { public: static std::unique_ptr<SplitCollectionOperator> CreateSplitCollectionOperator( const TensorForestParams& params); + + static std::unordered_map<int, CollectionCreator*> factories_; +}; + +template <typename T> +class AnyCollectionCreator : public CollectionCreator { + public: + AnyCollectionCreator(SplitCollectionType type) { + SplitCollectionOperatorFactory::factories_[type] = this; + } + virtual std::unique_ptr<SplitCollectionOperator> Create( + const TensorForestParams& params) { + return std::unique_ptr<SplitCollectionOperator>(new T(params)); + } }; +#define REGISTER_SPLIT_COLLECTION(name, cls) \ + namespace { \ + AnyCollectionCreator<cls> creator(name); \ + } + } // namespace tensorforest } // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc new file mode 100644 index 0000000000..1c3c2153a6 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc @@ -0,0 +1,135 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/shape_inference.h" + + +namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; + +namespace tensorforest { + +REGISTER_RESOURCE_HANDLE_OP(DecisionTreeResource); + +REGISTER_OP("TreeIsInitializedOp") + .Input("tree_handle: resource") + .Output("is_initialized: bool") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Checks whether a tree has been initialized. +)doc"); + +REGISTER_OP("CreateTreeVariable") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("tree_config: string") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +Creates a tree model and returns a handle to it. + +params: A serialized TensorForestParams proto. +tree_handle: handle to the tree resource to be created. +tree_config: Serialized proto of the tree. +)doc"); + +REGISTER_OP("TreeSerialize") + .Input("tree_handle: resource") + .Output("tree_config: string") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Serializes the tree to a proto. + +tree_handle: The handle to the tree. +tree_config: Serialized proto of the tree. +)doc"); + +REGISTER_OP("TreeDeserialize") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("tree_config: string") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +Deserializes a serialized tree config and replaces current tree. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree . +tree_config: Serialized proto of the . +)doc"); + +REGISTER_OP("TreeSize") + .Input("tree_handle: resource") + .Output("tree_size: int32") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Outputs the size of the tree, including leaves. + +tree_handle: The handle to the tree. +tree_size: Size scalar. +)doc"); + +REGISTER_OP("TreePredictionsV4") + .Attr("input_spec: string") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("input_data: float") + .Input("sparse_input_indices: int64") + .Input("sparse_input_values: float") + .Input("sparse_input_shape: int64") + .Output("predictions: float") + .SetShapeFn([](InferenceContext* c) { + DimensionHandle num_points = c->UnknownDim(); + + if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0) { + num_points = c->Dim(c->input(1), 0); + } + + c->set_output(0, c->Matrix(num_points, c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Outputs the predictions for the given input data. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` + gives the j-th feature of the i-th input. +sparse_input_indices: The indices tensor from the SparseTensor input. +sparse_input_values: The values tensor from the SparseTensor input. +sparse_input_shape: The shape tensor from the SparseTensor input. +predictions: `predictions[i][j]` is the probability that input i is class j. +)doc"); + +REGISTER_OP("FeatureUsageCounts") + .Attr("params: string") + .Input("tree_handle: resource") + .Output("feature_counts: int32") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Outputs the number of times each feature was used in a split. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +feature_counts: `feature_counts[i]` is the number of times feature i was used + in a split. +)doc"); + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc new file mode 100644 index 0000000000..48e91e3466 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc @@ -0,0 +1,146 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +using shape_inference::InferenceContext; + +namespace tensorforest { + + +REGISTER_RESOURCE_HANDLE_OP(FertileStatsResource); + +REGISTER_OP("FertileStatsIsInitializedOp") + .Input("stats_handle: resource") + .Output("is_initialized: bool") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Checks whether a stats has been initialized. +)doc"); + +REGISTER_OP("CreateFertileStatsVariable") + .Attr("params: string") + .Input("stats_handle: resource") + .Input("stats_config: string") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +Creates a stats model and returns a handle to it. + +params: A serialized TensorForestParams proto. +stats_handle: handle to the stats resource to be created. +stats_config: Serialized proto of the stats. +)doc"); + +REGISTER_OP("FertileStatsSerialize") + .Attr("params: string") + .Input("stats_handle: resource") + .Output("stats_config: string") + .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Doc(R"doc( +Serializes the stats to a proto. + +params: A serialized TensorForestParams proto. +stats_handle: The handle to the stats. +stats_config: Serialized proto of the stats. +)doc"); + +REGISTER_OP("FertileStatsDeserialize") + .Attr("params: string") + .Input("stats_handle: resource") + .Input("stats_config: string") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +Deserializes a serialized stats config and replaces current stats. + +params: A serialized TensorForestParams proto. +stats_handle: The handle to the stats. +stats_config: Serialized proto of the stats. +)doc"); + +REGISTER_OP("GrowTreeV4") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("stats_handle: resource") + .Input("finshed_nodes: int32") + .SetShapeFn(tensorflow::shape_inference::NoOutputs) + .Doc(R"doc( +Grows the tree for finished nodes and allocates waiting nodes. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +stats_handle: The handle to the stats. +finshed_nodes: A 1-d Tensor of finished node ids from ProcessInput. +)doc"); + +REGISTER_OP("ProcessInputV4") + .Attr("random_seed: int") + .Attr("input_spec: string") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("stats_handle: resource") + .Input("input_data: float") + .Input("sparse_input_indices: int64") + .Input("sparse_input_values: float") + .Input("sparse_input_shape: int64") + .Input("input_labels: float") + .Input("input_weights: float") + .Output("finished_nodes: int32") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Doc(R"doc( +Add labels to stats after traversing the tree for each example. + +Outputs node ids that are finished. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +stats_handle: The handle to the stats. +input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` + gives the j-th feature of the i-th input. +sparse_input_indices: The indices tensor from the SparseTensor input. +sparse_input_values: The values tensor from the SparseTensor input. +sparse_input_shape: The shape tensor from the SparseTensor input. +input_labels: The training batch's labels as a 1 or 2-d tensor. + 'input_labels[i][j]' gives the j-th label/target for the i-th input. +input_weights: The training batch's eample weights as a 1-d tensor. + 'input_weights[i]' gives the weight for the i-th input. +finished_nodes: A 1-d tensor of node ids that have finished and are ready to + grow. +)doc"); + +REGISTER_OP("FinalizeTree") + .Attr("params: string") + .Input("tree_handle: resource") + .Input("stats_handle: resource") + .SetShapeFn([](InferenceContext* c) { + return Status::OK(); + }) + .Doc(R"doc( +Puts the Leaf models inside the tree into their final form. + +If drop_final_class is true, the per-class probability prediction of the +last class is not stored in the leaf models. + +params: A serialized TensorForestParams proto. +tree_handle: The handle to the tree. +stats_handle: The handle to the stats. +)doc"); +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/python/__init__.py b/tensorflow/contrib/tensor_forest/python/__init__.py index 0d41d3500d..b2ca89877a 100644 --- a/tensorflow/contrib/tensor_forest/python/__init__.py +++ b/tensorflow/contrib/tensor_forest/python/__init__.py @@ -21,4 +21,6 @@ from __future__ import print_function from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.contrib.tensor_forest.python.ops import data_ops +from tensorflow.contrib.tensor_forest.python.ops import model_ops +from tensorflow.contrib.tensor_forest.python.ops import stats_ops from tensorflow.contrib.tensor_forest.python.ops import tensor_forest_ops diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py new file mode 100644 index 0000000000..4c7218305b --- /dev/null +++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model ops python wrappers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops +from tensorflow.contrib.tensor_forest.python.ops import stats_ops + +# pylint: disable=unused-import +from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts +from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4 +from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size +# pylint: enable=unused-import + +from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops +from tensorflow.python.ops import resources +from tensorflow.python.platform import resource_loader +from tensorflow.python.training import saver + + +_model_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("_model_ops.so")) + + +ops.NotDifferentiable("TreeVariable") +ops.NotDifferentiable("TreeSerialize") +ops.NotDifferentiable("TreeDeserialize") +ops.NotDifferentiable("TreeSize") +ops.NotDifferentiable("TreePredictionsV4") +ops.NotDifferentiable("FeatureUsageCounts") + + +class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for TreeVariable.""" + + def __init__(self, params, tree_handle, stats_handle, create_op, name): + """Creates a TreeVariableSavable object. + + Args: + params: A TensorForestParams object. + tree_handle: handle to the tree variable. + stats_handle: handle to the stats variable. + create_op: the op to initialize the variable. + name: the name to save the tree variable under. + """ + self.params = params + deps = [] + if stats_handle is not None: + deps.append(stats_ops.finalize_tree( + tree_handle, stats_handle, + params=params.serialized_params_proto)) + with ops.control_dependencies(deps): + tensor = gen_model_ops.tree_serialize(tree_handle) + # slice_spec is useful for saving a slice from a variable. + # It's not meaningful the tree variable. So we just pass an empty value. + slice_spec = "" + specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name),] + super(TreeVariableSavable, + self).__init__(tree_handle, specs, name) + self._tree_handle = tree_handle + self._create_op = create_op + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree variable. + """ + with ops.control_dependencies([self._create_op]): + return gen_model_ops.tree_deserialize( + self._tree_handle, + restored_tensors[0], + params=self.params.serialized_params_proto) + + +def tree_variable(params, tree_config, stats_handle, name, container=None): + r"""Creates a tree model and returns a handle to it. + + Args: + params: A TensorForestParams object. + tree_config: A `Tensor` of type `string`. Serialized proto of the tree. + stats_handle: Resource handle to the stats object. + name: A name for the variable. + container: An optional `string`. Defaults to `""`. + + Returns: + A `Tensor` of type mutable `string`. The handle to the tree. + """ + with ops.name_scope(name, "TreeVariable") as name: + resource_handle = gen_model_ops.decision_tree_resource_handle_op( + container, name, name=name) + + create_op = gen_model_ops.create_tree_variable( + resource_handle, + tree_config, + params=params.serialized_params_proto) + is_initialized_op = gen_model_ops.tree_is_initialized_op(resource_handle) + # Adds the variable to the savable list. + saveable = TreeVariableSavable(params, resource_handle, stats_handle, + create_op, + "tree_checkpoint_{0}".format(name)) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + resources.register_resource(resource_handle, create_op, is_initialized_op) + return resource_handle diff --git a/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py new file mode 100644 index 0000000000..be9f2e12b7 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/python/ops/stats_ops.py @@ -0,0 +1,114 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Stats ops python wrappers.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tensor_forest.python.ops import gen_stats_ops +# pylint: disable=unused-import +from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import finalize_tree +from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import grow_tree_v4 +from tensorflow.contrib.tensor_forest.python.ops.gen_stats_ops import process_input_v4 +# pylint: enable=unused-import + +from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops +from tensorflow.python.ops import resources +from tensorflow.python.platform import resource_loader +from tensorflow.python.training import saver + + +_stats_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("_stats_ops.so")) + + +ops.NotDifferentiable("FertileStatsVariable") +ops.NotDifferentiable("FertileStatsSerialize") +ops.NotDifferentiable("FertileStatsDeserialize") +ops.NotDifferentiable("GrowTreeV4") +ops.NotDifferentiable("ProcessInputV4") +ops.NotDifferentiable("FinalizeTree") + + +class FertileStatsVariableSavable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject implementation for FertileStatsVariable.""" + + def __init__(self, params, stats_handle, create_op, name): + """Creates a FertileStatsVariableSavable object. + + Args: + params: A TensorForestParams object. + stats_handle: handle to the tree variable. + create_op: the op to initialize the variable. + name: the name to save the tree variable under. + """ + self.params = params + tensor = gen_stats_ops.fertile_stats_serialize( + stats_handle, params=params.serialized_params_proto) + # slice_spec is useful for saving a slice from a variable. + # It's not meaningful the tree variable. So we just pass an empty value. + slice_spec = "" + specs = [saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name),] + super(FertileStatsVariableSavable, + self).__init__(stats_handle, specs, name) + self._stats_handle = stats_handle + self._create_op = create_op + + def restore(self, restored_tensors, unused_restored_shapes): + """Restores the associated tree from 'restored_tensors'. + + Args: + restored_tensors: the tensors that were loaded from a checkpoint. + unused_restored_shapes: the shapes this object should conform to after + restore. Not meaningful for trees. + + Returns: + The operation that restores the state of the tree variable. + """ + with ops.control_dependencies([self._create_op]): + return gen_stats_ops.fertile_stats_deserialize( + self._stats_handle, restored_tensors[0], + params=self.params.serialized_params_proto) + + +def fertile_stats_variable(params, stats_config, name, + container=None): + r"""Creates a stats object and returns a handle to it. + + Args: + params: A TensorForestParams object. + stats_config: A `Tensor` of type `string`. Serialized proto of the stats. + name: A name for the variable. + container: An optional `string`. Defaults to `""`. + + Returns: + A `Tensor` of type mutable `string`. The handle to the stats. + """ + with ops.name_scope(name, "FertileStatsVariable") as name: + resource_handle = gen_stats_ops.fertile_stats_resource_handle_op( + container, name, name=name) + + create_op = gen_stats_ops.create_fertile_stats_variable( + resource_handle, stats_config, + params=params.serialized_params_proto) + is_initialized_op = gen_stats_ops.fertile_stats_is_initialized_op( + resource_handle) + # Adds the variable to the savable list. + saveable = FertileStatsVariableSavable(params, resource_handle, create_op, + "stats_checkpoint_{0}".format(name)) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + resources.register_resource(resource_handle, create_op, is_initialized_op) + return resource_handle |