From b70103502b41df370906e8988b6593e55caf69cf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 30 Jun 2016 14:40:02 -0800 Subject: Improvements to tensor_forest, including support for sparse and categorical inputs. Add tf.learn.Estimator for random forests. Change: 126352221 --- tensorflow/contrib/tensor_forest/BUILD | 51 +++ tensorflow/contrib/tensor_forest/__init__.py | 2 + .../contrib/tensor_forest/client/__init__.py | 21 + .../contrib/tensor_forest/client/eval_metrics.py | 69 +++ .../core/ops/count_extremely_random_stats_op.cc | 251 +++++++--- .../tensor_forest/core/ops/finished_nodes_op.cc | 132 +++++- .../tensor_forest/core/ops/sample_inputs_op.cc | 114 ++++- .../tensor_forest/core/ops/tree_predictions_op.cc | 96 +++- .../contrib/tensor_forest/core/ops/tree_utils.cc | 266 +++++++++-- .../contrib/tensor_forest/core/ops/tree_utils.h | 93 +++- .../core/ops/update_fertile_slots_op.cc | 101 ++-- tensorflow/contrib/tensor_forest/data/__init__.py | 21 + tensorflow/contrib/tensor_forest/data/data_ops.py | 109 +++++ .../tensor_forest/data/string_to_float_op.cc | 111 +++++ .../contrib/tensor_forest/python/__init__.py | 1 + .../contrib/tensor_forest/python/constants.py | 26 ++ .../python/kernel_tests/best_splits_op_test.py | 18 +- .../count_extremely_random_stats_op_test.py | 101 +++- .../python/kernel_tests/finished_nodes_op_test.py | 56 ++- .../python/kernel_tests/sample_inputs_op_test.py | 34 +- .../kernel_tests/tree_predictions_op_test.py | 70 ++- .../kernel_tests/update_fertile_slots_op_test.py | 29 +- .../tensor_forest/python/ops/inference_ops.py | 24 +- .../tensor_forest/python/ops/training_ops.py | 28 +- .../contrib/tensor_forest/python/tensor_forest.py | 510 ++++++++++++--------- .../tensor_forest/python/tensor_forest_test.py | 41 ++ 26 files changed, 1854 insertions(+), 521 deletions(-) create mode 100644 tensorflow/contrib/tensor_forest/client/__init__.py create mode 100644 tensorflow/contrib/tensor_forest/client/eval_metrics.py create mode 100644 tensorflow/contrib/tensor_forest/data/__init__.py create mode 100644 tensorflow/contrib/tensor_forest/data/data_ops.py create mode 100644 tensorflow/contrib/tensor_forest/data/string_to_float_op.cc create mode 100644 tensorflow/contrib/tensor_forest/python/constants.py (limited to 'tensorflow/contrib/tensor_forest') diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 792243790c..8c9dc74222 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -18,6 +18,54 @@ filegroup( ), ) +py_library( + name = "constants", + srcs = [ + "python/constants.py", + ], + srcs_version = "PY2AND3", +) + +tf_custom_op_library( + name = "data/_data_ops.so", + srcs = [ + "data/string_to_float_op.cc", + ], + deps = [ + ":tree_utils", + ], +) + +py_library( + name = "data_ops_lib", + srcs = [ + "data/data_ops.py", + ], + data = [ + "data/_data_ops.so", + ], + srcs_version = "PY2AND3", + deps = [ + ":constants", + ], +) + +py_library( + name = "eval_metrics", + srcs = ["client/eval_metrics.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "client_lib", + srcs_version = "PY2AND3", + deps = [ + ":data_ops_lib", + ":eval_metrics", + ":tensor_forest_py", + ], +) + cc_library( name = "tree_utils", srcs = ["core/ops/tree_utils.cc"], @@ -86,6 +134,7 @@ py_test( srcs = ["python/kernel_tests/count_extremely_random_stats_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":constants", ":ops_lib", "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", @@ -151,6 +200,7 @@ py_test( srcs = ["python/kernel_tests/tree_predictions_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":constants", ":ops_lib", "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", @@ -176,6 +226,7 @@ py_library( srcs = ["python/tensor_forest.py"], srcs_version = "PY2AND3", deps = [ + ":constants", ":ops_lib", ], ) diff --git a/tensorflow/contrib/tensor_forest/__init__.py b/tensorflow/contrib/tensor_forest/__init__.py index 7cf05299c4..7d97e01df0 100644 --- a/tensorflow/contrib/tensor_forest/__init__.py +++ b/tensorflow/contrib/tensor_forest/__init__.py @@ -18,4 +18,6 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensor_forest.client import * +from tensorflow.contrib.tensor_forest.data import * from tensorflow.contrib.tensor_forest.python import * diff --git a/tensorflow/contrib/tensor_forest/client/__init__.py b/tensorflow/contrib/tensor_forest/client/__init__.py new file mode 100644 index 0000000000..753f406cbc --- /dev/null +++ b/tensorflow/contrib/tensor_forest/client/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Random forest implementation in tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensor_forest.client import eval_metrics diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py new file mode 100644 index 0000000000..f41794a886 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -0,0 +1,69 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A collection of functions to be used as evaluation metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import losses +from tensorflow.contrib.metrics.python.ops import metric_ops + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def _accuracy(probabilities, targets): + predictions = math_ops.argmax(probabilities, 1) + # undo one-hot + labels = math_ops.argmax(targets, 1) + return metric_ops.streaming_accuracy(predictions, labels) + + +def _r2(probabilities, targets): + if targets.get_shape().ndims == 1: + targets = array_ops.expand_dims(targets, -1) + y_mean = math_ops.reduce_mean(targets, 0) + squares_total = math_ops.reduce_sum(math_ops.square(targets - y_mean), 0) + squares_residuals = math_ops.reduce_sum(math_ops.square( + targets - probabilities), 0) + score = 1 - math_ops.reduce_sum(squares_residuals / squares_total) + return metric_ops.streaming_mean(score) + + +def _sigmoid_entropy(probabilities, targets): + return metric_ops.streaming_mean(losses.sigmoid_cross_entropy( + probabilities, targets)) + + +def _softmax_entropy(probabilities, targets): + return metric_ops.streaming_mean(losses.softmax_cross_entropy( + probabilities, targets)) + + +def _predictions(probabilities, unused_targets): + return math_ops.argmax(probabilities, 1) + + +_EVAL_METRICS = {'sigmoid_entropy': _sigmoid_entropy, + 'softmax_entropy': _softmax_entropy, + 'accuracy': _accuracy, + 'r2': _r2, + 'predictions': _predictions} + + +def get_metric(metric_name): + """Given a metric name, return the corresponding metric function.""" + return _EVAL_METRICS[metric_name] diff --git a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc index 0ccf75bcc6..0413f1e20a 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/count_extremely_random_stats_op.cc @@ -43,7 +43,7 @@ using tensorforest::LEAF_NODE; using tensorforest::FREE_NODE; using tensorforest::CheckTensorBounds; -using tensorforest::DecideNode; +using tensorforest::DataColumnTypes; using tensorforest::Initialize; using tensorforest::IsAllInitialized; @@ -61,61 +61,77 @@ struct InputDataResult { bool splits_initialized; }; -void Evaluate(const Tensor& input_data, const Tensor& input_labels, - const Tensor& tree_tensor, const Tensor& tree_thresholds, - const Tensor& node_to_accumulator, - const Tensor& candidate_split_features, - const Tensor& candidate_split_thresholds, - InputDataResult* results, int32 start, int32 end) { - const auto tree = tree_tensor.tensor(); - const auto thresholds = tree_thresholds.unaligned_flat(); - const auto node_map = node_to_accumulator.unaligned_flat(); - const auto split_features = candidate_split_features.tensor(); - const auto split_thresholds = candidate_split_thresholds.tensor(); + +struct EvaluateParams { + std::function decide_function; + Tensor input_spec; + Tensor input_labels; + Tensor tree_tensor; + Tensor tree_thresholds; + Tensor node_to_accumulator; + Tensor candidate_split_features; + Tensor candidate_split_thresholds; + InputDataResult* results; +}; + +void Evaluate(const EvaluateParams& params, int32 start, int32 end) { + const auto tree = params.tree_tensor.tensor(); + const auto thresholds = params.tree_thresholds.unaligned_flat(); + const auto node_map = params.node_to_accumulator.unaligned_flat(); + const auto split_features = + params.candidate_split_features.tensor(); + const auto split_thresholds = + params.candidate_split_thresholds.tensor(); + const auto spec = params.input_spec.unaligned_flat(); const int32 num_splits = static_cast( - candidate_split_features.shape().dim_size(1)); - const int32 num_nodes = static_cast(tree_tensor.shape().dim_size(0)); + params.candidate_split_features.shape().dim_size(1)); + const int32 num_nodes = static_cast( + params.tree_tensor.shape().dim_size(0)); const int32 num_accumulators = static_cast( - candidate_split_features.shape().dim_size(0)); + params.candidate_split_features.shape().dim_size(0)); for (int32 i = start; i < end; ++i) { - const Tensor point = input_data.Slice(i, i + 1); int node_index = 0; - results[i].splits_initialized = false; + params.results[i].splits_initialized = false; while (true) { - results[i].node_indices.push_back(node_index); + params.results[i].node_indices.push_back(node_index); CHECK_LT(node_index, num_nodes); int32 left_child = internal::SubtleMustCopy( tree(node_index, CHILDREN_INDEX)); if (left_child == LEAF_NODE) { const int32 accumulator = internal::SubtleMustCopy( node_map(node_index)); - results[i].leaf_accumulator = accumulator; + params.results[i].leaf_accumulator = accumulator; // If the leaf is not fertile or is not yet initialized, we don't // count it in the candidate/total split per-class-weights because // it won't have any candidate splits yet. if (accumulator >= 0 && - IsAllInitialized(candidate_split_features.Slice( + IsAllInitialized(params.candidate_split_features.Slice( accumulator, accumulator + 1))) { CHECK_LT(accumulator, num_accumulators); - results[i].splits_initialized = true; + params.results[i].splits_initialized = true; for (int split = 0; split < num_splits; split++) { - if (!DecideNode(point, split_features(accumulator, split), - split_thresholds(accumulator, split))) { - results[i].split_adds.push_back(split); + const int32 feature = split_features(accumulator, split); + if (!params.decide_function( + i, feature, split_thresholds(accumulator, split), + static_cast(spec(feature)))) { + params.results[i].split_adds.push_back(split); } } } break; } else if (left_child == FREE_NODE) { LOG(ERROR) << "Reached a free node, not good."; - results[i].node_indices.push_back(FREE_NODE); + params.results[i].node_indices.push_back(FREE_NODE); break; } + const int32 feature = tree(node_index, FEATURE_INDEX); node_index = - left_child + DecideNode(point, tree(node_index, FEATURE_INDEX), - thresholds(node_index)); + left_child + params.decide_function( + i, feature, thresholds(node_index), + static_cast(spec(feature))); } } } @@ -124,16 +140,18 @@ REGISTER_OP("CountExtremelyRandomStats") .Attr("num_classes: int") .Attr("regression: bool = false") .Input("input_data: float") + .Input("sparse_input_indices: int64") + .Input("sparse_input_values: float") + .Input("sparse_input_shape: int64") + .Input("input_spec: int32") .Input("input_labels: float") - .Input("tree: int32") .Input("tree_thresholds: float") - .Input("node_to_accumulator: int32") - .Input("candidate_split_features: int32") .Input("candidate_split_thresholds: float") - + .Input("birth_epochs: int32") + .Input("current_epoch: int32") .Output("pcw_node_sums_delta: float") .Output("pcw_node_squares_delta: float") .Output("pcw_splits_indices: int32") @@ -142,7 +160,6 @@ REGISTER_OP("CountExtremelyRandomStats") .Output("pcw_totals_indices: int32") .Output("pcw_totals_sums_delta: float") .Output("pcw_totals_squares_delta: float") - .Output("leaves: int32") .Doc(R"doc( Calculates incremental statistics for a batch of training data. @@ -156,7 +173,7 @@ For `regression` = false (classification), `pcw_node_sums_delta[i]` is incremented for every node i that it passes through, and the leaf it ends up in is recorded in `leaves[i]`. Then, if the leaf is fertile and initialized, the statistics for its corresponding accumulator slot -are updated in `pcw_candidate_splits_delta` and `pcw_total_splits_delta`. +are updated in `pcw_candidate_sums_delta` and `pcw_totals_sums_delta`. For `regression` = true, outputs contain the sum of the input_labels for the appropriate nodes. In adddition, the *_squares outputs are filled @@ -171,6 +188,11 @@ The attr `num_classes` is needed to appropriately size the outputs. input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` gives the j-th feature of the i-th input. +sparse_input_indices: The indices tensor from the SparseTensor input. +sparse_input_values: The values tensor from the SparseTensor input. +sparse_input_shape: The shape tensor from the SparseTensor input. +input_spec: A 1-D tensor containing the type of each column in input_data, + (e.g. continuous float, categorical). input_labels: The training batch's labels; `input_labels[i]` is the class of the i-th input. tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child @@ -185,6 +207,10 @@ candidate_split_features: `candidate_split_features[a][s]` is the index of the feature being considered by split s of accumulator slot a. candidate_split_thresholds: `candidate_split_thresholds[a][s]` is the threshold value being considered by split s of accumulator slot a. +birth_epochs: `birth_epoch[i]` is the epoch node i was born in. Only + nodes satisfying `current_epoch - birth_epoch <= 1` accumulate statistics. +current_epoch:= A 1-d int32 tensor with shape (1). current_epoch[0] contains + the current epoch. pcw_node_sums_delta: `pcw_node_sums_delta[i][c]` is the number of training examples in this training batch with class c that passed through node i for classification. For regression, it is the sum of the input_labels that @@ -236,17 +262,57 @@ class CountExtremelyRandomStats : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_data = context->input(0); - const Tensor& input_labels = context->input(1); - const Tensor& tree_tensor = context->input(2); - const Tensor& tree_thresholds = context->input(3); - const Tensor& node_to_accumulator = context->input(4); - const Tensor& candidate_split_features = context->input(5); - const Tensor& candidate_split_thresholds = context->input(6); + const Tensor& sparse_input_indices = context->input(1); + const Tensor& sparse_input_values = context->input(2); + const Tensor& sparse_input_shape = context->input(3); + const Tensor& input_spec = context->input(4); + const Tensor& input_labels = context->input(5); + const Tensor& tree_tensor = context->input(6); + const Tensor& tree_thresholds = context->input(7); + const Tensor& node_to_accumulator = context->input(8); + const Tensor& candidate_split_features = context->input(9); + const Tensor& candidate_split_thresholds = context->input(10); + const Tensor& birth_epochs = context->input(11); + const Tensor& current_epoch = context->input(12); + + bool sparse_input = (sparse_input_indices.shape().dims() == 2); // Check inputs. - OP_REQUIRES(context, input_data.shape().dims() == 2, + if (sparse_input) { + OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1, + errors::InvalidArgument( + "sparse_input_shape should be one-dimensional")); + OP_REQUIRES(context, + sparse_input_shape.shape().dim_size(0) == 2, + errors::InvalidArgument( + "The sparse input data should be two-dimensional")); + OP_REQUIRES(context, sparse_input_values.shape().dims() == 1, + errors::InvalidArgument( + "sparse_input_values should be one-dimensional")); + OP_REQUIRES(context, sparse_input_indices.shape().dims() == 2, + errors::InvalidArgument( + "The sparse input data should be two-dimensional")); + OP_REQUIRES(context, + sparse_input_indices.shape().dim_size(0) == + sparse_input_values.shape().dim_size(0), + errors::InvalidArgument( + "sparse_input_indices and sparse_input_values should " + "agree on the number of non-zero values")); + } else { + OP_REQUIRES(context, input_data.shape().dims() == 2, + errors::InvalidArgument( + "input_data should be two-dimensional")); + OP_REQUIRES( + context, + input_data.shape().dim_size(0) == input_labels.shape().dim_size(0), + errors::InvalidArgument( + "Number of inputs should be the same in " + "input_data and input_labels.")); + } + + OP_REQUIRES(context, input_labels.shape().dims() >= 1, errors::InvalidArgument( - "input_data should be two-dimensional")); + "input_labels should be at least one-dimensional")); OP_REQUIRES(context, tree_tensor.shape().dims() == 2, errors::InvalidArgument( "tree should be two-dimensional")); @@ -262,58 +328,93 @@ class CountExtremelyRandomStats : public OpKernel { OP_REQUIRES(context, candidate_split_thresholds.shape().dims() == 2, errors::InvalidArgument( "candidate_split_thresholds should be two-dimensional")); - - OP_REQUIRES( - context, - input_data.shape().dim_size(0) == input_labels.shape().dim_size(0), - errors::InvalidArgument( - "Number of inputs should be the same in " - "input_data and input_labels.")); + OP_REQUIRES(context, birth_epochs.shape().dims() == 1, + errors::InvalidArgument( + "birth_epochs should be one-dimensional")); + OP_REQUIRES(context, current_epoch.shape().dims() == 1, + errors::InvalidArgument( + "current_epoch should be one-dimensional")); OP_REQUIRES( context, tree_tensor.shape().dim_size(0) == tree_thresholds.shape().dim_size(0) && tree_tensor.shape().dim_size(0) == - node_to_accumulator.shape().dim_size(0), + node_to_accumulator.shape().dim_size(0) && + tree_tensor.shape().dim_size(0) == + birth_epochs.shape().dim_size(0), errors::InvalidArgument( "Number of nodes should be the same in " - "tree, tree_thresholds, and node_to_accumulator")); + "tree, tree_thresholds, node_to_accumulator, and birth_epoch.")); OP_REQUIRES( context, candidate_split_features.shape() == candidate_split_thresholds.shape(), errors::InvalidArgument( "candidate_split_features and candidate_split_thresholds should be " "the same shape.")); + OP_REQUIRES( + context, + current_epoch.shape().dim_size(0) == 1, + errors::InvalidArgument( + "The current_epoch should be a tensor of shape (1).")); // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; + if (!CheckTensorBounds(context, sparse_input_indices)) return; + if (!CheckTensorBounds(context, sparse_input_values)) return; + if (!CheckTensorBounds(context, sparse_input_shape)) return; if (!CheckTensorBounds(context, input_labels)) return; if (!CheckTensorBounds(context, tree_tensor)) return; if (!CheckTensorBounds(context, tree_thresholds)) return; if (!CheckTensorBounds(context, node_to_accumulator)) return; if (!CheckTensorBounds(context, candidate_split_features)) return; if (!CheckTensorBounds(context, candidate_split_thresholds)) return; + if (!CheckTensorBounds(context, birth_epochs)) return; + if (!CheckTensorBounds(context, current_epoch)) return; // Evaluate input data in parallel. - const int32 num_data = static_cast(input_data.shape().dim_size(0)); + const int32 epoch = current_epoch.unaligned_flat()(0); + int32 num_data; + std::function decide_function; + if (sparse_input) { + num_data = sparse_input_shape.unaligned_flat()(0); + decide_function = [&sparse_input_indices, &sparse_input_values]( + int32 i, int32 feature, float bias, DataColumnTypes type) { + const auto sparse_indices = sparse_input_indices.matrix(); + const auto sparse_values = sparse_input_values.vec(); + return tensorforest::DecideSparseNode( + sparse_indices, sparse_values, i, feature, bias, type); + }; + } else { + num_data = static_cast(input_data.shape().dim_size(0)); + decide_function = [&input_data]( + int32 i, int32 feature, float bias, DataColumnTypes type) { + const auto input_matrix = input_data.matrix(); + return tensorforest::DecideDenseNode( + input_matrix, i, feature, bias, type); + }; + } std::unique_ptr results(new InputDataResult[num_data]); auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; + EvaluateParams params; + params.decide_function = decide_function; + params.input_spec = input_spec; + params.input_labels = input_labels; + params.tree_tensor = tree_tensor; + params.tree_thresholds = tree_thresholds; + params.node_to_accumulator = node_to_accumulator; + params.candidate_split_features = candidate_split_features; + params.candidate_split_thresholds = candidate_split_thresholds; + params.results = results.get(); if (num_threads <= 1) { - Evaluate(input_data, input_labels, tree_tensor, tree_thresholds, - node_to_accumulator, candidate_split_features, - candidate_split_thresholds, results.get(), 0, num_data); + Evaluate(params, 0, num_data); } else { - auto work = [&input_data, &input_labels, &tree_tensor, &tree_thresholds, - &node_to_accumulator, &candidate_split_features, - &candidate_split_thresholds, &num_data, - &results](int64 start, int64 end) { + auto work = [¶ms, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - Evaluate(input_data, input_labels, tree_tensor, tree_thresholds, - node_to_accumulator, candidate_split_features, - candidate_split_thresholds, results.get(), + Evaluate(params, static_cast(start), static_cast(end)); }; Shard(num_threads, worker_threads->workers, num_data, 100, work); @@ -321,11 +422,13 @@ class CountExtremelyRandomStats : public OpKernel { const int32 num_nodes = static_cast(tree_tensor.shape().dim_size(0)); if (regression_) { - ProcessResultsRegression(context, input_labels, std::move(results), - num_nodes); + ProcessResultsRegression( + context, input_labels, birth_epochs, epoch, std::move(results), + num_nodes); } else { - ProcessResultsClassification(context, input_labels, std::move(results), - num_nodes); + ProcessResultsClassification( + context, input_labels, birth_epochs, epoch, std::move(results), + num_nodes); } } @@ -333,10 +436,13 @@ class CountExtremelyRandomStats : public OpKernel { void ProcessResultsClassification( OpKernelContext* context, const Tensor &input_labels, + const Tensor &birth_epochs, + int32 epoch, std::unique_ptr results, int32 num_nodes) { const int32 num_data = static_cast(input_labels.shape().dim_size(0)); const auto labels = input_labels.unaligned_flat(); + const auto start_epochs = birth_epochs.unaligned_flat(); // Unused outputs for classification. Still have to specify them or // tensorflow complains. @@ -381,10 +487,16 @@ class CountExtremelyRandomStats : public OpKernel { CHECK_LT(column, num_classes_); const int32 accumulator = results[i].leaf_accumulator; for (const int32 node : results[i].node_indices) { + if (epoch > start_epochs(node) + 1) { + continue; + } ++out_node_sums(node, column); ++out_node_sums(node, 0); } out_leaves(i) = results[i].node_indices.back(); + if (epoch > start_epochs(out_leaves(i)) + 1) { + continue; + } if (accumulator >= 0 && results[i].splits_initialized) { ++total_delta[make_pair(accumulator, column)]; ++total_delta[make_pair(accumulator, 0)]; @@ -457,6 +569,8 @@ class CountExtremelyRandomStats : public OpKernel { void ProcessResultsRegression( OpKernelContext* context, const Tensor &input_labels, + const Tensor &birth_epochs, + const int32 epoch, std::unique_ptr results, int32 num_nodes) { const int32 num_data = static_cast(input_labels.shape().dim_size(0)); @@ -465,6 +579,7 @@ class CountExtremelyRandomStats : public OpKernel { num_outputs = static_cast(input_labels.shape().dim_size(1)); } const auto labels = input_labels.unaligned_flat(); + const auto start_epochs = birth_epochs.unaligned_flat(); // node pcw delta Tensor* output_node_pcw_sums_delta = nullptr; @@ -503,6 +618,9 @@ class CountExtremelyRandomStats : public OpKernel { for (int32 i = 0; i < num_data; ++i) { const int32 accumulator = results[i].leaf_accumulator; for (const int32 node : results[i].node_indices) { + if (epoch > start_epochs(node) + 1) { + continue; + } for (int32 j = 0; j < num_outputs; ++j) { const float output = labels(i * num_outputs + j); out_node_sums(node, j + 1) += output; @@ -512,6 +630,9 @@ class CountExtremelyRandomStats : public OpKernel { } } out_leaves(i) = results[i].node_indices.back(); + if (epoch > start_epochs(out_leaves(i)) + 1) { + continue; + } if (accumulator >= 0 && results[i].splits_initialized) { total_delta[accumulator].insert(i); for (const int32 split : results[i].split_adds) { diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc index e1369f9d8c..d179f5b84e 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc @@ -24,42 +24,84 @@ namespace tensorflow { using tensorforest::CheckTensorBounds; using tensorforest::Sum; +using tensorforest::BestSplitDominatesClassification; +using tensorforest::BestSplitDominatesRegression; REGISTER_OP("FinishedNodes") + .Attr("regression: bool = false") .Attr("num_split_after_samples: int") + .Attr("min_split_samples: int") + .Attr("dominate_fraction: float = 0.95") .Input("leaves: int32") .Input("node_to_accumulator: int32") + .Input("split_sums: float") + .Input("split_squares: float") .Input("accumulator_sums: float") - + .Input("accumulator_squares: float") + .Input("birth_epochs: int32") + .Input("current_epoch: int32") .Output("finished: int32") + .Output("stale: int32") .Doc(R"doc( Determines which of the given leaf nodes are done accumulating. leaves:= A 1-d int32 tensor. Lists the nodes that are currently leaves. node_to_accumulator: If the i-th node is fertile, `node_to_accumulator[i]` is it's accumulator slot. Otherwise, `node_to_accumulator[i]` is -1. -accumulator_sums: For classification, `accumulator_sums[a][c]` records how many - training examples have class c and have ended up in the fertile node +split_sums:= a 3-d tensor where `split_sums[a][s]` summarizes the + training labels for examples that fall into the fertile node associated with + accumulator slot s and have then taken the *left* branch of candidate split + s. For a classification problem, `split_sums[a][s][c]` is the count of such + examples with class c and for regression problems, `split_sums[a][s]` is the + sum of the regression labels for such examples. +split_squares: Same as split_sums, but it contains the sum of the + squares of the regression labels. Only used for regression. For + classification problems, pass a dummy tensor into this. +accumulator_sums: For classification, `accumulator_sums[a][c]` records how + many training examples have class c and have ended up in the fertile node associated with accumulator slot a. It has the total sum in entry 0 for convenience. For regression, it is the same except it contains the sum of the input labels that have been seen, and entry 0 contains the number of training examples that have been seen. -finished:= A 1-d int32 tensor. Contains the nodes that have total split - counts greater or equal to the num_split_after_samples attribute. +accumulator_squares: Same as accumulator_sums, but it contains the sum of the + squares of the regression labels. Only used for regression. For + classification problems, pass a dummy tensor into this. +birth_epochs:= A 1-d int32 tensor. `birth_epochs[i]` contains the epoch + the i-th node was created in. +current_epoch:= A 1-d int32 tensor with shape (1). `current_epoch[0]` + stores the current epoch number. +finished:= A 1-d int32 tensor containing the indices of the finished nodes. + Nodes are finished if they have received at least num_split_after_samples + samples, or if they have received min_split_samples and the best scoring + split is sufficiently greater than the next best split. +stale:= A 1-d int32 tensor containing the fertile nodes that were created two + or more epochs ago. + )doc"); class FinishedNodes : public OpKernel { public: explicit FinishedNodes(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr( + "regression", ®ression_)); OP_REQUIRES_OK(context, context->GetAttr( "num_split_after_samples", &num_split_after_samples_)); + OP_REQUIRES_OK(context, context->GetAttr( + "min_split_samples", &min_split_samples_)); + OP_REQUIRES_OK(context, context->GetAttr( + "dominate_fraction", &dominate_fraction_)); } void Compute(OpKernelContext* context) override { const Tensor& leaf_tensor = context->input(0); const Tensor& node_to_accumulator = context->input(1); - const Tensor& accumulator_sums = context->input(2); + const Tensor& split_sums = context->input(2); + const Tensor& split_squares = context->input(3); + const Tensor& accumulator_sums = context->input(4); + const Tensor& accumulator_squares = context->input(5); + const Tensor& birth_epochs = context->input(6); + const Tensor& current_epoch = context->input(7); OP_REQUIRES(context, leaf_tensor.shape().dims() == 1, errors::InvalidArgument( @@ -67,25 +109,45 @@ class FinishedNodes : public OpKernel { OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1, errors::InvalidArgument( "node_to_accumulator should be one-dimensional")); + OP_REQUIRES(context, split_sums.shape().dims() == 3, + errors::InvalidArgument( + "split_sums should be three-dimensional")); OP_REQUIRES(context, accumulator_sums.shape().dims() == 2, errors::InvalidArgument( "accumulator_sums should be two-dimensional")); + OP_REQUIRES(context, birth_epochs.shape().dims() == 1, + errors::InvalidArgument( + "birth_epochs should be one-dimensional")); + OP_REQUIRES( + context, + birth_epochs.shape().dim_size(0) == + node_to_accumulator.shape().dim_size(0), + errors::InvalidArgument( + "birth_epochs and node_to_accumulator should be the same size.")); // Check tensor bounds. if (!CheckTensorBounds(context, leaf_tensor)) return; if (!CheckTensorBounds(context, node_to_accumulator)) return; + if (!CheckTensorBounds(context, split_sums)) return; + if (!CheckTensorBounds(context, split_squares)) return; if (!CheckTensorBounds(context, accumulator_sums)) return; + if (!CheckTensorBounds(context, accumulator_squares)) return; + if (!CheckTensorBounds(context, birth_epochs)) return; + if (!CheckTensorBounds(context, current_epoch)) return; const auto leaves = leaf_tensor.unaligned_flat(); const auto node_map = node_to_accumulator.unaligned_flat(); const auto sums = accumulator_sums.tensor(); + const auto start_epochs = birth_epochs.unaligned_flat(); + const int32 epoch = current_epoch.unaligned_flat()(0); const int32 num_leaves = static_cast( leaf_tensor.shape().dim_size(0)); const int32 num_accumulators = static_cast( accumulator_sums.shape().dim_size(0)); - std::vector finished; + std::vector finished_leaves; + std::vector stale; for (int32 i = 0; i < num_leaves; i++) { const int32 leaf = internal::SubtleMustCopy(leaves(i)); OP_REQUIRES(context, FastBoundsCheck(leaf, node_map.size()), @@ -97,30 +159,74 @@ class FinishedNodes : public OpKernel { OP_REQUIRES(context, FastBoundsCheck(accumulator, num_accumulators), errors::InvalidArgument("accumulator not in valid range.")) - // The first column holds the number of samples seen. // For classification, this should be the sum of the other columns. - if (sums(accumulator, 0) >= num_split_after_samples_) { - finished.push_back(leaf); + int32 count = sums(accumulator, 0); + + if (epoch > start_epochs(leaf) + 1) { + if (count >= min_split_samples_) { + finished_leaves.push_back(leaf); + } else { + stale.push_back(leaf); + } + continue; + } + + if (count >= num_split_after_samples_) { + finished_leaves.push_back(leaf); + continue; + } + + if (count < min_split_samples_) { + continue; + } + + bool finished = false; + if (regression_) { + finished = BestSplitDominatesRegression( + accumulator_sums, accumulator_squares, + split_sums, split_squares, accumulator); + } else { + finished = BestSplitDominatesClassification( + accumulator_sums, split_sums, accumulator, dominate_fraction_); + } + + if (finished) { + finished_leaves.push_back(leaf); } } // Copy to output. Tensor* output_finished = nullptr; TensorShape finished_shape; - finished_shape.AddDim(finished.size()); + finished_shape.AddDim(finished_leaves.size()); OP_REQUIRES_OK(context, context->allocate_output(0, finished_shape, &output_finished)); auto out_finished = output_finished->unaligned_flat(); - for (int32 i = 0; i < finished.size(); i++) { - out_finished(i) = finished[i]; + for (int32 i = 0; i < finished_leaves.size(); i++) { + out_finished(i) = finished_leaves[i]; + } + + Tensor* output_stale = nullptr; + TensorShape stale_shape; + stale_shape.AddDim(stale.size()); + OP_REQUIRES_OK(context, + context->allocate_output(1, stale_shape, + &output_stale)); + auto out_stale = output_stale->unaligned_flat(); + + for (int32 i = 0; i < stale.size(); i++) { + out_stale(i) = stale[i]; } } private: + bool regression_; int32 num_split_after_samples_; + int32 min_split_samples_; + float dominate_fraction_; }; REGISTER_KERNEL_BUILDER(Name("FinishedNodes").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc index 182b1257b6..8b15f8a0b5 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/sample_inputs_op.cc @@ -35,6 +35,9 @@ REGISTER_OP("SampleInputs") .Attr("split_initializations_per_input: int") .Attr("split_sampling_random_seed: int") .Input("input_data: float") + .Input("sparse_input_indices: int64") + .Input("sparse_input_values: float") + .Input("sparse_input_shape: int64") .Input("node_to_accumulator: int32") .Input("leaves: int32") .Input("candidate_split_features: int32") @@ -60,6 +63,9 @@ a single training example can initialize, and the attribute input_data: The features for the current batch of training data. `input_data[i][j]` is the j-th feature of the i-th input. +sparse_input_indices: The indices tensor from the SparseTensor input. +sparse_input_values: The values tensor from the SparseTensor input. +sparse_input_shape: The shape tensor from the SparseTensor input. node_to_accumulator: For a fertile node i, node_to_accumulator[i] is the associated accumulator slot. For non-fertile nodes, it is -1. leaves: `leaves[i]` is the leaf that the i-th input landed in, as @@ -82,6 +88,7 @@ new_split_threshold_rows: The new values for the candidate_split_thresholds `tf.scatter_update(candidate_split_thresholds, accumulators_to_update, new_split_feature_thresholds)` + )doc"); class SampleInputs : public OpKernel { @@ -106,16 +113,74 @@ class SampleInputs : public OpKernel { new random::SimplePhilox(single_rand_.get())); } + template + void GetRandomFeatureDense(const T& inputs, int32 num_features, + int32 input_index, int32* index, float* val) { + *index = rng_->Uniform(num_features); + *val = inputs(input_index, *index); + } + + template + void GetRandomFeatureSparse(const T1& sparse_indices, const T2& sparse_values, + int32 input_index, int32* index, float* val) { + int32 low = 0; + int32 high = sparse_values.dimension(0); + while (low < high) { + int32 vi = low + rng_->Uniform(high - low); + int64 i = internal::SubtleMustCopy(sparse_indices(vi, 0)); + if (i == input_index) { + *index = internal::SubtleMustCopy(sparse_indices(vi, 1)); + *val = sparse_values(vi); + return; + } + if (i < input_index) { + low = vi + 1; + } else { + high = vi; + } + } + LOG(FATAL) << "Could not find any values for input " << input_index + << " inside sparse_input_indices"; + } + void Compute(OpKernelContext* context) override { const Tensor& input_data = context->input(0); - const Tensor& node_to_accumulator = context->input(1); - const Tensor& leaves = context->input(2); - const Tensor& split_features = context->input(3); - const Tensor& split_thresholds = context->input(4); + const Tensor& sparse_input_indices = context->input(1); + const Tensor& sparse_input_values = context->input(2); + const Tensor& sparse_input_shape = context->input(3); + const Tensor& node_to_accumulator = context->input(4); + const Tensor& leaves = context->input(5); + const Tensor& split_features = context->input(6); + const Tensor& split_thresholds = context->input(7); + + bool sparse_input = (sparse_input_indices.shape().dims() == 2); + + if (sparse_input) { + OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1, + errors::InvalidArgument( + "sparse_input_shape should be one-dimensional")); + OP_REQUIRES(context, + sparse_input_shape.shape().dim_size(0) == 2, + errors::InvalidArgument( + "The sparse input data should be two-dimensional")); + OP_REQUIRES(context, sparse_input_values.shape().dims() == 1, + errors::InvalidArgument( + "sparse_input_values should be one-dimensional")); + OP_REQUIRES(context, sparse_input_indices.shape().dims() == 2, + errors::InvalidArgument( + "The sparse input data should be two-dimensional")); + OP_REQUIRES(context, + sparse_input_indices.shape().dim_size(0) == + sparse_input_values.shape().dim_size(0), + errors::InvalidArgument( + "sparse_input_indices and sparse_input_values should " + "agree on the number of non-zero values")); + } else { + OP_REQUIRES(context, input_data.shape().dims() == 2, + errors::InvalidArgument( + "input_data should be two-dimensional")); + } - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1, errors::InvalidArgument( "node_to_accumulator should be one-dimensional")); @@ -137,12 +202,36 @@ class SampleInputs : public OpKernel { // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; + if (!CheckTensorBounds(context, sparse_input_indices)) return; + if (!CheckTensorBounds(context, sparse_input_values)) return; + if (!CheckTensorBounds(context, sparse_input_shape)) return; if (!CheckTensorBounds(context, node_to_accumulator)) return; if (!CheckTensorBounds(context, leaves)) return; if (!CheckTensorBounds(context, split_features)) return; if (!CheckTensorBounds(context, split_thresholds)) return; - const auto inputs = input_data.tensor(); + int32 num_features; + std::function get_random_feature; + // TODO(thomaswc): Figure out a way to avoid calling .vec, etc. over and + // over again + if (sparse_input) { + num_features = sparse_input_shape.unaligned_flat()(1); + get_random_feature = [&sparse_input_indices, &sparse_input_values, this]( + int32 input_index, int32* index, float* val) { + const auto sparse_indices = sparse_input_indices.matrix(); + const auto sparse_values = sparse_input_values.vec(); + GetRandomFeatureSparse(sparse_indices, sparse_values, input_index, + index, val); + }; + } else { + num_features = static_cast(input_data.shape().dim_size(1)); + get_random_feature = [&input_data, num_features, this]( + int32 input_index, int32* index, float* val) { + const auto inputs = input_data.tensor(); + GetRandomFeatureDense(inputs, num_features, input_index, index, val); + }; + } + const auto leaves_vec = leaves.unaligned_flat(); const auto node_map = node_to_accumulator.unaligned_flat(); const auto features = split_features.tensor(); @@ -151,8 +240,6 @@ class SampleInputs : public OpKernel { const int32 num_data = static_cast(leaves.shape().dim_size(0)); const int32 num_splits = static_cast( split_features.shape().dim_size(1)); - const int32 num_features = static_cast( - input_data.shape().dim_size(1)); const int32 num_accumulators = static_cast( split_features.shape().dim_size(0)); @@ -234,10 +321,11 @@ class SampleInputs : public OpKernel { for (int split = 0; split < num_splits && num_inits > 0; split++) { if (new_split_feature_rows_flat(output_slot, split) < 0) { VLOG(1) << "Over-writing @ " << output_slot << "," << split; - const int32 index = rng_->Uniform(num_features); + int32 index; + float val; + get_random_feature(i, &index, &val); new_split_feature_rows_flat(output_slot, split) = index; - new_split_threshold_rows_flat(output_slot, split) = - inputs(i, index); + new_split_threshold_rows_flat(output_slot, split) = val; --num_inits; } } diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc index 7db52ec3ca..1f77212d20 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_predictions_op.cc @@ -30,12 +30,17 @@ using tensorforest::LEAF_NODE; using tensorforest::FREE_NODE; using tensorforest::CheckTensorBounds; -using tensorforest::DecideNode; +using tensorforest::DataColumnTypes; using tensorforest::Sum; REGISTER_OP("TreePredictions") .Attr("valid_leaf_threshold: float") .Input("input_data: float") + .Input("sparse_input_indices: int64") + .Input("sparse_input_values: float") + .Input("sparse_input_shape: int64") + .Input("input_spec: int32") + .Input("tree: int32") .Input("tree_thresholds: float") .Input("node_per_class_weights: float") @@ -46,6 +51,11 @@ REGISTER_OP("TreePredictions") input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` gives the j-th feature of the i-th input. + sparse_input_indices: The indices tensor from the SparseTensor input. + sparse_input_values: The values tensor from the SparseTensor input. + sparse_input_shape: The shape tensor from the SparseTensor input. + input_spec: A 1-D tensor containing the type of each column in input_data, + (e.g. continuous float, categorical). tree:= A 2-d int32 tensor. `tree[i][0]` gives the index of the left child of the i-th node, `tree[i][0] + 1` gives the index of the right child of the i-th node, and `tree[i][1]` gives the index of the feature used to @@ -70,10 +80,42 @@ class TreePredictions : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input_data = context->input(0); - - const Tensor& tree_tensor = context->input(1); - const Tensor& tree_thresholds = context->input(2); - const Tensor& node_per_class_weights = context->input(3); + const Tensor& sparse_input_indices = context->input(1); + const Tensor& sparse_input_values = context->input(2); + const Tensor& sparse_input_shape = context->input(3); + const Tensor& input_spec = context->input(4); + const Tensor& tree_tensor = context->input(5); + const Tensor& tree_thresholds = context->input(6); + const Tensor& node_per_class_weights = context->input(7); + + bool sparse_input = (sparse_input_indices.shape().dims() == 2); + + if (sparse_input) { + OP_REQUIRES(context, sparse_input_values.shape().dims() == 1, + errors::InvalidArgument( + "sparse_input_values should be one-dimensional")); + OP_REQUIRES(context, sparse_input_shape.shape().dims() == 1, + errors::InvalidArgument( + "sparse_input_shape should be one-dimensional")); + OP_REQUIRES(context, + sparse_input_indices.shape().dim_size(0) == + sparse_input_values.shape().dim_size(0), + errors::InvalidArgument( + "sparse_input_indices and sparse_input_values should " + "agree on the number of non-zero values")); + OP_REQUIRES(context, + sparse_input_indices.shape().dim_size(1) == + sparse_input_shape.shape().dim_size(0), + errors::InvalidArgument( + "sparse_input_indices and sparse_input_shape should " + "agree on the dimensionality of data points")); + } else { + if (input_data.shape().dim_size(0) > 0) { + OP_REQUIRES(context, input_data.shape().dims() == 2, + errors::InvalidArgument( + "input_data should be two-dimensional")); + } + } OP_REQUIRES(context, tree_tensor.shape().dims() == 2, errors::InvalidArgument( @@ -85,11 +127,6 @@ class TreePredictions : public OpKernel { errors::InvalidArgument( "node_pcw should be two-dimensional")); - if (input_data.shape().dim_size(0) > 0) { - OP_REQUIRES(context, input_data.shape().dims() == 2, - errors::InvalidArgument( - "input_data should be two-dimensional")); - } OP_REQUIRES( context, tree_tensor.shape().dim_size(0) == @@ -102,16 +139,43 @@ class TreePredictions : public OpKernel { // Check tensor bounds. if (!CheckTensorBounds(context, input_data)) return; + if (!CheckTensorBounds(context, sparse_input_indices)) return; + if (!CheckTensorBounds(context, sparse_input_values)) return; + if (!CheckTensorBounds(context, sparse_input_shape)) return; if (!CheckTensorBounds(context, tree_tensor)) return; if (!CheckTensorBounds(context, tree_thresholds)) return; if (!CheckTensorBounds(context, node_per_class_weights)) return; const int32 num_classes = static_cast( node_per_class_weights.shape().dim_size(1)); - const int32 num_data = static_cast( - input_data.shape().dim_size(0)); const int32 num_nodes = static_cast( tree_tensor.shape().dim_size(0)); + int32 num_data; + std::function decide_function; + + if (sparse_input) { + num_data = sparse_input_shape.unaligned_flat()(0); + decide_function = [&sparse_input_indices, &sparse_input_values]( + int32 i, int32 feature, float bias, DataColumnTypes type) { + const auto sparse_indices = sparse_input_indices.matrix(); + const auto sparse_values = sparse_input_values.vec(); + return tensorforest::DecideSparseNode( + sparse_indices, sparse_values, i, feature, bias, type); + }; + } else { + num_data = static_cast(input_data.shape().dim_size(0)); + int32 num_features = 0; + if (num_data > 0) { + num_features = input_data.NumElements() / num_data; + } + decide_function = [&input_data]( + int32 i, int32 feature, float bias, DataColumnTypes type) { + const auto input_matrix = input_data.matrix(); + return tensorforest::DecideDenseNode( + input_matrix, i, feature, bias, type); + }; + } Tensor* output_predictions = nullptr; TensorShape output_shape; @@ -124,10 +188,10 @@ class TreePredictions : public OpKernel { const auto node_pcw = node_per_class_weights.tensor(); const auto tree = tree_tensor.tensor(); + const auto spec = input_spec.unaligned_flat(); const auto thresholds = tree_thresholds.unaligned_flat(); for (int i = 0; i < num_data; i++) { - const Tensor point = input_data.Slice(i, i+1); int node_index = 0; int parent = -1; while (true) { @@ -162,9 +226,11 @@ class TreePredictions : public OpKernel { return; } parent = node_index; + const int32 feature = tree(node_index, FEATURE_INDEX); node_index = left_child + - DecideNode(point, tree(node_index, FEATURE_INDEX), - thresholds(node_index)); + decide_function( + i, feature, thresholds(node_index), + static_cast(spec(feature))); } } diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc index f3fc416055..398990780c 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.cc @@ -13,56 +13,127 @@ // limitations under the License. // ============================================================================= #include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h" +#include +#include "tensorflow/core/platform/logging.h" namespace tensorflow { namespace tensorforest { using tensorflow::Tensor; -int32 BestFeatureClassification( +void GetTwoBest(int max, std::function score_fn, + float *best_score, int *best_index, + float *second_best_score) { + *best_index = -1; + *best_score = FLT_MAX; + *second_best_score = FLT_MAX; + for (int i = 0; i < max; i++) { + float score = score_fn(i); + if (score < *best_score) { + *second_best_score = *best_score; + *best_score = score; + *best_index = i; + } else if (score < *second_best_score) { + *second_best_score = score; + } + } +} + +float ClassificationSplitScore( + const Eigen::Tensor& splits, + const Eigen::Tensor& rights, + int32 num_classes, int i) { + Eigen::array offsets; + offsets[0] = i * num_classes + 1; + Eigen::array extents; + extents[0] = num_classes - 1; + return WeightedGiniImpurity(splits.slice(offsets, extents)) + + WeightedGiniImpurity(rights.slice(offsets, extents)); +} + +void GetTwoBestClassification( const Tensor& total_counts, const Tensor& split_counts, - int32 accumulator) { - int32 best_feature_index = -1; - // We choose the split with the lowest score. - float best_score = kint64max; + int32 accumulator, + float *best_score, int *best_index, + float *second_best_score) { const int32 num_splits = static_cast(split_counts.shape().dim_size(1)); const int32 num_classes = static_cast( split_counts.shape().dim_size(2)); + // Ideally, Eigen::Tensor::chip would be best to use here but it results // in seg faults, so we have to go with flat views of these tensors. However, // it is still pretty efficient because we put off evaluation until the // score is actually returned. const auto tc = total_counts.Slice( accumulator, accumulator + 1).unaligned_flat(); - const auto splits = split_counts.Slice( + + // TODO(gilberth): See if we can delay evaluation here by templating the + // arguments to ClassificationSplitScore. + const Eigen::Tensor splits = split_counts.Slice( accumulator, accumulator + 1).unaligned_flat(); Eigen::array bcast; bcast[0] = num_splits; - const auto rights = tc.broadcast(bcast) - splits; - - for (int i = 0; i < num_splits; i++) { - Eigen::array offsets; - offsets[0] = i * num_classes; - Eigen::array extents; - extents[0] = num_classes; - float score = WeightedGiniImpurity(splits.slice(offsets, extents)) + - WeightedGiniImpurity(rights.slice(offsets, extents)); - - if (score < best_score) { - best_score = score; - best_feature_index = i; - } - } + const Eigen::Tensor rights = + tc.broadcast(bcast) - splits; + + std::function score_fn = std::bind( + ClassificationSplitScore, splits, rights, num_classes, + std::placeholders::_1); + + GetTwoBest( + num_splits, score_fn, + best_score, best_index, second_best_score); +} + +int32 BestFeatureClassification( + const Tensor& total_counts, const Tensor& split_counts, + int32 accumulator) { + float best_score; + float second_best_score; + int best_feature_index; + GetTwoBestClassification( + total_counts, split_counts, accumulator, + &best_score, &best_feature_index, &second_best_score); return best_feature_index; } -int32 BestFeatureRegression( +float RegressionSplitScore( + const Eigen::Tensor& splits_count_accessor, + const Eigen::Tensor& totals_count_accessor, + const Eigen::Tensor& splits_sum, + const Eigen::Tensor& splits_square, + const Eigen::Tensor& right_sums, + const Eigen::Tensor& right_squares, + int32 accumulator, + int32 num_regression_dims, int i) { + Eigen::array offsets = {i * num_regression_dims + 1}; + Eigen::array extents = {num_regression_dims - 1}; + float left_count = splits_count_accessor(accumulator, i, 0); + float right_count = totals_count_accessor(accumulator, 0) - left_count; + + float score = 0; + + // Guard against divide-by-zero. + if (left_count > 0) { + score += WeightedVariance( + splits_sum.slice(offsets, extents), + splits_square.slice(offsets, extents), left_count); + } + + if (right_count > 0) { + score += WeightedVariance(right_sums.slice(offsets, extents), + right_squares.slice(offsets, extents), + right_count); + } + return score; +} + +void GetTwoBestRegression( const Tensor& total_sums, const Tensor& total_squares, const Tensor& split_sums, const Tensor& split_squares, - int32 accumulator) { - int32 best_feature_index = -1; - // We choose the split with the lowest score. - float best_score = kint64max; + int32 accumulator, + float *best_score, int *best_index, + float *second_best_score) { const int32 num_splits = static_cast(split_sums.shape().dim_size(1)); const int32 num_regression_dims = static_cast( split_sums.shape().dim_size(2)); @@ -90,43 +161,138 @@ int32 BestFeatureRegression( const auto right_sums = tc_sum.broadcast(bcast) - splits_sum; const auto right_squares = tc_square.broadcast(bcast) - splits_square; - for (int i = 0; i < num_splits; i++) { - Eigen::array offsets; - offsets[0] = i * num_regression_dims; - Eigen::array extents; - extents[0] = num_regression_dims; - float left_count = splits_count_accessor(accumulator, i, 0); - float right_count = totals_count_accessor(accumulator, 0) - left_count; + GetTwoBest( + num_splits, + std::bind(RegressionSplitScore, + splits_count_accessor, totals_count_accessor, + splits_sum, splits_square, right_sums, right_squares, + accumulator, num_regression_dims, std::placeholders::_1), + best_score, best_index, second_best_score); +} - float score = 0; +int32 BestFeatureRegression( + const Tensor& total_sums, const Tensor& total_squares, + const Tensor& split_sums, const Tensor& split_squares, + int32 accumulator) { + float best_score; + float second_best_score; + int best_feature_index; + GetTwoBestRegression( + total_sums, total_squares, split_sums, split_squares, accumulator, + &best_score, &best_feature_index, &second_best_score); + return best_feature_index; +} - // Guard against divide-by-zero. - if (left_count > 0) { - score += WeightedVariance( - splits_sum.slice(offsets, extents), - splits_square.slice(offsets, extents), left_count); - } - if (right_count > 0) { - score += WeightedVariance(right_sums.slice(offsets, extents), - right_squares.slice(offsets, extents), - right_count); +bool BestSplitDominatesRegression( + const Tensor& total_sums, const Tensor& total_squares, + const Tensor& split_sums, const Tensor& split_squares, + int32 accumulator) { + // TODO(thomaswc): Implement this, probably as part of v3. + return false; +} + +bool BestSplitDominatesClassification( + const Tensor& total_counts, + const Tensor& split_counts, int32 accumulator, + float dominate_fraction) { + float best_score; + float second_best_score; + int best_feature_index; + GetTwoBestClassification( + total_counts, split_counts, accumulator, + &best_score, &best_feature_index, &second_best_score); + + // Total counts are stored in the first column. + const int32 num_classes = split_counts.shape().dim_size(2) - 1; + + // total_class_counts(c) is the # of class c examples seen by this + // accumulator. + auto total_class_counts = total_counts.Slice( + accumulator, accumulator + 1).unaligned_flat(); + + const Eigen::Tensor splits = split_counts.Slice( + accumulator, accumulator + 1).unaligned_flat(); + + // For some reason, Eigen is fine with offsets being an array in + // ClassificationSplitScore, but it demands an array here. + const Eigen::array offsets = + {num_classes * best_feature_index}; + const Eigen::array extents = {num_classes}; + + const Eigen::Tensor left_counts = + splits.slice(offsets, extents); + // I can find no other way using Eigen to copy a const Tensor into a + // non-const Tensor. + Eigen::Tensor left_counts_copy(num_classes+1); + for (int i = 0; i <= num_classes; i++) { + left_counts_copy(i) = left_counts(i); + } + + Eigen::Tensor right_counts_copy = + total_class_counts - left_counts_copy; + + // "Reverse-jackknife" estimate of how often the chosen best split is + // truly better than the second best split. We use the reverse jackknife + // (in which counts are incremented) rather than the normal jackknife + // (in which counts are decremented) because the later badly underestimates + // the score variance of perfect splits. + float better_count = 0.0; + float worse_count = 0.0; + for (int i = 1; i <= num_classes; i++) { + left_counts_copy(i) += 1.0; + float weight = left_counts_copy(i); + float v = WeightedGiniImpurity(left_counts_copy) + + WeightedGiniImpurity(right_counts_copy); + left_counts_copy(i) -= 1.0; + if (v < second_best_score) { + better_count += weight; + } else { + worse_count += weight; } - if (score < best_score) { - best_score = score; - best_feature_index = i; + right_counts_copy(i) += 1.0; + weight = right_counts_copy(i); + v = WeightedGiniImpurity(left_counts) + + WeightedGiniImpurity(right_counts_copy); + right_counts_copy(i) -= 1.0; + if (v < second_best_score) { + better_count += weight; + } else { + worse_count += weight; } } - return best_feature_index; + + VLOG(1) << "Better count = " << better_count; + VLOG(1) << "Worse count = " << worse_count; + return better_count > dominate_fraction * (better_count + worse_count); } -bool DecideNode(const Tensor& point, int32 feature, float bias) { + +bool DecideNode(const Tensor& point, int32 feature, float bias, + DataColumnTypes type) { const auto p = point.unaligned_flat(); CHECK_LT(feature, p.size()); - return p(feature) > bias; + return Decide(p(feature), bias, type); +} + + +bool Decide(float value, float bias, DataColumnTypes type) { + switch (type) { + case kDataFloat: + return value > bias; + + case kDataCategorical: + // We arbitrarily define categorical equality as going left. + return value != bias; + + default: + LOG(ERROR) << "Got unknown column type: " << type; + return false; + } } + bool IsAllInitialized(const Tensor& features) { const auto feature_vec = features.unaligned_flat(); return feature_vec(feature_vec.size() - 1) >= 0; diff --git a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h index 19b02e379e..067f0768d3 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h +++ b/tensorflow/contrib/tensor_forest/core/ops/tree_utils.h @@ -19,6 +19,7 @@ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -26,6 +27,7 @@ namespace tensorflow { namespace tensorforest { +// TODO(gilberth): Put these in protos so they can be shared by C++ and python. // Indexes in the tree representation's 2nd dimension for children and features. const int32 CHILDREN_INDEX = 0; const int32 FEATURE_INDEX = 1; @@ -34,6 +36,14 @@ const int32 FEATURE_INDEX = 1; const int32 LEAF_NODE = -1; const int32 FREE_NODE = -2; +// Used to indicate column types, e.g. categorical vs. float +enum DataColumnTypes { + kDataFloat = 0, + kDataCategorical = 1 +}; + + + // Calculates the sum of a tensor. template T Sum(Tensor counts) { @@ -80,6 +90,20 @@ int32 BestFeatureRegression(const Tensor& total_sums, const Tensor& split_sums, const Tensor& split_squares, int32 accumulator); +// Returns true if the best split's variance is sufficiently smaller than +// that of the next best split. +bool BestSplitDominatesRegression( + const Tensor& total_sums, const Tensor& total_squares, + const Tensor& split_sums, const Tensor& split_squares, + int32 accumulator); + +// Returns true if the best split's Gini impurity is sufficiently smaller than +// that of the next best split. +bool BestSplitDominatesClassification( + const Tensor& total_counts, + const Tensor& split_counts, int32 accumulator, + float dominate_fraction); + // Initializes everything in the given tensor to the given value. template void Initialize(Tensor counts, T val = 0) { @@ -90,7 +114,74 @@ void Initialize(Tensor counts, T val = 0) { // Returns true if the point falls to the right (i.e., the selected feature // of the input point is greater than the bias threshold), and false if it // falls to the left. -bool DecideNode(const Tensor& point, int32 feature, float bias); +// Even though our input data is forced into float Tensors, it could have +// originally been something else (e.g. categorical string data) which +// we treat differently. +bool DecideNode(const Tensor& point, int32 feature, float bias, + DataColumnTypes type = kDataFloat); + +// Returns input_data(i, feature) > bias. +template +bool DecideDenseNode(const T& input_data, + int32 i, int32 feature, float bias, + DataColumnTypes type = kDataFloat) { + CHECK_LT(i, input_data.dimensions()[0]); + CHECK_LT(feature, input_data.dimensions()[1]); + return Decide(input_data(i, feature), bias, type); +} + +// If T is a sparse float matrix represented by sparse_input_indices and +// sparse_input_values, FindSparseValue returns T(i,j), or 0.0 if (i,j) +// isn't present in sparse_input_indices. sparse_input_indices is assumed +// to be sorted. +template +float FindSparseValue( + const T1& sparse_input_indices, + const T2& sparse_input_values, + int32 i, int32 j) { + int32 low = 0; + int32 high = sparse_input_values.dimension(0); + while (low < high) { + int32 mid = (low + high) / 2; + int64 midi = internal::SubtleMustCopy(sparse_input_indices(mid, 0)); + int64 midj = internal::SubtleMustCopy(sparse_input_indices(mid, 1)); + if (midi == i) { + if (midj == j) { + return sparse_input_values(mid); + } + if (midj < j) { + low = mid + 1; + } else { + high = mid; + } + continue; + } + if (midi < i) { + low = mid + 1; + } else { + high = mid; + } + } + return 0.0; +} + +// Returns t(i, feature) > bias, where t is the sparse tensor represented by +// sparse_input_indices and sparse_input_values. +template +bool DecideSparseNode( + const T1& sparse_input_indices, + const T2& sparse_input_values, + int32 i, int32 feature, float bias, + DataColumnTypes type = kDataFloat) { + return Decide( + FindSparseValue(sparse_input_indices, sparse_input_values, i, feature), + bias, type); +} + +// Returns left/right decision between the input value and the threshold bias. +// For floating point types, the decision is value > bias, but for +// categorical data, it is value != bias. +bool Decide(float value, float bias, DataColumnTypes type = kDataFloat); // Returns true if all the splits are initialized. Since they get initialized // in order, we can simply infer this from the last split. diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc index 026262e47f..33638ca7e6 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc @@ -36,7 +36,7 @@ using tensorforest::Initialize; using tensorforest::WeightedGiniImpurity; REGISTER_OP("UpdateFertileSlots") - .Attr("max_depth: int") + .Attr("max_depth: int") .Attr("regression: bool = False") .Input("finished: int32") .Input("non_fertile_leaves: int32") @@ -45,11 +45,10 @@ REGISTER_OP("UpdateFertileSlots") .Input("tree_depths: int32") .Input("accumulator_sums: float") .Input("node_to_accumulator: int32") + .Input("stale_leaves: int32") .Output("node_map_updates: int32") .Output("accumulators_cleared: int32") .Output("accumulators_allocated: int32") - .Output("new_nonfertile_leaves: int32") - .Output("new_nonfertile_leaves_scores: float") .Doc(R"doc( Updates accumulator slots to reflect finished or newly fertile nodes. @@ -77,6 +76,8 @@ accumulator_sums: For classification, `accumulator_sums[a][c]` records how of training examples that have been seen. node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by fertile node i, or -1 if node i isn't fertile. +stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that + have stopped accumulating statistics because they are too old. node_map_updates:= A 2-d int32 tensor describing the changes that need to be applied to the node_to_accumulator map. Intended to be used with `tf.scatter_update(node_to_accumulator, @@ -86,10 +87,7 @@ accumulators_cleared:= A 1-d int32 tensor containing the indices of all the accumulator slots that need to be cleared. accumulators_allocated:= A 1-d int32 tensor containing the indices of all the accumulator slots that need to be allocated. -new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the - leaves that are now non-fertile. -new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the - splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`. + )doc"); class UpdateFertileSlots : public OpKernel { @@ -112,6 +110,7 @@ class UpdateFertileSlots : public OpKernel { const Tensor& accumulator_sums = context->input(5); const Tensor& node_to_accumulator = context->input(6); + const Tensor& stale_leaves = context->input(7); OP_REQUIRES(context, finished.shape().dims() == 1, errors::InvalidArgument( @@ -134,6 +133,9 @@ class UpdateFertileSlots : public OpKernel { OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1, errors::InvalidArgument( "node_to_accumulator should be one-dimensional")); + OP_REQUIRES(context, stale_leaves.shape().dims() == 1, + errors::InvalidArgument( + "stale_leaves should be one-dimensional")); OP_REQUIRES( context, @@ -151,6 +153,7 @@ class UpdateFertileSlots : public OpKernel { if (!CheckTensorBounds(context, tree_depths)) return; if (!CheckTensorBounds(context, accumulator_sums)) return; if (!CheckTensorBounds(context, node_to_accumulator)) return; + if (!CheckTensorBounds(context, stale_leaves)) return; // Read finished accumulators into a set for quick lookup. const auto node_map = node_to_accumulator.unaligned_flat(); @@ -164,6 +167,16 @@ class UpdateFertileSlots : public OpKernel { errors::InvalidArgument("finished node is outside the valid range")); finished_accumulators.insert(node_map(node)); } + // Stale accumulators are also finished for the purposes of clearing + // and re-allocating. + const auto stale_vec = stale_leaves.unaligned_flat(); + for (int32 i = 0; i < stale_vec.size(); ++i) { + const int32 node = internal::SubtleMustCopy(stale_vec(i)); + OP_REQUIRES( + context, FastBoundsCheck(node, node_map.size()), + errors::InvalidArgument("stale node is outside the valid range")); + finished_accumulators.insert(node_map(node)); + } // Construct leaf heap to sort leaves to allocate accumulators to. const int32 num_nodes = static_cast(tree_depths.shape().dim_size(0)); @@ -210,11 +223,10 @@ class UpdateFertileSlots : public OpKernel { } // Construct and fill outputs. - SetNodeMapUpdates(accumulators_to_node, finished, context); + SetNodeMapUpdates(accumulators_to_node, finished, stale_leaves, context); SetAccumulatorsCleared(finished_accumulators, accumulators_to_node, context); SetAccumulatorsAllocated(accumulators_to_node, context); - SetNewNonFertileLeaves(values.get(), i, context); } private: @@ -228,18 +240,20 @@ class UpdateFertileSlots : public OpKernel { typedef TopN, OrderBySecondGreater> LeafHeapType; typedef std::vector> HeapValuesType; - // Creates an update tensor for node to accumulator map. Sets finished nodes - // to -1 (no accumulator assigned) and newly allocated nodes to their - // accumulator. + // Creates an update tensor for node to accumulator map. Sets finished and + // stale nodes to -1 (no accumulator assigned) and newly allocated nodes to + // their accumulator. void SetNodeMapUpdates( const std::unordered_map& accumulators_to_node, - const Tensor& finished, OpKernelContext* context) { + const Tensor& finished, const Tensor& stale, OpKernelContext* context) { // Node map updates. Tensor* output_node_map = nullptr; TensorShape node_map_shape; node_map_shape.AddDim(2); - node_map_shape.AddDim(accumulators_to_node.size() + - static_cast(finished.shape().dim_size(0))); + node_map_shape.AddDim( + accumulators_to_node.size() + + static_cast(stale.shape().dim_size(0) + + finished.shape().dim_size(0))); OP_REQUIRES_OK(context, context->allocate_output(0, node_map_shape, &output_node_map)); @@ -254,6 +268,13 @@ class UpdateFertileSlots : public OpKernel { out_node(1, output_slot) = -1; ++output_slot; } + // Set stale nodes to -1. + const auto stale_vec = stale.unaligned_flat(); + for (int32 i = 0; i < stale_vec.size(); ++i) { + out_node(0, output_slot) = stale_vec(i); + out_node(1, output_slot) = -1; + ++output_slot; + } // Set newly allocated nodes to their allocator. for (const auto& node_alloc_pair : accumulators_to_node) { @@ -315,56 +336,6 @@ class UpdateFertileSlots : public OpKernel { } } - // Creates output tensors for non-fertile leaves and non-fertile leaf scores. - // Start indicates the index in values where the leaves that weren't - // allocated this round begin, and should thus be placed in the new - // nonfertile_leaves tensors. - void SetNewNonFertileLeaves(HeapValuesType* values, int32 start, - OpKernelContext* context) { - // Node map updates. - int32 num_values = static_cast(values->size()) - start; - - // Unfortunately, a zero-sized Variable results in an uninitialized - // error, probably because they check for zero size instead of - // a real inititalization condition. - bool fill_with_garbage = false; - if (num_values == 0) { - num_values = 1; - fill_with_garbage = true; - } - Tensor* output_nonfertile_leaves = nullptr; - TensorShape nonfertile_leaves_shape; - nonfertile_leaves_shape.AddDim(num_values); - OP_REQUIRES_OK(context, - context->allocate_output(3, nonfertile_leaves_shape, - &output_nonfertile_leaves)); - - auto out_nonfertile_leaves = - output_nonfertile_leaves->unaligned_flat(); - - Tensor* output_nonfertile_leaves_scores = nullptr; - TensorShape nonfertile_leaves_scores_shape; - nonfertile_leaves_scores_shape.AddDim(num_values); - OP_REQUIRES_OK(context, - context->allocate_output(4, nonfertile_leaves_scores_shape, - &output_nonfertile_leaves_scores)); - - auto out_nonfertile_leaves_scores = - output_nonfertile_leaves_scores->unaligned_flat(); - - if (fill_with_garbage) { - out_nonfertile_leaves(0) = -1; - out_nonfertile_leaves_scores(0) = 0.0; - return; - } - - for (int32 i = start; i < values->size(); ++i) { - const std::pair& node = (*values)[i]; - out_nonfertile_leaves(i -start) = node.first; - out_nonfertile_leaves_scores(i - start) = node.second; - } - } - void ConstructLeafHeap(const Tensor& non_fertile_leaves, const Tensor& non_fertile_leaf_scores, const Tensor& tree_depths, int32 end_of_tree, diff --git a/tensorflow/contrib/tensor_forest/data/__init__.py b/tensorflow/contrib/tensor_forest/data/__init__.py new file mode 100644 index 0000000000..3d04705878 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/data/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Random forest implementation in tensorflow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.tensor_forest.data import data_ops diff --git a/tensorflow/contrib/tensor_forest/data/data_ops.py b/tensorflow/contrib/tensor_forest/data/data_ops.py new file mode 100644 index 0000000000..ca229f4ce9 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/data/data_ops.py @@ -0,0 +1,109 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for preprocessing data.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.contrib.tensor_forest.python import constants + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + +DATA_OPS_FILE = '_data_ops.so' + +_data_ops = None +_ops_lock = threading.Lock() + + +ops.NoGradient('StringToFloat') + + +@ops.RegisterShape('StringToFloat') +def StringToFloatShape(op): + """Shape function for StringToFloat Op.""" + return [op.inputs[0].get_shape()] + + +# Workaround for the fact that importing tensorflow imports contrib +# (even if a user isn't using this or any other contrib op), but +# there's not yet any guarantee that the shared object exists. +# In which case, "import tensorflow" will always crash, even for users that +# never use contrib. +def Load(): + """Load the data ops library and return the loaded module.""" + with _ops_lock: + global _data_ops + if not _data_ops: + ops_path = resource_loader.get_path_to_datafile(DATA_OPS_FILE) + logging.info('data path: %s', ops_path) + _data_ops = load_library.load_op_library(ops_path) + + assert _data_ops, 'Could not load _data_ops.so' + return _data_ops + + +def ParseDataTensorOrDict(data): + """Return a tensor to use for input data. + + The incoming features can be a dict where keys are the string names of the + columns, which we turn into a single 2-D tensor. + + Args: + data: `Tensor` or `dict` of `Tensor` objects. + + Returns: + A 2-D tensor for input to tensor_forest and a 1-D tensor of the + type of each column (e.g. continuous float, categorical). + """ + convert_ops = Load() + if isinstance(data, dict): + data_spec = [constants.DATA_CATEGORICAL if data[k].dtype == dtypes.string + else constants.DATA_FLOAT + for k in sorted(data.keys())] + return array_ops.concat(1, [ + convert_ops.string_to_float(data[k]) + if data[k].dtype == dtypes.string else data[k] + for k in sorted(data.keys())]), data_spec + else: + return data, [constants.DATA_FLOAT] * data.get_shape().as_list()[1] + + +def ParseLabelTensorOrDict(labels): + """Return a tensor to use for input labels to tensor_forest. + + The incoming targets can be a dict where keys are the string names of the + columns, which we turn into a single 1-D tensor for classification or + 2-D tensor for regression. + + Args: + labels: `Tensor` or `dict` of `Tensor` objects. + + Returns: + A 2-D tensor for labels/outputs. + """ + if isinstance(labels, dict): + return math_ops.to_float(array_ops.concat( + 1, [labels[k] for k in sorted(labels.keys())])) + else: + return math_ops.to_float(labels) diff --git a/tensorflow/contrib/tensor_forest/data/string_to_float_op.cc b/tensorflow/contrib/tensor_forest/data/string_to_float_op.cc new file mode 100644 index 0000000000..3908855063 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/data/string_to_float_op.cc @@ -0,0 +1,111 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// Converts strings of arbitrary length to float values by +// hashing and cramming bits. +#include + +#include "tensorflow/contrib/tensor_forest/core/ops/tree_utils.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +using tensorforest::CheckTensorBounds; + + +float Convert(const string& in) { + const std::size_t intval = std::hash()(in); + return static_cast(intval); +} + + +void Evaluate(const Tensor& input_data, Tensor output_data, + int32 start, int32 end) { + auto out_data = output_data.tensor(); + const auto in_data = input_data.tensor(); + + for (int32 i = start; i < end; ++i) { + for (int32 j = 0; j < output_data.dim_size(1); ++j) { + out_data(i, j) = Convert(in_data(i, j)); + } + } +} + + +REGISTER_OP("StringToFloat") + .Input("input_data: string") + .Output("output_data: float") + + .Doc(R"doc( + Converts byte arrays represented by strings to 32-bit + floating point numbers. The output numbers themselves are meaningless, and + should only be used in == comparisons. + + input_data: A batch of string features as a 2-d tensor; `input_data[i][j]` + gives the j-th feature of the i-th input. + output_data: A tensor of the same shape as input_data but the values are + float32. + +)doc"); + +class StringToFloat : public OpKernel { + public: + explicit StringToFloat(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor& input_data = context->input(0); + + // Check inputs. + OP_REQUIRES(context, input_data.shape().dims() == 2, + errors::InvalidArgument( + "input_data should be two-dimensional")); + + // Check tensor bounds. + if (!CheckTensorBounds(context, input_data)) return; + + Tensor* output_data = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_data.shape(), + &output_data)); + + // Evaluate input data in parallel. + const int32 num_data = static_cast(input_data.shape().dim_size(0)); + auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); + int num_threads = worker_threads->num_threads; + if (num_threads <= 1) { + Evaluate(input_data, *output_data, 0, num_data); + } else { + auto work = [&input_data, output_data, num_data](int64 start, int64 end) { + CHECK(start <= end); + CHECK(end <= num_data); + Evaluate(input_data, *output_data, + static_cast(start), static_cast(end)); + }; + Shard(num_threads, worker_threads->workers, num_data, 100, work); + } + } +}; + + +REGISTER_KERNEL_BUILDER(Name("StringToFloat").Device(DEVICE_CPU), + StringToFloat); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/python/__init__.py b/tensorflow/contrib/tensor_forest/python/__init__.py index 0f692bbe97..a9dd599c97 100644 --- a/tensorflow/contrib/tensor_forest/python/__init__.py +++ b/tensorflow/contrib/tensor_forest/python/__init__.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.contrib.tensor_forest.python.ops import training_ops diff --git a/tensorflow/contrib/tensor_forest/python/constants.py b/tensorflow/contrib/tensor_forest/python/constants.py new file mode 100644 index 0000000000..029c782461 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/python/constants.py @@ -0,0 +1,26 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants used by tensorforest. Some of these map to values in C++ ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# If tree[i][0] equals this value, then i is a leaf node. +LEAF_NODE = -1 + +# Data column types for indicating categorical or other non-float values. +DATA_FLOAT = 0 +DATA_CATEGORICAL = 1 diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py index c5b5981adb..3641ab0ee0 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/best_splits_op_test.py @@ -30,14 +30,16 @@ class BestSplitsClassificationTests(test_util.TensorFlowTestCase): def setUp(self): self.finished = [3, 5] self.node_map = [-1, -1, -1, 0, -1, 3, -1, -1, -1] - self.candidate_counts = [[[50., 60., 40., 3.], [70., 30., 70., 30.]], - [[0., 0., 0., 0.], [0., 0., 0., 0.]], - [[0., 0., 0., 0.], [0., 0., 0., 0.]], - [[10., 10., 10., 10.], [10., 5., 5., 10.]]] - self.total_counts = [[100., 100., 100., 100.], - [0., 0., 0., 0.], - [0., 0., 0., 0.], - [100., 100., 100., 100.]] + self.candidate_counts = [[[153., 50., 60., 40., 3.], + [200., 70., 30., 70., 30.]], + [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], + [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], + [[40., 10., 10., 10., 10.], + [30., 10., 5., 5., 10.]]] + self.total_counts = [[400., 100., 100., 100., 100.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [400., 100., 100., 100., 100.]] self.squares = [] self.ops = training_ops.Load() diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py index eb61573f24..a50eb22795 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python.ops import training_ops from tensorflow.python.framework import test_util @@ -37,16 +38,20 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase): self.split_features = [[1], [-1]] self.split_thresholds = [[1.], [0.]] self.ops = training_ops.Load() + self.epochs = [0, 1, 1] + self.current_epoch = [1] + self.data_spec = [constants.DATA_FLOAT] * 2 def testSimple(self): with self.test_session(): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = ( self.ops.count_extremely_random_stats( - self.input_data, self.input_labels, self.tree, - self.tree_thresholds, self.node_map, - self.split_features, self.split_thresholds, num_classes=5, - regression=False)) + self.input_data, [], [], [], self.data_spec, self.input_labels, + self.tree, self.tree_thresholds, self.node_map, + self.split_features, self.split_thresholds, self.epochs, + self.current_epoch, + num_classes=5, regression=False)) self.assertAllEqual( [[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]], @@ -57,15 +62,68 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase): self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) self.assertAllEqual([1, 1, 2, 2], leaves.eval()) + def testSparseInput(self): + sparse_shape = [4, 10] + sparse_indices = [[0, 0], [0, 4], [0, 9], + [1, 0], [1, 7], + [2, 0], + [3, 1], [3, 4]] + sparse_values = [3.0, -1.0, 0.5, + 1.5, 6.0, + -2.0, + -0.5, 2.0] + with self.test_session(): + (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, + pcw_totals_indices, pcw_totals_sums, _, leaves) = ( + self.ops.count_extremely_random_stats( + [], sparse_indices, sparse_values, sparse_shape, self.data_spec, + self.input_labels, self.tree, + self.tree_thresholds, self.node_map, + self.split_features, self.split_thresholds, self.epochs, + self.current_epoch, + num_classes=5, regression=False)) + + self.assertAllEqual( + [[4., 1., 1., 1., 1.], + [2., 0., 0., 1., 1.], + [2., 1., 1., 0., 0.]], + pcw_node_sums.eval()) + self.assertAllEqual([[0, 0, 4], [0, 0, 0], [0, 0, 3]], + pcw_splits_indices.eval()) + self.assertAllEqual([1., 2., 1.], pcw_splits_sums.eval()) + self.assertAllEqual([[0, 4], [0, 0], [0, 3]], pcw_totals_indices.eval()) + self.assertAllEqual([1., 2., 1.], pcw_totals_sums.eval()) + self.assertAllEqual([2, 2, 1, 1], leaves.eval()) + + def testFutureEpoch(self): + current_epoch = [3] + with self.test_session(): + (pcw_node_sums, _, _, pcw_splits_sums, _, + _, pcw_totals_sums, _, leaves) = ( + self.ops.count_extremely_random_stats( + self.input_data, [], [], [], self.data_spec, self.input_labels, + self.tree, self.tree_thresholds, self.node_map, + self.split_features, self.split_thresholds, self.epochs, + current_epoch, num_classes=5, regression=False)) + + self.assertAllEqual( + [[0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.], [0., 0., 0., 0., 0.]], + pcw_node_sums.eval()) + self.assertAllEqual([], pcw_splits_sums.eval()) + self.assertAllEqual([], pcw_totals_sums.eval()) + self.assertAllEqual([1, 1, 2, 2], leaves.eval()) + def testThreaded(self): with self.test_session( config=tf.ConfigProto(intra_op_parallelism_threads=2)): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = ( self.ops.count_extremely_random_stats( - self.input_data, self.input_labels, self.tree, - self.tree_thresholds, self.node_map, self.split_features, - self.split_thresholds, num_classes=5, regression=False)) + self.input_data, [], [], [], self.data_spec, self.input_labels, + self.tree, self.tree_thresholds, self.node_map, + self.split_features, + self.split_thresholds, self.epochs, self.current_epoch, + num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]], @@ -81,10 +139,10 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase): (pcw_node_sums, _, pcw_splits_indices, pcw_splits_sums, _, pcw_totals_indices, pcw_totals_sums, _, leaves) = ( self.ops.count_extremely_random_stats( - self.input_data, self.input_labels, self.tree, - self.tree_thresholds, [-1] * 3, - self.split_features, self.split_thresholds, num_classes=5, - regression=False)) + self.input_data, [], [], [], self.data_spec, self.input_labels, + self.tree, self.tree_thresholds, [-1] * 3, + self.split_features, self.split_thresholds, self.epochs, + self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([[4., 1., 1., 1., 1.], [2., 1., 1., 0., 0.], [2., 0., 0., 1., 1.]], @@ -101,13 +159,13 @@ class CountExtremelyRandomStatsClassificationTest(test_util.TensorFlowTestCase): with self.test_session(): with self.assertRaisesOpError( 'Number of nodes should be the same in ' - 'tree, tree_thresholds, and node_to_accumulator'): + 'tree, tree_thresholds, node_to_accumulator, and birth_epoch.'): pcw_node, _, _, _, _, _, _, _, _ = ( self.ops.count_extremely_random_stats( - self.input_data, self.input_labels, self.tree, - self.tree_thresholds, self.node_map, - self.split_features, self.split_thresholds, num_classes=5, - regression=False)) + self.input_data, [], [], [], self.data_spec, self.input_labels, + self.tree, self.tree_thresholds, self.node_map, + self.split_features, self.split_thresholds, self.epochs, + self.current_epoch, num_classes=5, regression=False)) self.assertAllEqual([], pcw_node.eval()) @@ -124,6 +182,9 @@ class CountExtremelyRandomStatsRegressionTest(test_util.TensorFlowTestCase): self.split_features = [[1], [-1]] self.split_thresholds = [[1.], [0.]] self.ops = training_ops.Load() + self.epochs = [0, 1, 1] + self.current_epoch = [1] + self.data_spec = [constants.DATA_FLOAT] * 2 def testSimple(self): with self.test_session(): @@ -131,10 +192,10 @@ class CountExtremelyRandomStatsRegressionTest(test_util.TensorFlowTestCase): pcw_splits_squares, pcw_totals_indices, pcw_totals_sums, pcw_totals_squares, leaves) = ( self.ops.count_extremely_random_stats( - self.input_data, self.input_labels, self.tree, - self.tree_thresholds, self.node_map, - self.split_features, self.split_thresholds, num_classes=2, - regression=True)) + self.input_data, [], [], [], self.data_spec, self.input_labels, + self.tree, self.tree_thresholds, self.node_map, + self.split_features, self.split_thresholds, self.epochs, + self.current_epoch, num_classes=2, regression=True)) self.assertAllEqual( [[4., 14.], [2., 9.], [2., 5.]], pcw_node_sums.eval()) diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py index 24fbe2c11d..222ef2b2eb 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/finished_nodes_op_test.py @@ -30,35 +30,71 @@ class FinishedNodesTest(test_util.TensorFlowTestCase): def setUp(self): self.leaves = [1, 3, 4] self.node_map = [-1, -1, -1, 0, 1, -1] - self.pcw_total_splits = [[6, 3, 3], [11, 4, 7], [0, 0, 0], [0, 0, 0], + self.split_sums = [ + # Accumulator 1 + [[3, 0, 3], [2, 1, 1], [3, 1, 2]], + # Accumulator 2 + [[6, 3, 3], [6, 2, 4], [5, 0, 5]], + # Accumulator 3 + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + # Accumulator 4 + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + # Accumulator 5 + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ] + self.split_squares = [] + self.accumulator_sums = [[6, 3, 3], [11, 4, 7], [0, 0, 0], [0, 0, 0], [0, 0, 0]] + self.accumulator_squares = [] self.ops = training_ops.Load() + self.birth_epochs = [0, 0, 0, 1, 1, 1] + self.current_epoch = [1] def testSimple(self): with self.test_session(): - finished = self.ops.finished_nodes(self.leaves, self.node_map, - self.pcw_total_splits, - num_split_after_samples=10) + finished, stale = self.ops.finished_nodes( + self.leaves, self.node_map, self.split_sums, + self.split_squares, self.accumulator_sums, self.accumulator_squares, + self.birth_epochs, self.current_epoch, + regression=False, num_split_after_samples=10, min_split_samples=10) self.assertAllEqual([4], finished.eval()) + self.assertAllEqual([], stale.eval()) def testNoAccumulators(self): with self.test_session(): - finished = self.ops.finished_nodes(self.leaves, [-1] * 6, - self.pcw_total_splits, - num_split_after_samples=10) + finished, stale = self.ops.finished_nodes( + self.leaves, [-1] * 6, self.split_sums, + self.split_squares, self.accumulator_sums, self.accumulator_squares, + self.birth_epochs, self.current_epoch, + regression=False, num_split_after_samples=10, min_split_samples=10) self.assertAllEqual([], finished.eval()) + self.assertAllEqual([], stale.eval()) def testBadInput(self): with self.test_session(): with self.assertRaisesOpError( 'leaf_tensor should be one-dimensional'): - finished = self.ops.finished_nodes([self.leaves], self.node_map, - self.pcw_total_splits, - num_split_after_samples=10) + finished, stale = self.ops.finished_nodes( + [self.leaves], self.node_map, self.split_sums, + self.split_squares, self.accumulator_sums, self.accumulator_squares, + self.birth_epochs, self.current_epoch, + regression=False, num_split_after_samples=10, min_split_samples=10) self.assertAllEqual([], finished.eval()) + self.assertAllEqual([], stale.eval()) + + def testEarlyDominates(self): + with self.test_session(): + finished, stale = self.ops.finished_nodes( + self.leaves, self.node_map, self.split_sums, + self.split_squares, self.accumulator_sums, self.accumulator_squares, + self.birth_epochs, self.current_epoch, + regression=False, num_split_after_samples=10, min_split_samples=5) + + self.assertAllEqual([4], finished.eval()) + self.assertAllEqual([], stale.eval()) if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py index 0bbd94a2a4..9830651a5d 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/sample_inputs_op_test.py @@ -41,7 +41,8 @@ class SampleInputsTest(test_util.TensorFlowTestCase): tf.initialize_all_variables().run() indices, feature_updates, threshold_updates = ( self.ops.sample_inputs( - self.input_data, self.node_map, self.leaves, self.split_features, + self.input_data, [], [], [], + self.node_map, self.leaves, self.split_features, self.split_thresholds, split_initializations_per_input=1, split_sampling_random_seed=3)) self.assertAllEqual([1, 0], indices.eval()) @@ -50,12 +51,38 @@ class SampleInputsTest(test_util.TensorFlowTestCase): self.assertAllEqual([[5., -2., 50.], [-1., -10., 0.]], threshold_updates.eval()) + def testSparse(self): + sparse_shape = [4, 10] + sparse_indices = [[0, 0], [0, 4], [0, 9], + [1, 0], [1, 7], + [2, 0], + [3, 1], [3, 4]] + sparse_values = [3.0, -1.0, 0.5, + 1.5, 6.0, + -2.0, + -0.5, 2.0] + + with self.test_session(): + tf.initialize_all_variables().run() + indices, feature_updates, threshold_updates = ( + self.ops.sample_inputs( + [], sparse_indices, sparse_values, sparse_shape, + self.node_map, self.leaves, self.split_features, + self.split_thresholds, split_initializations_per_input=1, + split_sampling_random_seed=3)) + self.assertAllEqual([1, 0], indices.eval()) + self.assertAllEqual([[1, 0, 0], [4, 7, -1]], + feature_updates.eval()) + self.assertAllEqual([[5., -2., -2.], [-1., 6., 0.]], + threshold_updates.eval()) + def testNoAccumulators(self): with self.test_session(): tf.initialize_all_variables().run() indices, feature_updates, threshold_updates = ( self.ops.sample_inputs( - self.input_data, [-1] * 3, self.leaves, self.split_features, + self.input_data, [], [], [], + [-1] * 3, self.leaves, self.split_features, self.split_thresholds, split_initializations_per_input=1, split_sampling_random_seed=3)) self.assertAllEqual([], indices.eval()) @@ -69,7 +96,8 @@ class SampleInputsTest(test_util.TensorFlowTestCase): with self.assertRaisesOpError( 'split_features and split_thresholds should be the same shape.'): indices, _, _ = self.ops.sample_inputs( - self.input_data, self.node_map, self.leaves, self.split_features, + self.input_data, [], [], [], + self.node_map, self.leaves, self.split_features, self.split_thresholds, split_initializations_per_input=1, split_sampling_random_seed=3) self.assertAllEqual([], indices.eval()) diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py index e61085657a..aaead5610f 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/tree_predictions_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import tensorflow # pylint: disable=unused-import +from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.python.framework import test_util @@ -29,6 +30,7 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): def setUp(self): self.ops = inference_ops.Load() + self.data_spec = [constants.DATA_FLOAT] * 2 def testSimple(self): input_data = [[-1., 0.], [-1., 2.], # node 1 @@ -41,13 +43,65 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): with self.test_session(): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=1) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=1) self.assertAllClose([[0.1, 0.1, 0.8], [0.1, 0.1, 0.8], [0.5, 0.25, 0.25], [0.5, 0.25, 0.25]], predictions.eval()) + def testSparseInput(self): + sparse_shape = [3, 10] + sparse_indices = [[0, 0], [0, 4], [0, 9], + [1, 0], [1, 7], + [2, 0]] + sparse_values = [3.0, -1.0, 0.5, + 1.5, 6.0, + -2.0] + sparse_data_spec = [constants.DATA_FLOAT] * 10 + + tree = [[1, 0], [-1, 0], [-1, 0]] + tree_thresholds = [0., 0., 0.] + node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8], + [1.0, 0.5, 0.25, 0.25]] + + with self.test_session(): + predictions = self.ops.tree_predictions( + [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec, + tree, tree_thresholds, node_pcw, + valid_leaf_threshold=1) + + self.assertAllClose([[0.5, 0.25, 0.25], + [0.5, 0.25, 0.25], + [0.1, 0.1, 0.8]], + predictions.eval()) + + def testSparseInputDefaultIsZero(self): + sparse_shape = [3, 10] + sparse_indices = [[0, 0], [0, 4], [0, 9], + [1, 0], [1, 7], + [2, 0]] + sparse_values = [3.0, -1.0, 0.5, + 1.5, 6.0, + -2.0] + sparse_data_spec = [constants.DATA_FLOAT] * 10 + + tree = [[1, 7], [-1, 0], [-1, 0]] + tree_thresholds = [3.0, 0., 0.] + node_pcw = [[1.0, 0.3, 0.4, 0.3], [1.0, 0.1, 0.1, 0.8], + [1.0, 0.5, 0.25, 0.25]] + + with self.test_session(): + predictions = self.ops.tree_predictions( + [], sparse_indices, sparse_values, sparse_shape, sparse_data_spec, + tree, tree_thresholds, node_pcw, + valid_leaf_threshold=1) + + self.assertAllClose([[0.1, 0.1, 0.8], + [0.5, 0.25, 0.25], + [0.1, 0.1, 0.8]], + predictions.eval()) + def testBackoffToParent(self): input_data = [[-1., 0.], [-1., 2.], # node 1 [1., 0.], [1., -2.]] # node 2 @@ -59,8 +113,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): with self.test_session(): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=10) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=10) # Node 2 has enough data, but Node 1 needs to combine with the parent # counts. @@ -78,8 +132,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): with self.test_session(): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=10) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=10) self.assertEquals((0, 3), predictions.eval().shape) @@ -97,8 +151,8 @@ class TreePredictionsTest(test_util.TensorFlowTestCase): 'Number of nodes should be the same in tree, tree_thresholds ' 'and node_pcw.'): predictions = self.ops.tree_predictions( - input_data, tree, tree_thresholds, node_pcw, - valid_leaf_threshold=10) + input_data, [], [], [], self.data_spec, tree, tree_thresholds, + node_pcw, valid_leaf_threshold=10) self.assertEquals((0, 3), predictions.eval().shape) diff --git a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py index f370903b3c..c9af01c50b 100644 --- a/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py +++ b/tensorflow/contrib/tensor_forest/python/kernel_tests/update_fertile_slots_op_test.py @@ -40,48 +40,43 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase): self.node_map = [-1, -1, 0, -1, -1, -1, -1] self.total_counts = [[80., 40., 40.]] self.ops = training_ops.Load() + self.stale_leaves = [] def testSimple(self): with self.test_session(): - (node_map_updates, accumulators_cleared, accumulators_allocated, - new_nfl, new_nfl_scores) = self.ops.update_fertile_slots( + (node_map_updates, accumulators_cleared, + accumulators_allocated) = self.ops.update_fertile_slots( self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores, self.end_of_tree, self.depths, - self.total_counts, self.node_map, max_depth=4) + self.total_counts, self.node_map, self.stale_leaves, max_depth=4) self.assertAllEqual([[2, 4], [-1, 0]], node_map_updates.eval()) self.assertAllEqual([], accumulators_cleared.eval()) self.assertAllEqual([0], accumulators_allocated.eval()) - self.assertAllEqual([3, 5, 6], new_nfl.eval()) - self.assertAllEqual([10., 1., 1.], new_nfl_scores.eval()) def testReachedMaxDepth(self): with self.test_session(): - (node_map_updates, accumulators_cleared, accumulators_allocated, - new_nfl, new_nfl_scores) = self.ops.update_fertile_slots( + (node_map_updates, accumulators_cleared, + accumulators_allocated) = self.ops.update_fertile_slots( self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores, self.end_of_tree, self.depths, - self.total_counts, self.node_map, max_depth=3) + self.total_counts, self.node_map, self.stale_leaves, max_depth=3) self.assertAllEqual([[2], [-1]], node_map_updates.eval()) self.assertAllEqual([0], accumulators_cleared.eval()) self.assertAllEqual([], accumulators_allocated.eval()) - self.assertAllEqual([-1], new_nfl.eval()) - self.assertAllEqual([0.0], new_nfl_scores.eval()) def testNoFinished(self): with self.test_session(): - (node_map_updates, accumulators_cleared, accumulators_allocated, - new_nfl, new_nfl_scores) = self.ops.update_fertile_slots( + (node_map_updates, accumulators_cleared, + accumulators_allocated) = self.ops.update_fertile_slots( [], self.non_fertile_leaves, self.non_fertile_leaf_scores, self.end_of_tree, self.depths, - self.total_counts, self.node_map, max_depth=4) + self.total_counts, self.node_map, self.stale_leaves, max_depth=4) self.assertAllEqual((2, 0), node_map_updates.eval().shape) self.assertAllEqual([], accumulators_cleared.eval()) self.assertAllEqual([], accumulators_allocated.eval()) - self.assertAllEqual([4, 3], new_nfl.eval()) - self.assertAllEqual([15., 10.], new_nfl_scores.eval()) def testBadInput(self): del self.non_fertile_leaf_scores[-1] @@ -89,10 +84,10 @@ class UpdateFertileSlotsTest(test_util.TensorFlowTestCase): with self.assertRaisesOpError( 'Number of non fertile leaves should be the same in ' 'non_fertile_leaves and non_fertile_leaf_scores.'): - (node_map_updates, _, _, _, _) = self.ops.update_fertile_slots( + (node_map_updates, _, _) = self.ops.update_fertile_slots( self.finished, self.non_fertile_leaves, self.non_fertile_leaf_scores, self.end_of_tree, self.depths, - self.total_counts, self.node_map, max_depth=4) + self.total_counts, self.node_map, self.stale_leaves, max_depth=4) self.assertAllEqual((2, 0), node_map_updates.eval().shape) diff --git a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py index 6f4e6fff40..88f8112ed4 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/inference_ops.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import threading -import tensorflow as tf - +from tensorflow.python.framework import load_library from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging + INFERENCE_OPS_FILE = '_inference_ops.so' @@ -38,7 +40,11 @@ ops.NoGradient('TreePredictions') def TreePredictions(op): """Shape function for TreePredictions Op.""" num_points = op.inputs[0].get_shape()[0].value - num_classes = op.inputs[3].get_shape()[1].value + sparse_shape = op.inputs[3].get_shape() + if sparse_shape.ndims == 2: + num_points = sparse_shape[0].value + num_classes = op.inputs[7].get_shape()[1].value + # The output of TreePredictions is # [node_pcw(evaluate_tree(x), c) for c in classes for x in input_data]. return [tensor_shape.TensorShape([num_points, num_classes - 1])] @@ -49,16 +55,14 @@ def TreePredictions(op): # there's not yet any guarantee that the shared object exists. # In which case, "import tensorflow" will always crash, even for users that # never use contrib. -def Load(library_base_dir=''): +def Load(): """Load the inference ops library and return the loaded module.""" with _ops_lock: global _inference_ops if not _inference_ops: - data_files_path = os.path.join(library_base_dir, - tf.resource_loader.get_data_files_path()) - tf.logging.info('data path: %s', data_files_path) - _inference_ops = tf.load_op_library(os.path.join( - data_files_path, INFERENCE_OPS_FILE)) + ops_path = resource_loader.get_path_to_datafile(INFERENCE_OPS_FILE) + logging.info('data path: %s', ops_path) + _inference_ops = load_library.load_op_library(ops_path) assert _inference_ops, 'Could not load inference_ops.so' return _inference_ops diff --git a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py index 7a108baf42..d25d5ce50b 100644 --- a/tensorflow/contrib/tensor_forest/python/ops/training_ops.py +++ b/tensorflow/contrib/tensor_forest/python/ops/training_ops.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,13 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import threading -import tensorflow as tf - +from tensorflow.python.framework import load_library from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging TRAINING_OPS_FILE = '_training_ops.so' @@ -45,7 +46,10 @@ def _CountExtremelyRandomStatsShape(op): """Shape function for CountExtremelyRandomStats Op.""" regression = op.get_attr('regression') num_points = op.inputs[0].get_shape()[0].value - num_nodes = op.inputs[2].get_shape()[0].value + sparse_shape = op.inputs[3].get_shape() + if sparse_shape.ndims == 2: + num_points = sparse_shape[0].value + num_nodes = op.inputs[6].get_shape()[0].value num_classes = op.get_attr('num_classes') # The output of TraverseTree is [leaf_node_index(x) for x in input_data]. return [tensor_shape.TensorShape([num_nodes, num_classes]), # node sums @@ -66,7 +70,7 @@ def _CountExtremelyRandomStatsShape(op): @ops.RegisterShape('SampleInputs') def _SampleInputsShape(op): """Shape function for SampleInputs Op.""" - num_splits = op.inputs[3].get_shape()[1].value + num_splits = op.inputs[6].get_shape()[1].value return [[None], [None, num_splits], [None, num_splits]] @@ -85,7 +89,7 @@ def _GrowTreeShape(unused_op): @ops.RegisterShape('FinishedNodes') def _FinishedNodesShape(unused_op): """Shape function for FinishedNodes Op.""" - return [[None]] + return [[None], [None]] @ops.RegisterShape('ScatterAddNdim') @@ -97,7 +101,7 @@ def _ScatterAddNdimShape(unused_op): @ops.RegisterShape('UpdateFertileSlots') def _UpdateFertileSlotsShape(unused_op): """Shape function for UpdateFertileSlots Op.""" - return [[None, 2], [None], [None], [None], [None]] + return [[None, 2], [None], [None]] # Workaround for the fact that importing tensorflow imports contrib @@ -105,16 +109,14 @@ def _UpdateFertileSlotsShape(unused_op): # there's not yet any guarantee that the shared object exists. # In which case, "import tensorflow" will always crash, even for users that # never use contrib. -def Load(library_base_dir=''): +def Load(): """Load training ops library and return the loaded module.""" with _ops_lock: global _training_ops if not _training_ops: - data_files_path = os.path.join(library_base_dir, - tf.resource_loader.get_data_files_path()) - tf.logging.info('data path: %s', data_files_path) - _training_ops = tf.load_op_library(os.path.join( - data_files_path, TRAINING_OPS_FILE)) + ops_path = resource_loader.get_path_to_datafile(TRAINING_OPS_FILE) + logging.info('data path: %s', ops_path) + _training_ops = load_library.load_op_library(ops_path) assert _training_ops, 'Could not load _training_ops.so' return _training_ops diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index f48efaa5db..791954c51f 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -1,3 +1,4 @@ +# pylint: disable=g-bad-file-header # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -20,14 +21,22 @@ from __future__ import print_function import math import random -import tensorflow as tf - +from tensorflow.contrib.tensor_forest.python import constants from tensorflow.contrib.tensor_forest.python.ops import inference_ops from tensorflow.contrib.tensor_forest.python.ops import training_ops - -# If tree[i][0] equals this value, then i is a leaf node. -LEAF_NODE = -1 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import tf_logging as logging # A convenience class for holding random forest hyperparameters. @@ -49,6 +58,7 @@ class ForestHParams(object): max_depth=0, num_splits_to_consider=0, feature_bagging_fraction=1.0, max_fertile_nodes=0, split_after_samples=250, + min_split_samples=5, valid_leaf_threshold=1, **kwargs): self.num_trees = num_trees self.max_nodes = max_nodes @@ -58,6 +68,7 @@ class ForestHParams(object): self.num_splits_to_consider = num_splits_to_consider self.max_fertile_nodes = max_fertile_nodes self.split_after_samples = split_after_samples + self.min_split_samples = min_split_samples self.valid_leaf_threshold = valid_leaf_threshold for name, value in kwargs.items(): @@ -72,11 +83,6 @@ class ForestHParams(object): _ = getattr(self, 'num_classes') _ = getattr(self, 'num_features') - self.training_library_base_dir = getattr( - self, 'training_library_base_dir', '') - self.inference_library_base_dir = getattr( - self, 'inference_library_base_dir', '') - self.bagged_num_features = int(self.feature_bagging_fraction * self.num_features) @@ -147,92 +153,86 @@ class TreeTrainingVariables(object): """ def __init__(self, params, tree_num, training): - self.tree = tf.get_variable( - name=self.get_tree_name('tree', tree_num), dtype=tf.int32, - initializer=tf.constant( - [[-1, -1]] + [[-2, -1]] * (params.max_nodes - 1))) - self.tree_thresholds = tf.get_variable( + self.tree = variable_scope.get_variable( + name=self.get_tree_name('tree', tree_num), dtype=dtypes.int32, + shape=[params.max_nodes, 2], + initializer=init_ops.constant_initializer(-2)) + self.tree_thresholds = variable_scope.get_variable( name=self.get_tree_name('tree_thresholds', tree_num), shape=[params.max_nodes], - initializer=tf.constant_initializer(-1.0)) - self.tree_depths = tf.get_variable( + initializer=init_ops.constant_initializer(-1.0)) + self.tree_depths = variable_scope.get_variable( name=self.get_tree_name('tree_depths', tree_num), shape=[params.max_nodes], - dtype=tf.int32, - initializer=tf.constant_initializer(1)) - self.end_of_tree = tf.get_variable( + dtype=dtypes.int32, + initializer=init_ops.constant_initializer(1)) + self.end_of_tree = variable_scope.get_variable( name=self.get_tree_name('end_of_tree', tree_num), - dtype=tf.int32, - initializer=tf.constant([1])) + dtype=dtypes.int32, + initializer=constant_op.constant([1])) + self.start_epoch = tf_variables.Variable( + [0] * (params.max_nodes), name='start_epoch') if training: - self.non_fertile_leaves = tf.get_variable( - name=self.get_tree_name('non_fertile_leaves', tree_num), - dtype=tf.int32, - initializer=tf.constant([0])) - self.non_fertile_leaf_scores = tf.get_variable( - name=self.get_tree_name('non_fertile_leaf_scores', tree_num), - initializer=tf.constant([1.0])) - - self.node_to_accumulator_map = tf.get_variable( + self.node_to_accumulator_map = variable_scope.get_variable( name=self.get_tree_name('node_to_accumulator_map', tree_num), shape=[params.max_nodes], - dtype=tf.int32, - initializer=tf.constant_initializer(-1)) + dtype=dtypes.int32, + initializer=init_ops.constant_initializer(-1)) - self.candidate_split_features = tf.get_variable( + self.candidate_split_features = variable_scope.get_variable( name=self.get_tree_name('candidate_split_features', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider], - dtype=tf.int32, - initializer=tf.constant_initializer(-1)) - self.candidate_split_thresholds = tf.get_variable( + dtype=dtypes.int32, + initializer=init_ops.constant_initializer(-1)) + self.candidate_split_thresholds = variable_scope.get_variable( name=self.get_tree_name('candidate_split_thresholds', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) # Statistics shared by classification and regression. - self.node_sums = tf.get_variable( + self.node_sums = variable_scope.get_variable( name=self.get_tree_name('node_sums', tree_num), shape=[params.max_nodes, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) if training: - self.candidate_split_sums = tf.get_variable( + self.candidate_split_sums = variable_scope.get_variable( name=self.get_tree_name('candidate_split_sums', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) - self.accumulator_sums = tf.get_variable( + initializer=init_ops.constant_initializer(0.0)) + self.accumulator_sums = variable_scope.get_variable( name=self.get_tree_name('accumulator_sums', tree_num), shape=[params.max_fertile_nodes, params.num_output_columns], - initializer=tf.constant_initializer(-1.0)) + initializer=init_ops.constant_initializer(-1.0)) # Regression also tracks second order stats. if params.regression: - self.node_squares = tf.get_variable( + self.node_squares = variable_scope.get_variable( name=self.get_tree_name('node_squares', tree_num), shape=[params.max_nodes, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) - self.candidate_split_squares = tf.get_variable( + self.candidate_split_squares = variable_scope.get_variable( name=self.get_tree_name('candidate_split_squares', tree_num), shape=[params.max_fertile_nodes, params.num_splits_to_consider, params.num_output_columns], - initializer=tf.constant_initializer(0.0)) + initializer=init_ops.constant_initializer(0.0)) - self.accumulator_squares = tf.get_variable( + self.accumulator_squares = variable_scope.get_variable( name=self.get_tree_name('accumulator_squares', tree_num), shape=[params.max_fertile_nodes, params.num_output_columns], - initializer=tf.constant_initializer(-1.0)) + initializer=init_ops.constant_initializer(-1.0)) else: - self.node_squares = tf.constant( + self.node_squares = constant_op.constant( 0.0, name=self.get_tree_name('node_squares', tree_num)) - self.candidate_split_squares = tf.constant( + self.candidate_split_squares = constant_op.constant( 0.0, name=self.get_tree_name('candidate_split_squares', tree_num)) - self.accumulator_squares = tf.constant( + self.accumulator_squares = constant_op.constant( 0.0, name=self.get_tree_name('accumulator_squares', tree_num)) def get_tree_name(self, name, num): @@ -273,11 +273,11 @@ class ForestTrainingVariables(object): """ def __init__(self, params, device_assigner, training=True, - tree_variable_class=TreeTrainingVariables): + tree_variables_class=TreeTrainingVariables): self.variables = [] for i in range(params.num_trees): - with tf.device(device_assigner.get_device(i)): - self.variables.append(tree_variable_class(params, i, training)) + with ops.device(device_assigner.get_device(i)): + self.variables.append(tree_variables_class(params, i, training)) def __setitem__(self, t, val): self.variables[t] = val @@ -299,7 +299,7 @@ class RandomForestDeviceAssigner(object): def get_device(self, unused_tree_num): if not self.cached: - dummy = tf.constant(0) + dummy = constant_op.constant(0) self.cached = dummy.device return self.cached @@ -308,43 +308,51 @@ class RandomForestDeviceAssigner(object): class RandomForestGraphs(object): """Builds TF graphs for random forest training and inference.""" - def __init__(self, params, device_assigner=None, variables=None, - tree_graphs=None, + def __init__(self, params, device_assigner=None, + variables=None, tree_variables_class=TreeTrainingVariables, + tree_graphs=None, training=True, t_ops=training_ops, i_ops=inference_ops): self.params = params self.device_assigner = device_assigner or RandomForestDeviceAssigner() - tf.logging.info('Constructing forest with params = ') - tf.logging.info(self.params.__dict__) + logging.info('Constructing forest with params = ') + logging.info(self.params.__dict__) self.variables = variables or ForestTrainingVariables( - self.params, device_assigner=self.device_assigner) + self.params, device_assigner=self.device_assigner, training=training, + tree_variables_class=tree_variables_class) tree_graph_class = tree_graphs or RandomTreeGraphs self.trees = [ tree_graph_class( self.variables[i], self.params, - t_ops.Load(self.params.training_library_base_dir), - i_ops.Load(self.params.inference_library_base_dir), i) + t_ops.Load(), i_ops.Load(), i) for i in range(self.params.num_trees)] def _bag_features(self, tree_num, input_data): - split_data = tf.split(1, self.params.num_features, input_data) - return tf.concat(1, [split_data[ind] - for ind in self.params.bagged_features[tree_num]]) + split_data = array_ops.split(1, self.params.num_features, input_data) + return array_ops.concat( + 1, [split_data[ind] for ind in self.params.bagged_features[tree_num]]) - def training_graph(self, input_data, input_labels): + def training_graph(self, input_data, input_labels, data_spec=None, + epoch=None, **tree_kwargs): """Constructs a TF graph for training a random forest. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. + data_spec: A list of tf.dtype values specifying the original types of + each column. + epoch: A tensor or placeholder for the epoch the training data comes from. + **tree_kwargs: Keyword arguments passed to each tree's training_graph. Returns: The last op in the random forest training graph. """ + data_spec = ([constants.DATA_FLOAT] * self.params.num_features + if data_spec is None else data_spec) tree_graphs = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): seed = self.params.base_random_seed if seed != 0: seed += i @@ -354,40 +362,54 @@ class RandomForestGraphs(object): if self.params.bagging_fraction < 1.0: # TODO(thomaswc): This does sampling without replacment. Consider # also allowing sampling with replacement as an option. - batch_size = tf.slice(tf.shape(input_data), [0], [1]) - r = tf.random_uniform(batch_size, seed=seed) - mask = tf.less(r, tf.ones_like(r) * self.params.bagging_fraction) - gather_indices = tf.squeeze(tf.where(mask), squeeze_dims=[1]) + batch_size = array_ops.slice(array_ops.shape(input_data), [0], [1]) + r = random_ops.random_uniform(batch_size, seed=seed) + mask = math_ops.less( + r, array_ops.ones_like(r) * self.params.bagging_fraction) + gather_indices = array_ops.squeeze( + array_ops.where(mask), squeeze_dims=[1]) # TODO(thomaswc): Calculate out-of-bag data and labels, and store # them for use in calculating statistics later. - tree_data = tf.gather(input_data, gather_indices) - tree_labels = tf.gather(input_labels, gather_indices) + tree_data = array_ops.gather(input_data, gather_indices) + tree_labels = array_ops.gather(input_labels, gather_indices) if self.params.bagged_features: tree_data = self._bag_features(i, tree_data) - tree_graphs.append( - self.trees[i].training_graph(tree_data, tree_labels, seed)) - return tf.group(*tree_graphs) + initialization = self.trees[i].tree_initialization() + + with ops.control_dependencies([initialization]): + tree_graphs.append( + self.trees[i].training_graph( + tree_data, tree_labels, seed, data_spec=data_spec, + epoch=([0] if epoch is None else epoch), + **tree_kwargs)) - def inference_graph(self, input_data): + return control_flow_ops.group(*tree_graphs) + + def inference_graph(self, input_data, data_spec=None): """Constructs a TF graph for evaluating a random forest. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. + data_spec: A list of tf.dtype values specifying the original types of + each column. Returns: The last op in the random forest inference graph. """ + data_spec = ([constants.DATA_FLOAT] * self.params.num_features + if data_spec is None else data_spec) probabilities = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): tree_data = input_data if self.params.bagged_features: tree_data = self._bag_features(i, input_data) - probabilities.append(self.trees[i].inference_graph(tree_data)) - with tf.device(self.device_assigner.get_device(0)): - all_predict = tf.pack(probabilities) - return tf.reduce_sum(all_predict, 0) / self.params.num_trees + probabilities.append(self.trees[i].inference_graph(tree_data, + data_spec)) + with ops.device(self.device_assigner.get_device(0)): + all_predict = array_ops.pack(probabilities) + return math_ops.reduce_sum(all_predict, 0) / self.params.num_trees def average_size(self): """Constructs a TF graph for evaluating the average size of a forest. @@ -397,9 +419,16 @@ class RandomForestGraphs(object): """ sizes = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): sizes.append(self.trees[i].size()) - return tf.reduce_mean(tf.pack(sizes)) + return math_ops.reduce_mean(array_ops.pack(sizes)) + + def training_loss(self): + return math_ops.neg(self.average_size()) + + # pylint: disable=unused-argument + def validation_loss(self, features, labels): + return math_ops.neg(self.average_size()) def average_impurity(self): """Constructs a TF graph for evaluating the leaf impurity of a forest. @@ -409,14 +438,14 @@ class RandomForestGraphs(object): """ impurities = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): impurities.append(self.trees[i].average_impurity()) - return tf.reduce_mean(tf.pack(impurities)) + return math_ops.reduce_mean(array_ops.pack(impurities)) def get_stats(self, session): tree_stats = [] for i in range(self.params.num_trees): - with tf.device(self.device_assigner.get_device(i)): + with ops.device(self.device_assigner.get_device(i)): tree_stats.append(self.trees[i].get_stats(session)) return ForestStats(tree_stats, self.params) @@ -431,6 +460,18 @@ class RandomTreeGraphs(object): self.params = params self.tree_num = tree_num + def tree_initialization(self): + def _init_tree(): + return state_ops.scatter_update(self.variables.tree, [0], [[-1, -1]]).op + + def _nothing(): + return control_flow_ops.no_op() + + return control_flow_ops.cond( + math_ops.equal(array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [1, 1])), -2), + _init_tree, _nothing) + def _gini(self, class_counts): """Calculate the Gini impurity. @@ -444,9 +485,9 @@ class RandomTreeGraphs(object): Returns: A 1-D tensor of the Gini impurities for each row in the input. """ - smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1]) - sums = tf.reduce_sum(smoothed, 1) - sum_squares = tf.reduce_sum(tf.square(smoothed), 1) + smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) + sums = math_ops.reduce_sum(smoothed, 1) + sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return 1.0 - sum_squares / (sums * sums) @@ -463,9 +504,9 @@ class RandomTreeGraphs(object): Returns: A 1-D tensor of the Gini impurities for each row in the input. """ - smoothed = 1.0 + tf.slice(class_counts, [0, 1], [-1, -1]) - sums = tf.reduce_sum(smoothed, 1) - sum_squares = tf.reduce_sum(tf.square(smoothed), 1) + smoothed = 1.0 + array_ops.slice(class_counts, [0, 1], [-1, -1]) + sums = math_ops.reduce_sum(smoothed, 1) + sum_squares = math_ops.reduce_sum(math_ops.square(smoothed), 1) return sums - sum_squares / sums @@ -483,40 +524,58 @@ class RandomTreeGraphs(object): Returns: A 1-D tensor of the variances for each row in the input. """ - total_count = tf.slice(sums, [0, 0], [-1, 1]) + total_count = array_ops.slice(sums, [0, 0], [-1, 1]) e_x = sums / total_count e_x2 = squares / total_count - return tf.reduce_sum(e_x2 - tf.square(e_x), 1) + return math_ops.reduce_sum(e_x2 - math_ops.square(e_x), 1) + + def training_graph(self, input_data, input_labels, random_seed, + data_spec, epoch=None): - def training_graph(self, input_data, input_labels, random_seed): """Constructs a TF graph for training a random tree. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. input_labels: A tensor or placeholder for labels associated with input_data. random_seed: The random number generator seed to use for this tree. 0 means use the current time as the seed. + data_spec: A list of tf.dtype values specifying the original types of + each column. + epoch: A tensor or placeholder for the epoch the training data comes from. Returns: The last op in the random tree training graph. """ + epoch = [0] if epoch is None else epoch + + sparse_indices = [] + sparse_values = [] + sparse_shape = [] + if isinstance(input_data, ops.SparseTensor): + sparse_indices = input_data.indices + sparse_values = input_data.values + sparse_shape = input_data.shape + input_data = [] + # Count extremely random stats. (node_sums, node_squares, splits_indices, splits_sums, splits_squares, totals_indices, totals_sums, totals_squares, input_leaves) = ( self.training_ops.count_extremely_random_stats( - input_data, input_labels, self.variables.tree, + input_data, sparse_indices, sparse_values, sparse_shape, + data_spec, input_labels, self.variables.tree, self.variables.tree_thresholds, self.variables.node_to_accumulator_map, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, + self.variables.start_epoch, epoch, num_classes=self.params.num_output_columns, regression=self.params.regression)) node_update_ops = [] node_update_ops.append( - tf.assign_add(self.variables.node_sums, node_sums)) + state_ops.assign_add(self.variables.node_sums, node_sums)) splits_update_ops = [] splits_update_ops.append(self.training_ops.scatter_add_ndim( @@ -527,8 +586,8 @@ class RandomTreeGraphs(object): totals_sums)) if self.params.regression: - node_update_ops.append(tf.assign_add(self.variables.node_squares, - node_squares)) + node_update_ops.append(state_ops.assign_add(self.variables.node_squares, + node_squares)) splits_update_ops.append(self.training_ops.scatter_add_ndim( self.variables.candidate_split_squares, splits_indices, splits_squares)) @@ -539,63 +598,56 @@ class RandomTreeGraphs(object): # Sample inputs. update_indices, feature_updates, threshold_updates = ( self.training_ops.sample_inputs( - input_data, self.variables.node_to_accumulator_map, + input_data, sparse_indices, sparse_values, sparse_shape, + self.variables.node_to_accumulator_map, input_leaves, self.variables.candidate_split_features, self.variables.candidate_split_thresholds, split_initializations_per_input=( self.params.split_initializations_per_input), split_sampling_random_seed=random_seed)) - update_features_op = tf.scatter_update( + update_features_op = state_ops.scatter_update( self.variables.candidate_split_features, update_indices, feature_updates) - update_thresholds_op = tf.scatter_update( + update_thresholds_op = state_ops.scatter_update( self.variables.candidate_split_thresholds, update_indices, threshold_updates) # Calculate finished nodes. - with tf.control_dependencies(splits_update_ops): - children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]), - squeeze_dims=[1]) - is_leaf = tf.equal(LEAF_NODE, children) - leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1])) - finished = self.training_ops.finished_nodes( + with ops.control_dependencies(splits_update_ops): + children = array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) + is_leaf = math_ops.equal(constants.LEAF_NODE, children) + leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), + squeeze_dims=[1])) + finished, stale = self.training_ops.finished_nodes( leaves, self.variables.node_to_accumulator_map, + self.variables.candidate_split_sums, + self.variables.candidate_split_squares, self.variables.accumulator_sums, - num_split_after_samples=self.params.split_after_samples) + self.variables.accumulator_squares, + self.variables.start_epoch, epoch, + num_split_after_samples=self.params.split_after_samples, + min_split_samples=self.params.min_split_samples) # Update leaf scores. - # TODO(gilberth): Optimize this. It currently calculates counts for - # every non-fertile leaf. - with tf.control_dependencies(node_update_ops): - def dont_update_leaf_scores(): - return self.variables.non_fertile_leaf_scores - - def update_leaf_scores_regression(): - sums = tf.gather(self.variables.node_sums, - self.variables.non_fertile_leaves) - squares = tf.gather(self.variables.node_squares, - self.variables.non_fertile_leaves) - new_scores = self._variance(sums, squares) - return tf.assign(self.variables.non_fertile_leaf_scores, new_scores) - - def update_leaf_scores_classification(): - counts = tf.gather(self.variables.node_sums, - self.variables.non_fertile_leaves) - new_scores = self._weighted_gini(counts) - return tf.assign(self.variables.non_fertile_leaf_scores, new_scores) - - # Because we can't have tf.self.variables of size 0, we have to put in a - # garbage value of -1 in there. Here we check for that so we don't - # try to index into node_per_class_weights in a tf.gather with a negative - # number. - update_nonfertile_leaves_scores_op = tf.cond( - tf.less(self.variables.non_fertile_leaves[0], 0), - dont_update_leaf_scores, - update_leaf_scores_regression if self.params.regression else - update_leaf_scores_classification) + non_fertile_leaves = array_ops.boolean_mask( + leaves, math_ops.less(array_ops.gather( + self.variables.node_to_accumulator_map, leaves), 0)) + + # TODO(gilberth): It should be possible to limit the number of non + # fertile leaves we calculate scores for, especially since we can only take + # at most array_ops.shape(finished)[0] of them. + with ops.control_dependencies(node_update_ops): + sums = array_ops.gather(self.variables.node_sums, non_fertile_leaves) + if self.params.regression: + squares = array_ops.gather(self.variables.node_squares, + non_fertile_leaves) + non_fertile_leaf_scores = self._variance(sums, squares) + else: + non_fertile_leaf_scores = self._weighted_gini(sums) # Calculate best splits. - with tf.control_dependencies(splits_update_ops): + with ops.control_dependencies(splits_update_ops): split_indices = self.training_ops.best_splits( finished, self.variables.node_to_accumulator_map, self.variables.candidate_split_sums, @@ -605,7 +657,7 @@ class RandomTreeGraphs(object): regression=self.params.regression) # Grow tree. - with tf.control_dependencies([update_features_op, update_thresholds_op]): + with ops.control_dependencies([update_features_op, update_thresholds_op]): (tree_update_indices, tree_children_updates, tree_threshold_updates, tree_depth_updates, new_eot) = ( self.training_ops.grow_tree( @@ -613,110 +665,138 @@ class RandomTreeGraphs(object): self.variables.node_to_accumulator_map, finished, split_indices, self.variables.candidate_split_features, self.variables.candidate_split_thresholds)) - tree_update_op = tf.scatter_update( + tree_update_op = state_ops.scatter_update( self.variables.tree, tree_update_indices, tree_children_updates) - threhsolds_update_op = tf.scatter_update( + thresholds_update_op = state_ops.scatter_update( self.variables.tree_thresholds, tree_update_indices, tree_threshold_updates) - depth_update_op = tf.scatter_update( + depth_update_op = state_ops.scatter_update( self.variables.tree_depths, tree_update_indices, tree_depth_updates) + # TODO(thomaswc): Only update the epoch on the new leaves. + new_epoch_updates = epoch * array_ops.ones_like(tree_depth_updates) + epoch_update_op = state_ops.scatter_update( + self.variables.start_epoch, tree_update_indices, + new_epoch_updates) # Update fertile slots. - with tf.control_dependencies([update_nonfertile_leaves_scores_op, - depth_update_op]): - (node_map_updates, accumulators_cleared, accumulators_allocated, - new_nonfertile_leaves, new_nonfertile_leaves_scores) = ( - self.training_ops.update_fertile_slots( - finished, self.variables.non_fertile_leaves, - self.variables.non_fertile_leaf_scores, - self.variables.end_of_tree, self.variables.tree_depths, - self.variables.accumulator_sums, - self.variables.node_to_accumulator_map, - max_depth=self.params.max_depth, - regression=self.params.regression)) + with ops.control_dependencies([depth_update_op]): + (node_map_updates, accumulators_cleared, accumulators_allocated) = ( + self.training_ops.update_fertile_slots( + finished, non_fertile_leaves, + non_fertile_leaf_scores, + self.variables.end_of_tree, self.variables.tree_depths, + self.variables.accumulator_sums, + self.variables.node_to_accumulator_map, + stale, + max_depth=self.params.max_depth, + regression=self.params.regression)) # Ensure end_of_tree doesn't get updated until UpdateFertileSlots has # used it to calculate new leaves. - gated_new_eot, = tf.tuple([new_eot], control_inputs=[new_nonfertile_leaves]) - eot_update_op = tf.assign(self.variables.end_of_tree, gated_new_eot) + gated_new_eot, = control_flow_ops.tuple([new_eot], + control_inputs=[node_map_updates]) + eot_update_op = state_ops.assign(self.variables.end_of_tree, gated_new_eot) updates = [] updates.append(eot_update_op) updates.append(tree_update_op) - updates.append(threhsolds_update_op) - updates.append(tf.assign( - self.variables.non_fertile_leaves, new_nonfertile_leaves, - validate_shape=False)) - updates.append(tf.assign( - self.variables.non_fertile_leaf_scores, - new_nonfertile_leaves_scores, validate_shape=False)) - - updates.append(tf.scatter_update( + updates.append(thresholds_update_op) + updates.append(epoch_update_op) + + updates.append(state_ops.scatter_update( self.variables.node_to_accumulator_map, - tf.squeeze(tf.slice(node_map_updates, [0, 0], [1, -1]), - squeeze_dims=[0]), - tf.squeeze(tf.slice(node_map_updates, [1, 0], [1, -1]), - squeeze_dims=[0]))) + array_ops.squeeze(array_ops.slice(node_map_updates, [0, 0], [1, -1]), + squeeze_dims=[0]), + array_ops.squeeze(array_ops.slice(node_map_updates, [1, 0], [1, -1]), + squeeze_dims=[0]))) - cleared_and_allocated_accumulators = tf.concat( + cleared_and_allocated_accumulators = array_ops.concat( 0, [accumulators_cleared, accumulators_allocated]) # Calculate values to put into scatter update for candidate counts. # Candidate split counts are always reset back to 0 for both cleared # and allocated accumulators. This means some accumulators might be doubly # reset to 0 if the were released and not allocated, then later allocated. - split_values = tf.tile( - tf.expand_dims(tf.expand_dims( - tf.zeros_like(cleared_and_allocated_accumulators, dtype=tf.float32), - 1), 2), + split_values = array_ops.tile( + array_ops.expand_dims(array_ops.expand_dims( + array_ops.zeros_like(cleared_and_allocated_accumulators, + dtype=dtypes.float32), 1), 2), [1, self.params.num_splits_to_consider, self.params.num_output_columns]) - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.candidate_split_sums, cleared_and_allocated_accumulators, split_values)) if self.params.regression: - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.candidate_split_squares, cleared_and_allocated_accumulators, split_values)) # Calculate values to put into scatter update for total counts. - total_cleared = tf.tile( - tf.expand_dims( - tf.neg(tf.ones_like(accumulators_cleared, dtype=tf.float32)), 1), + total_cleared = array_ops.tile( + array_ops.expand_dims( + math_ops.neg(array_ops.ones_like(accumulators_cleared, + dtype=dtypes.float32)), 1), [1, self.params.num_output_columns]) - total_reset = tf.tile( - tf.expand_dims( - tf.zeros_like(accumulators_allocated, dtype=tf.float32), 1), + total_reset = array_ops.tile( + array_ops.expand_dims( + array_ops.zeros_like(accumulators_allocated, + dtype=dtypes.float32), 1), [1, self.params.num_output_columns]) - accumulator_updates = tf.concat(0, [total_cleared, total_reset]) - updates.append(tf.scatter_update( + accumulator_updates = array_ops.concat(0, [total_cleared, total_reset]) + updates.append(state_ops.scatter_update( self.variables.accumulator_sums, cleared_and_allocated_accumulators, accumulator_updates)) if self.params.regression: - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.accumulator_squares, cleared_and_allocated_accumulators, accumulator_updates)) # Calculate values to put into scatter update for candidate splits. - split_features_updates = tf.tile( - tf.expand_dims( - tf.neg(tf.ones_like(cleared_and_allocated_accumulators)), 1), + split_features_updates = array_ops.tile( + array_ops.expand_dims( + math_ops.neg(array_ops.ones_like( + cleared_and_allocated_accumulators)), 1), [1, self.params.num_splits_to_consider]) - updates.append(tf.scatter_update( + updates.append(state_ops.scatter_update( self.variables.candidate_split_features, cleared_and_allocated_accumulators, split_features_updates)) - return tf.group(*updates) + updates += self.finish_iteration() + + return control_flow_ops.group(*updates) + + def finish_iteration(self): + """Perform any operations that should be done at the end of an iteration. + + This is mostly useful for subclasses that need to reset variables after + an iteration, such as ones that are used to finish nodes. + + Returns: + A list of operations. + """ + return [] - def inference_graph(self, input_data): + def inference_graph(self, input_data, data_spec): """Constructs a TF graph for evaluating a random tree. Args: - input_data: A tensor or placeholder for input data. + input_data: A tensor or SparseTensor or placeholder for input data. + data_spec: A list of tf.dtype values specifying the original types of + each column. Returns: The last op in the random tree inference graph. """ + sparse_indices = [] + sparse_values = [] + sparse_shape = [] + if isinstance(input_data, ops.SparseTensor): + sparse_indices = input_data.indices + sparse_values = input_data.values + sparse_shape = input_data.shape + input_data = [] return self.inference_ops.tree_predictions( - input_data, self.variables.tree, self.variables.tree_thresholds, + input_data, sparse_indices, sparse_values, sparse_shape, data_spec, + self.variables.tree, + self.variables.tree_thresholds, self.variables.node_sums, valid_leaf_threshold=self.params.valid_leaf_threshold) @@ -729,13 +809,22 @@ class RandomTreeGraphs(object): Returns: The last op in the graph. """ - children = tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1]), - squeeze_dims=[1]) - is_leaf = tf.equal(LEAF_NODE, children) - leaves = tf.to_int32(tf.squeeze(tf.where(is_leaf), squeeze_dims=[1])) - counts = tf.gather(self.variables.node_sums, leaves) - impurity = self._weighted_gini(counts) - return tf.reduce_sum(impurity) / tf.reduce_sum(counts + 1.0) + children = array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) + is_leaf = math_ops.equal(constants.LEAF_NODE, children) + leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), + squeeze_dims=[1])) + counts = array_ops.gather(self.variables.node_sums, leaves) + gini = self._weighted_gini(counts) + # Guard against step 1, when there often are no leaves yet. + def impurity(): + return gini + # Since average impurity can be used for loss, when there's no data just + # return a big number so that loss always decreases. + def big(): + return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000. + return control_flow_ops.cond(math_ops.greater( + array_ops.shape(leaves)[0], 0), impurity, big) def size(self): """Constructs a TF graph for evaluating the current number of nodes. @@ -747,7 +836,8 @@ class RandomTreeGraphs(object): def get_stats(self, session): num_nodes = self.variables.end_of_tree.eval(session=session) - 1 - num_leaves = tf.where( - tf.equal(tf.squeeze(tf.slice(self.variables.tree, [0, 0], [-1, 1])), - LEAF_NODE)).eval(session=session).shape[0] + num_leaves = array_ops.where( + math_ops.equal(array_ops.squeeze(array_ops.slice( + self.variables.tree, [0, 0], [-1, 1])), constants.LEAF_NODE) + ).eval(session=session).shape[0] return TreeStats(num_nodes, num_leaves) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py index c3e1c8520d..4e4cfcd1e8 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py @@ -105,6 +105,47 @@ class TensorForestTest(test_util.TensorFlowTestCase): graph = graph_builder.average_impurity() self.assertTrue(isinstance(graph, tf.Tensor)) + def testTrainingConstructionClassificationSparse(self): + input_data = tf.SparseTensor( + indices=[[0, 0], [0, 3], + [1, 0], [1, 7], + [2, 1], + [3, 9]], + values=[-1.0, 0.0, + -1., 2., + 1., + -2.0], + shape=[4, 10]) + input_labels = [0, 1, 2, 3] + + params = tensor_forest.ForestHParams( + num_classes=4, num_features=10, num_trees=10, max_nodes=1000, + split_after_samples=25).fill() + + graph_builder = tensor_forest.RandomForestGraphs(params) + graph = graph_builder.training_graph(input_data, input_labels) + self.assertTrue(isinstance(graph, tf.Operation)) + + def testInferenceConstructionSparse(self): + input_data = tf.SparseTensor( + indices=[[0, 0], [0, 3], + [1, 0], [1, 7], + [2, 1], + [3, 9]], + values=[-1.0, 0.0, + -1., 2., + 1., + -2.0], + shape=[4, 10]) + + params = tensor_forest.ForestHParams( + num_classes=4, num_features=10, num_trees=10, max_nodes=1000, + split_after_samples=25).fill() + + graph_builder = tensor_forest.RandomForestGraphs(params) + graph = graph_builder.inference_graph(input_data) + self.assertTrue(isinstance(graph, tf.Tensor)) + if __name__ == '__main__': googletest.main() -- cgit v1.2.3