aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-12 07:37:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-12 07:40:56 -0700
commit786bf6cd656d0d67e56bf50047ff116bae884b9e (patch)
tree048e396094d11a3cd71525defb2d0e40e00140d6
parenteb1fe50da445d3880b588215f6fadcc7f48dd3ff (diff)
Refactor some of TensorForest V4 to make the tree model valid during training time, instead of only after FinalizeTreeOp.
PiperOrigin-RevId: 161663317
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops.cc183
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc27
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/stats_ops.cc101
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc2
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc9
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h8
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc26
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h11
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc33
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h23
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc13
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc10
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h8
-rw-r--r--tensorflow/contrib/tensor_forest/ops/model_ops.cc52
-rw-r--r--tensorflow/contrib/tensor_forest/ops/stats_ops.cc2
-rw-r--r--tensorflow/contrib/tensor_forest/python/ops/model_ops.py11
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py27
17 files changed, 374 insertions, 172 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
index 0f92f05e2c..221f8d969b 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <functional>
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h"
#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
@@ -26,6 +27,7 @@
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
namespace tensorforest {
@@ -46,7 +48,7 @@ class CreateTreeVariableOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(tree_config_t->shape()),
errors::InvalidArgument("Tree config must be a scalar."));
- auto* result = new DecisionTreeResource();
+ auto* result = new DecisionTreeResource(param_proto_);
if (!ParseProtoUnlimited(result->mutable_decision_tree(),
tree_config_t->scalar<string>()())) {
result->Unref();
@@ -142,6 +144,16 @@ class TreeSizeOp : public OpKernel {
}
};
+void TraverseTree(const DecisionTreeResource* tree_resource,
+ const std::unique_ptr<TensorDataSet>& data, int32 start,
+ int32 end,
+ const std::function<void(int32, int32)>& set_leaf_id) {
+ for (int i = start; i < end; ++i) {
+ const int32 id = tree_resource->TraverseTree(data, i, nullptr);
+ set_leaf_id(i, id);
+ }
+}
+
// Op for tree inference.
class TreePredictionsV4Op : public OpKernel {
public:
@@ -176,22 +188,49 @@ class TreePredictionsV4Op : public OpKernel {
mutex_lock l(*decision_tree_resource->get_mutex());
core::ScopedUnref unref_me(decision_tree_resource);
+ const int num_data = data_set_->NumItems();
+ const int32 num_outputs = param_proto_.num_outputs();
+
Tensor* output_predictions = nullptr;
TensorShape output_shape;
- output_shape.AddDim(data_set_->NumItems());
- output_shape.AddDim(param_proto_.num_outputs());
+ output_shape.AddDim(num_data);
+ output_shape.AddDim(num_outputs);
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
&output_predictions));
+ TTypes<float, 2>::Tensor out = output_predictions->tensor<float, 2>();
+
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ int num_threads = worker_threads->num_threads;
+ const int64 costPerTraverse = 500;
+ auto traverse = [this, &out, decision_tree_resource, num_data](int64 start,
+ int64 end) {
+ CHECK(start <= end);
+ CHECK(end <= num_data);
+ TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
+ static_cast<int32>(end),
+ std::bind(&TreePredictionsV4Op::set_output_value, this,
+ std::placeholders::_1, std::placeholders::_2,
+ decision_tree_resource, &out));
+ };
+ Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
+ traverse);
+ }
+
+ void set_output_value(int32 i, int32 id,
+ DecisionTreeResource* decision_tree_resource,
+ TTypes<float, 2>::Tensor* out) {
+ const decision_trees::Leaf& leaf = decision_tree_resource->get_leaf(id);
- auto out = output_predictions->tensor<float, 2>();
- for (int i = 0; i < data_set_->NumItems(); ++i) {
- const int32 leaf_id =
- decision_tree_resource->TraverseTree(data_set_, i, nullptr);
- const decision_trees::Leaf& leaf =
- decision_tree_resource->get_leaf(leaf_id);
+ float sum = 0;
+ for (int j = 0; j < param_proto_.num_outputs(); ++j) {
+ const float count = model_op_->GetOutputValue(leaf, j);
+ (*out)(i, j) = count;
+ sum += count;
+ }
+
+ if (!param_proto_.is_regression() && sum > 0 && sum != 1) {
for (int j = 0; j < param_proto_.num_outputs(); ++j) {
- const float count = model_op_->GetOutputValue(leaf, j);
- out(i, j) = count;
+ (*out)(i, j) /= sum;
}
}
}
@@ -203,6 +242,122 @@ class TreePredictionsV4Op : public OpKernel {
TensorForestParams param_proto_;
};
+// Outputs leaf ids for the given examples.
+class TraverseTreeV4Op : public OpKernel {
+ public:
+ explicit TraverseTreeV4Op(OpKernelConstruction* context) : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+
+ string serialized_proto;
+ OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
+ input_spec_.ParseFromString(serialized_proto);
+
+ data_set_ =
+ std::unique_ptr<TensorDataSet>(new TensorDataSet(input_spec_, 0));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input_data = context->input(1);
+ const Tensor& sparse_input_indices = context->input(2);
+ const Tensor& sparse_input_values = context->input(3);
+ const Tensor& sparse_input_shape = context->input(4);
+
+ data_set_->set_input_tensors(input_data, sparse_input_indices,
+ sparse_input_values, sparse_input_shape);
+
+ DecisionTreeResource* decision_tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+
+ const int num_data = data_set_->NumItems();
+
+ Tensor* output_predictions = nullptr;
+ TensorShape output_shape;
+ output_shape.AddDim(num_data);
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
+ &output_predictions));
+
+ auto leaf_ids = output_predictions->tensor<int32, 1>();
+
+ auto set_leaf_ids = [&leaf_ids](int32 i, int32 id) { leaf_ids(i) = id; };
+
+ auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
+ int num_threads = worker_threads->num_threads;
+ const int64 costPerTraverse = 500;
+ auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data](
+ int64 start, int64 end) {
+ CHECK(start <= end);
+ CHECK(end <= num_data);
+ TraverseTree(decision_tree_resource, data_set_, static_cast<int32>(start),
+ static_cast<int32>(end), set_leaf_ids);
+ };
+ Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
+ traverse);
+ }
+
+ private:
+ tensorforest::TensorForestDataSpec input_spec_;
+ std::unique_ptr<TensorDataSet> data_set_;
+ TensorForestParams param_proto_;
+};
+
+// Update the given leaf models using the batch of labels.
+class UpdateModelV4Op : public OpKernel {
+ public:
+ explicit UpdateModelV4Op(OpKernelConstruction* context) : OpKernel(context) {
+ string serialized_params;
+ OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
+ ParseProtoUnlimited(&param_proto_, serialized_params);
+
+ model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& leaf_ids = context->input(1);
+ const Tensor& input_labels = context->input(2);
+ const Tensor& input_weights = context->input(3);
+
+ DecisionTreeResource* decision_tree_resource;
+ OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
+ &decision_tree_resource));
+ mutex_lock l(*decision_tree_resource->get_mutex());
+ core::ScopedUnref unref_me(decision_tree_resource);
+
+ const int num_data = input_labels.shape().dim_size(0);
+ const int32 label_dim =
+ input_labels.shape().dims() <= 1
+ ? 0
+ : static_cast<int>(input_labels.shape().dim_size(1));
+ const int32 num_targets =
+ param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
+
+ TensorInputTarget target(input_labels, input_weights, num_targets);
+
+ // TODO(gilberth): Make this thread safe and multi-thread.
+ UpdateModel(leaf_ids, target, 0, num_data, decision_tree_resource);
+ }
+
+ void UpdateModel(const Tensor& leaf_ids, const TensorInputTarget& target,
+ int32 start, int32 end,
+ DecisionTreeResource* decision_tree_resource) {
+ const auto leaves = leaf_ids.unaligned_flat<int32>();
+ for (int i = start; i < end; ++i) {
+ model_op_->UpdateModel(
+ decision_tree_resource->get_mutable_tree_node(leaves(i))
+ ->mutable_leaf(),
+ &target, i);
+ }
+ }
+
+ private:
+ std::unique_ptr<LeafModelOperator> model_op_;
+ TensorForestParams param_proto_;
+};
+
// Op for getting feature usage counts.
class FeatureUsageCountsOp : public OpKernel {
public:
@@ -286,8 +441,14 @@ REGISTER_KERNEL_BUILDER(Name("TreeSize").Device(DEVICE_CPU), TreeSizeOp);
REGISTER_KERNEL_BUILDER(Name("TreePredictionsV4").Device(DEVICE_CPU),
TreePredictionsV4Op);
+REGISTER_KERNEL_BUILDER(Name("TraverseTreeV4").Device(DEVICE_CPU),
+ TraverseTreeV4Op);
+
REGISTER_KERNEL_BUILDER(Name("FeatureUsageCounts").Device(DEVICE_CPU),
FeatureUsageCountsOp);
+REGISTER_KERNEL_BUILDER(Name("UpdateModelV4").Device(DEVICE_CPU),
+ UpdateModelV4Op);
+
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
index cece61a54c..0fdab8e6e0 100644
--- a/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/model_ops_test.cc
@@ -64,6 +64,33 @@ TEST(ModelOpsTest, TreePredictionsV4_ShapeFn) {
INFER_OK(op, "?;?;?;?;[10,11]", "[?,?]");
}
+TEST(ModelOpsTest, TraverseTreeV4_ShapeFn) {
+ ShapeInferenceTestOp op("TraverseTreeV4");
+ TF_ASSERT_OK(NodeDefBuilder("test", "TraverseTreeV4")
+ .Input("a", 0, DT_RESOURCE)
+ .Input("b", 1, DT_FLOAT)
+ .Input("c", 2, DT_INT64)
+ .Input("d", 3, DT_FLOAT)
+ .Input("e", 5, DT_INT64)
+ .Attr("input_spec", "")
+ .Attr("params", "")
+ .Finalize(&op.node_def));
+
+ // num_points = 2, sparse shape not known
+ INFER_OK(op, "?;[2,3];?;?;?", "[d1_0]");
+
+ // num_points = 2, sparse and dense shape rank known and > 1
+ INFER_OK(op, "?;[2,3];?;?;[10,11]", "[d1_0]");
+
+ // num_points = 2, sparse shape rank known and > 1
+ INFER_OK(op, "?;?;?;?;[10,11]", "[?]");
+}
+
+TEST(ModelOpsTest, UpdateModelV4_ShapeFn) {
+ ShapeInferenceTestOp op("UpdateModelV4");
+ INFER_OK(op, "[1];?;?;?", "");
+}
+
TEST(ModelOpsTest, FeatureUsageCounts_ShapeFn) {
ShapeInferenceTestOp op("FeatureUsageCounts");
INFER_OK(op, "[1]", "[?]");
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
index 260e03df26..b6d57ef952 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc
@@ -141,18 +141,6 @@ class FertileStatsDeserializeOp : public OpKernel {
TensorForestParams param_proto_;
};
-void TraverseTree(const DecisionTreeResource* tree_resource,
- const std::unique_ptr<TensorDataSet>& data, int32 start,
- int32 end, std::vector<int32>* leaf_ids,
- std::vector<int32>* leaf_depths) {
- for (int i = start; i < end; ++i) {
- int32 depth;
- const int32 leaf_id = tree_resource->TraverseTree(data, i, &depth);
- (*leaf_ids)[i] = leaf_id;
- (*leaf_depths)[i] = depth;
- }
-}
-
// Try to update a leaf's stats by acquiring its lock. If it can't be
// acquired, put it in a waiting queue to come back to later and try the next
// one. Once all leaf_ids have been visited, cycle through the waiting ids
@@ -160,28 +148,27 @@ void TraverseTree(const DecisionTreeResource* tree_resource,
void UpdateStats(FertileStatsResource* fertile_stats_resource,
const std::unique_ptr<TensorDataSet>& data,
const TensorInputTarget& target, int num_targets,
- const std::vector<int32>& leaf_ids,
- const std::vector<int32>& leaf_depths,
+ const Tensor& leaf_ids_tensor,
std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
mutex* set_lock, int32 start, int32 end,
std::unordered_set<int32>* ready_to_split) {
+ const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
+
// Stores leaf_id, leaf_depth, example_id for examples that are waiting
// on another to finish.
- std::queue<std::tuple<int32, int32, int32>> waiting;
+ std::queue<std::tuple<int32, int32>> waiting;
int32 i = start;
while (i < end || !waiting.empty()) {
int32 leaf_id;
- int32 leaf_depth;
int32 example_id;
bool was_waiting = false;
if (i >= end) {
- std::tie(leaf_id, leaf_depth, example_id) = waiting.front();
+ std::tie(leaf_id, example_id) = waiting.front();
waiting.pop();
was_waiting = true;
} else {
- leaf_id = leaf_ids[i];
- leaf_depth = leaf_depths[i];
+ leaf_id = leaf_ids(i);
example_id = i;
++i;
}
@@ -190,14 +177,14 @@ void UpdateStats(FertileStatsResource* fertile_stats_resource,
leaf_lock->lock();
} else {
if (!leaf_lock->try_lock()) {
- waiting.emplace(leaf_id, leaf_depth, example_id);
+ waiting.emplace(leaf_id, example_id);
continue;
}
}
bool is_finished;
fertile_stats_resource->AddExampleToStatsAndInitialize(
- data, &target, {example_id}, leaf_id, leaf_depth, &is_finished);
+ data, &target, {example_id}, leaf_id, &is_finished);
leaf_lock->unlock();
if (is_finished) {
set_lock->lock();
@@ -214,8 +201,8 @@ void UpdateStatsCollated(
const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
int num_targets,
const std::unordered_map<int32, std::vector<int>>& leaf_examples,
- const std::vector<int32>& leaf_depths, mutex* set_lock, int32 start,
- int32 end, std::unordered_set<int32>* ready_to_split) {
+ mutex* set_lock, int32 start, int32 end,
+ std::unordered_set<int32>* ready_to_split) {
auto it = leaf_examples.begin();
std::advance(it, start);
auto end_it = leaf_examples.begin();
@@ -224,8 +211,7 @@ void UpdateStatsCollated(
int32 leaf_id = it->first;
bool is_finished;
fertile_stats_resource->AddExampleToStatsAndInitialize(
- data, &target, it->second, leaf_id, leaf_depths[it->second[0]],
- &is_finished);
+ data, &target, it->second, leaf_id, &is_finished);
if (is_finished) {
set_lock->lock();
ready_to_split->insert(leaf_id);
@@ -261,6 +247,7 @@ class ProcessInputOp : public OpKernel {
const Tensor& sparse_input_shape = context->input(5);
const Tensor& input_labels = context->input(6);
const Tensor& input_weights = context->input(7);
+ const Tensor& leaf_ids_tensor = context->input(8);
data_set_->set_input_tensors(input_data, sparse_input_indices,
sparse_input_values, sparse_input_shape);
@@ -281,22 +268,7 @@ class ProcessInputOp : public OpKernel {
auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
int num_threads = worker_threads->num_threads;
- // First find the leaf ids for each example.
- std::vector<int32> leaf_ids(num_data);
-
- // The depth of the leaf for example i.
- std::vector<int32> leaf_depths(num_data);
-
- const int64 costPerTraverse = 500;
- auto traverse = [this, &leaf_ids, &leaf_depths, tree_resource, num_data](
- int64 start, int64 end) {
- CHECK(start <= end);
- CHECK(end <= num_data);
- TraverseTree(tree_resource, data_set_, static_cast<int32>(start),
- static_cast<int32>(end), &leaf_ids, &leaf_depths);
- };
- Shard(num_threads, worker_threads->workers, num_data, costPerTraverse,
- traverse);
+ const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
// Create one mutex per leaf. We need to protect access to leaf pointers,
// so instead of grouping examples by leaf, we spread examples out among
@@ -306,10 +278,11 @@ class ProcessInputOp : public OpKernel {
std::unordered_map<int32, std::vector<int>> leaf_examples;
if (param_proto_.collate_examples()) {
for (int i = 0; i < num_data; ++i) {
- leaf_examples[leaf_ids[i]].push_back(i);
+ leaf_examples[leaf_ids(i)].push_back(i);
}
} else {
- for (const int32 id : leaf_ids) {
+ for (int i = 0; i < num_data; ++i) {
+ const int32 id = leaf_ids(i);
if (FindOrNull(locks, id) == nullptr) {
// TODO(gilberth): Consider using a memory pool for these.
locks[id] = std::unique_ptr<mutex>(new mutex);
@@ -335,27 +308,26 @@ class ProcessInputOp : public OpKernel {
// from a digits run on local desktop. Heuristics might be necessary
// if it really matters that much.
const int64 costPerUpdate = 1000;
- auto update = [this, &target, &leaf_ids, &leaf_depths, &num_targets,
+ auto update = [this, &target, &leaf_ids_tensor, &num_targets,
fertile_stats_resource, &locks, &set_lock, &ready_to_split,
num_data](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_data);
UpdateStats(fertile_stats_resource, data_set_, target, num_targets,
- leaf_ids, leaf_depths, &locks, &set_lock,
- static_cast<int32>(start), static_cast<int32>(end),
- &ready_to_split);
+ leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start),
+ static_cast<int32>(end), &ready_to_split);
};
- auto update_collated = [this, &target, &leaf_ids, &num_targets,
- &leaf_depths, fertile_stats_resource, tree_resource,
- &leaf_examples, &set_lock, &ready_to_split,
+ auto update_collated = [this, &target, &num_targets, fertile_stats_resource,
+ tree_resource, &leaf_examples, &set_lock,
+ &ready_to_split,
num_leaves](int64 start, int64 end) {
CHECK(start <= end);
CHECK(end <= num_leaves);
UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_,
- target, num_targets, leaf_examples, leaf_depths,
- &set_lock, static_cast<int32>(start),
- static_cast<int32>(end), &ready_to_split);
+ target, num_targets, leaf_examples, &set_lock,
+ static_cast<int32>(start), static_cast<int32>(end),
+ &ready_to_split);
};
if (param_proto_.collate_examples()) {
@@ -411,7 +383,8 @@ class GrowTreeOp : public OpKernel {
const int32 num_nodes =
static_cast<int32>(finished_nodes.shape().dim_size(0));
- // TODO(gilberth): distribute this work over a number of threads.
+ // This op takes so little of the time for one batch that it isn't worth
+ // threading this.
for (int i = 0;
i < num_nodes &&
tree_resource->decision_tree().decision_tree().nodes_size() <
@@ -420,16 +393,14 @@ class GrowTreeOp : public OpKernel {
const int32 node = finished(i);
std::unique_ptr<SplitCandidate> best(new SplitCandidate);
int32 parent_depth;
+ // TODO(gilberth): Pushing these to an output would allow the complete
+ // decoupling of tree from resource.
bool found =
fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
if (found) {
std::vector<int32> new_children;
tree_resource->SplitNode(node, best.get(), &new_children);
fertile_stats_resource->Allocate(parent_depth, new_children);
- fertile_stats_resource->set_leaf_stat(best->left_stats(),
- new_children[0]);
- fertile_stats_resource->set_leaf_stat(best->right_stats(),
- new_children[1]);
// We are done with best, so it is now safe to clear node.
fertile_stats_resource->Clear(node);
CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
@@ -444,20 +415,17 @@ class GrowTreeOp : public OpKernel {
TensorForestParams param_proto_;
};
-void FinalizeLeaf(const LeafStat& leaf_stats, bool is_regression,
- bool drop_final_class,
+void FinalizeLeaf(bool is_regression, bool drop_final_class,
const std::unique_ptr<LeafModelOperator>& leaf_op,
decision_trees::Leaf* leaf) {
- leaf_op->ExportModel(leaf_stats, leaf);
-
- // TODO(thomaswc): Move the rest of this into ExportModel.
-
// regression models are already stored in leaf in normalized form.
if (is_regression) {
return;
}
- float sum = leaf_stats.weight_sum();
+ // TODO(gilberth): Calculate the leaf's sum.
+ float sum = 0;
+ LOG(FATAL) << "FinalizeTreeOp is disabled for now.";
if (sum <= 0.0) {
LOG(WARNING) << "Leaf with sum " << sum << " has stats "
<< leaf->ShortDebugString();
@@ -517,8 +485,7 @@ class FinalizeTreeOp : public OpKernel {
->mutable_decision_tree()
->mutable_nodes(i);
if (node->has_leaf()) {
- const auto& leaf_stats = fertile_stats_resource->leaf_stat(i);
- FinalizeLeaf(leaf_stats, param_proto_.is_regression(),
+ FinalizeLeaf(param_proto_.is_regression(),
param_proto_.drop_final_class(), model_op_,
node->mutable_leaf());
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc
index e5b86b0520..b3aa3a96f4 100644
--- a/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops_test.cc
@@ -45,7 +45,7 @@ TEST(StatsOpsTest, GrowTreeV4_ShapeFn) {
TEST(StatsOpsTest, ProcessInputV4_ShapeFn) {
ShapeInferenceTestOp op("ProcessInputV4");
- INFER_OK(op, "[1];[1];?;?;?;?;?;?", "[?]");
+ INFER_OK(op, "[1];[1];?;?;?;?;?;?;?", "[?]");
}
TEST(StatsOpsTest, FinalizeTree_ShapeFn) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
index 165685ca53..881e4339a7 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc
@@ -18,6 +18,7 @@ namespace tensorflow {
namespace tensorforest {
using decision_trees::DecisionTree;
+using decision_trees::Leaf;
using decision_trees::TreeNode;
int32 DecisionTreeResource::TraverseTree(
@@ -51,13 +52,15 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
new_children->push_back(newid);
TreeNode* new_left = tree->add_nodes();
new_left->mutable_node_id()->set_value(newid++);
- new_left->mutable_leaf();
+ Leaf* left_leaf = new_left->mutable_leaf();
+ model_op_->ExportModel(best->left_stats(), left_leaf);
// right
new_children->push_back(newid);
TreeNode* new_right = tree->add_nodes();
new_right->mutable_node_id()->set_value(newid);
- new_right->mutable_leaf();
+ Leaf* right_leaf = new_right->mutable_leaf();
+ model_op_->ExportModel(best->right_stats(), right_leaf);
node->clear_leaf();
node->mutable_binary_node()->Swap(best->mutable_split());
@@ -72,7 +75,7 @@ void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best,
void DecisionTreeResource::MaybeInitialize() {
DecisionTree* tree = decision_tree_->mutable_decision_tree();
if (tree->nodes_size() == 0) {
- tree->add_nodes()->mutable_leaf();
+ model_op_->InitModel(tree->add_nodes()->mutable_leaf());
} else if (node_evaluators_.empty()) { // reconstruct evaluators
for (const auto& node : tree->nodes()) {
if (node.has_leaf()) {
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
index c8f09d8e07..438d3d817c 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h
@@ -31,8 +31,10 @@ namespace tensorforest {
class DecisionTreeResource : public ResourceBase {
public:
// Constructor.
- explicit DecisionTreeResource()
- : decision_tree_(new decision_trees::Model()) {}
+ explicit DecisionTreeResource(const TensorForestParams& params)
+ : params_(params), decision_tree_(new decision_trees::Model()) {
+ model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_);
+ }
string DebugString() override {
return strings::StrCat("DecisionTree[size=",
@@ -79,7 +81,9 @@ class DecisionTreeResource : public ResourceBase {
private:
mutex mu_;
+ const TensorForestParams params_;
std::unique_ptr<decision_trees::Model> decision_tree_;
+ std::shared_ptr<LeafModelOperator> model_op_;
std::vector<std::unique_ptr<DecisionNodeEvaluator>> node_evaluators_;
};
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc
index 5c1b7454ae..7f914aac31 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc
@@ -20,14 +20,8 @@ namespace tensorflow {
namespace tensorforest {
void FertileStatsResource::AddExampleToStatsAndInitialize(
- const std::unique_ptr<TensorDataSet>& input_data,
- const InputTarget* target, const std::vector<int>& examples,
- int32 node_id, int32 node_depth, bool* is_finished) {
- // Set leaf's counts for calculating probabilities.
- for (int example : examples) {
- model_op_->UpdateModel(&leaf_stats_[node_id], target, example);
- }
-
+ const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
+ const std::vector<int>& examples, int32 node_id, bool* is_finished) {
// Update stats or initialize if needed.
if (collection_op_->IsInitialized(node_id)) {
collection_op_->AddExample(input_data, target, examples, node_id);
@@ -47,8 +41,6 @@ void FertileStatsResource::AddExampleToStatsAndInitialize(
}
void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) {
- leaf_stats_[node_id] = LeafStat();
- model_op_->InitModel(&leaf_stats_[node_id]);
collection_op_->InitializeSlot(node_id, depth);
}
@@ -62,7 +54,6 @@ void FertileStatsResource::Allocate(int32 parent_depth,
void FertileStatsResource::Clear(int32 node) {
collection_op_->ClearSlot(node);
- leaf_stats_.erase(node);
}
bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
@@ -71,27 +62,16 @@ bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
}
void FertileStatsResource::MaybeInitialize() {
- if (leaf_stats_.empty()) {
- AllocateNode(0, 0);
- }
+ collection_op_->MaybeInitialize();
}
void FertileStatsResource::ExtractFromProto(const FertileStats& stats) {
collection_op_ =
SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_);
collection_op_->ExtractFromProto(stats);
- for (int i = 0; i < stats.node_to_slot_size(); ++i) {
- const auto& slot = stats.node_to_slot(i);
- leaf_stats_[slot.node_id()] = slot.leaf_stats();
- }
}
void FertileStatsResource::PackToProto(FertileStats* stats) const {
- for (const auto& entry : leaf_stats_) {
- auto* slot = stats->add_node_to_slot();
- *slot->mutable_leaf_stats() = entry.second;
- slot->set_node_id(entry.first);
- }
collection_op_->PackToProto(stats);
}
} // namespace tensorforest
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
index 34ec945e84..dacf033d99 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h
@@ -51,7 +51,6 @@ class FertileStatsResource : public ResourceBase {
// Resets the resource and frees the proto.
// Caller needs to hold the mutex lock while calling this.
void Reset() {
- leaf_stats_.clear();
}
// Reset the stats for a node, but leave the leaf_stats intact.
@@ -71,7 +70,7 @@ class FertileStatsResource : public ResourceBase {
void AddExampleToStatsAndInitialize(
const std::unique_ptr<TensorDataSet>& input_data,
const InputTarget* target, const std::vector<int>& examples,
- int32 node_id, int32 node_depth, bool* is_finished);
+ int32 node_id, bool* is_finished);
// Allocate a fertile slot for each ready node, then new children up to
// max_fertile_nodes_.
@@ -85,19 +84,11 @@ class FertileStatsResource : public ResourceBase {
// was found.
bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth);
- const LeafStat& leaf_stat(int32 node_id) {
- return leaf_stats_[node_id];
- }
-
- void set_leaf_stat(const LeafStat& stat, int32 node_id) {
- leaf_stats_[node_id] = stat;
- }
private:
mutex mu_;
std::shared_ptr<LeafModelOperator> model_op_;
std::unique_ptr<SplitCollectionOperator> collection_op_;
- std::unordered_map<int32, LeafStat> leaf_stats_;
const TensorForestParams params_;
void AllocateNode(int32 node_id, int32 depth);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
index 49e425642d..d43c068e46 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc
@@ -17,6 +17,8 @@
namespace tensorflow {
namespace tensorforest {
+using decision_trees::Leaf;
+
std::unique_ptr<LeafModelOperator>
LeafModelOperatorFactory::CreateLeafModelOperator(
const TensorForestParams& params) {
@@ -50,24 +52,21 @@ float DenseClassificationLeafModelOperator::GetOutputValue(
}
void DenseClassificationLeafModelOperator::UpdateModel(
- LeafStat* leaf, const InputTarget* target,
- int example) const {
+ Leaf* leaf, const InputTarget* target, int example) const {
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
QCHECK_LT(int_label, params_.num_outputs())
<< "Got label greater than indicated number of classes. Is "
"params.num_classes set correctly?";
QCHECK_GE(int_label, 0);
- auto* val = leaf->mutable_classification()->mutable_dense_counts()
- ->mutable_value(int_label);
+ auto* val = leaf->mutable_vector()->mutable_value(int_label);
+
float weight = target->GetTargetWeight(example);
val->set_float_value(val->float_value() + weight);
- leaf->set_weight_sum(leaf->weight_sum() + weight);
}
-void DenseClassificationLeafModelOperator::InitModel(
- LeafStat* leaf) const {
+void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
for (int i = 0; i < params_.num_outputs(); ++i) {
- leaf->mutable_classification()->mutable_dense_counts()->add_value();
+ leaf->mutable_vector()->add_value();
}
}
@@ -88,17 +87,15 @@ float SparseClassificationLeafModelOperator::GetOutputValue(
}
void SparseClassificationLeafModelOperator::UpdateModel(
- LeafStat* leaf, const InputTarget* target,
- int example) const {
+ Leaf* leaf, const InputTarget* target, int example) const {
const int32 int_label = target->GetTargetAsClassIndex(example, 0);
QCHECK_LT(int_label, params_.num_outputs())
<< "Got label greater than indicated number of classes. Is "
"params.num_classes set correctly?";
QCHECK_GE(int_label, 0);
const float weight = target->GetTargetWeight(example);
- leaf->set_weight_sum(leaf->weight_sum() + weight);
- auto value_map = leaf->mutable_classification()->mutable_sparse_counts()
- ->mutable_sparse_value();
+
+ auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
auto it = value_map->find(int_label);
if (it == value_map->end()) {
(*value_map)[int_label].set_float_value(weight);
@@ -123,8 +120,8 @@ float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
}
void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
- LeafStat* leaf, const InputTarget* target, int example) const {
- if (leaf->classification().has_dense_counts()) {
+ Leaf* leaf, const InputTarget* target, int example) const {
+ if (leaf->has_vector()) {
return dense_->UpdateModel(leaf, target, example);
} else {
return sparse_->UpdateModel(leaf, target, example);
@@ -146,15 +143,15 @@ float RegressionLeafModelOperator::GetOutputValue(
return leaf.vector().value(o).float_value();
}
-void RegressionLeafModelOperator::InitModel(
- LeafStat* leaf) const {
+void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
for (int i = 0; i < params_.num_outputs(); ++i) {
- leaf->mutable_regression()->mutable_mean_output()->add_value();
+ leaf->mutable_vector()->add_value();
}
}
void RegressionLeafModelOperator::ExportModel(
const LeafStat& stat, decision_trees::Leaf* leaf) const {
+ leaf->clear_vector();
for (int i = 0; i < params_.num_outputs(); ++i) {
const float new_val =
stat.regression().mean_output().value(i).float_value() /
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
index 8aadefc403..946a648f22 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h
@@ -42,12 +42,11 @@ class LeafModelOperator {
int32 o) const = 0;
// Update the given Leaf's model with the given example.
- virtual void UpdateModel(LeafStat* leaf,
- const InputTarget* target,
- int example) const = 0;
+ virtual void UpdateModel(decision_trees::Leaf* leaf,
+ const InputTarget* target, int example) const = 0;
// Initialize an empty Leaf model.
- virtual void InitModel(LeafStat* leaf) const = 0;
+ virtual void InitModel(decision_trees::Leaf* leaf) const = 0;
virtual void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const = 0;
@@ -65,10 +64,10 @@ class DenseClassificationLeafModelOperator : public LeafModelOperator {
float GetOutputValue(const decision_trees::Leaf& leaf,
int32 o) const override;
- void UpdateModel(LeafStat* leaf, const InputTarget* target,
+ void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override;
- void InitModel(LeafStat* leaf) const override;
+ void InitModel(decision_trees::Leaf* leaf) const override;
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
@@ -84,10 +83,10 @@ class SparseClassificationLeafModelOperator : public LeafModelOperator {
float GetOutputValue(const decision_trees::Leaf& leaf,
int32 o) const override;
- void UpdateModel(LeafStat* leaf, const InputTarget* target,
+ void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override;
- void InitModel(LeafStat* leaf) const override {}
+ void InitModel(decision_trees::Leaf* leaf) const override {}
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
@@ -103,10 +102,10 @@ class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator {
float GetOutputValue(const decision_trees::Leaf& leaf,
int32 o) const override;
- void UpdateModel(LeafStat* leaf, const InputTarget* target,
+ void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override;
- void InitModel(LeafStat* leaf) const override {}
+ void InitModel(decision_trees::Leaf* leaf) const override {}
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
@@ -129,10 +128,10 @@ class RegressionLeafModelOperator : public LeafModelOperator {
// updating model and just using the seeded values. Can add this in
// with additional_data, though protobuf::Any is slow. Maybe make it
// optional. Maybe make any update optional.
- void UpdateModel(LeafStat* leaf, const InputTarget* target,
+ void UpdateModel(decision_trees::Leaf* leaf, const InputTarget* target,
int example) const override {}
- void InitModel(LeafStat* leaf) const override;
+ void InitModel(decision_trees::Leaf* leaf) const override;
void ExportModel(const LeafStat& stat,
decision_trees::Leaf* leaf) const override;
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
index 35268d15d3..ffd92c01f9 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc
@@ -63,12 +63,8 @@ constexpr char kRegressionStatProto[] =
"}";
void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
- std::unique_ptr<LeafStat> leaf(new LeafStat);
- op->InitModel(leaf.get());
-
Leaf l;
- op->ExportModel(*leaf, &l);
-
+ op->InitModel(&l);
// Make sure it was initialized correctly.
for (int i = 0; i < kNumClasses; ++i) {
EXPECT_EQ(op->GetOutputValue(l, i), 0);
@@ -80,11 +76,10 @@ void TestClassificationNormalUse(const std::unique_ptr<LeafModelOperator>& op) {
new TestableInputTarget(labels, weights, 1));
// Update and check value.
- op->UpdateModel(leaf.get(), target.get(), 0);
- op->UpdateModel(leaf.get(), target.get(), 1);
- op->UpdateModel(leaf.get(), target.get(), 2);
+ op->UpdateModel(&l, target.get(), 0);
+ op->UpdateModel(&l, target.get(), 1);
+ op->UpdateModel(&l, target.get(), 2);
- op->ExportModel(*leaf, &l);
EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4);
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
index 632408fd71..ccc412600c 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc
@@ -71,13 +71,13 @@ void SplitCollectionOperator::ExtractFromProto(
}
void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const {
- for (int i = 0; i < stats_proto->node_to_slot_size(); ++i) {
- auto* new_slot = stats_proto->mutable_node_to_slot(i);
- const auto& stats = stats_.at(new_slot->node_id());
+ for (const auto& pair : stats_) {
+ auto* new_slot = stats_proto->add_node_to_slot();
+ new_slot->set_node_id(pair.first);
if (params_.checkpoint_stats()) {
- stats->PackToProto(new_slot);
+ pair.second->PackToProto(new_slot);
}
- new_slot->set_depth(stats->depth());
+ new_slot->set_depth(pair.second->depth());
}
}
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
index 6990e82678..6c21c0bd34 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h
@@ -62,6 +62,14 @@ class SplitCollectionOperator {
// Create a new GrowStats for the given node id and initialize it.
virtual void InitializeSlot(int32 node_id, int32 depth);
+ // Called when the resource is deserialized, possibly needing an
+ // initialization.
+ virtual void MaybeInitialize() {
+ if (stats_.empty()) {
+ InitializeSlot(0, 0);
+ }
+ }
+
// Perform any necessary cleanup for any tracked state for the slot.
virtual void ClearSlot(int32 node_id) {
stats_.erase(node_id);
diff --git a/tensorflow/contrib/tensor_forest/ops/model_ops.cc b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
index 168f079f52..1227a70a2e 100644
--- a/tensorflow/contrib/tensor_forest/ops/model_ops.cc
+++ b/tensorflow/contrib/tensor_forest/ops/model_ops.cc
@@ -115,6 +115,58 @@ sparse_input_shape: The shape tensor from the SparseTensor input.
predictions: `predictions[i][j]` is the probability that input i is class j.
)doc");
+REGISTER_OP("TraverseTreeV4")
+ .Attr("input_spec: string")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("input_data: float")
+ .Input("sparse_input_indices: int64")
+ .Input("sparse_input_values: float")
+ .Input("sparse_input_shape: int64")
+ .Output("leaf_ids: int32")
+ .SetShapeFn([](InferenceContext* c) {
+ DimensionHandle num_points = c->UnknownDim();
+
+ if (c->RankKnown(c->input(1)) && c->Rank(c->input(1)) > 0 &&
+ c->Value(c->Dim(c->input(1), 0)) > 0) {
+ num_points = c->Dim(c->input(1), 0);
+ }
+
+ c->set_output(0, c->Vector(num_points));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the leaf ids for the given input data.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
+ gives the j-th feature of the i-th input.
+sparse_input_indices: The indices tensor from the SparseTensor input.
+sparse_input_values: The values tensor from the SparseTensor input.
+sparse_input_shape: The shape tensor from the SparseTensor input.
+leaf_ids: `leaf_ids[i]` is the leaf id for input i.
+)doc");
+
+REGISTER_OP("UpdateModelV4")
+ .Attr("params: string")
+ .Input("tree_handle: resource")
+ .Input("leaf_ids: int32")
+ .Input("input_labels: float")
+ .Input("input_weights: float")
+ .SetShapeFn(tensorflow::shape_inference::NoOutputs)
+ .Doc(R"doc(
+Updates the given leaves for each example with the new labels.
+
+params: A serialized TensorForestParams proto.
+tree_handle: The handle to the tree.
+leaf_ids: `leaf_ids[i]` is the leaf id for input i.
+input_labels: The training batch's labels as a 1 or 2-d tensor.
+ 'input_labels[i][j]' gives the j-th label/target for the i-th input.
+input_weights: The training batch's eample weights as a 1-d tensor.
+ 'input_weights[i]' gives the weight for the i-th input.
+)doc");
+
REGISTER_OP("FeatureUsageCounts")
.Attr("params: string")
.Input("tree_handle: resource")
diff --git a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
index 9652749768..e8b5c5d8a6 100644
--- a/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
+++ b/tensorflow/contrib/tensor_forest/ops/stats_ops.cc
@@ -98,6 +98,7 @@ REGISTER_OP("ProcessInputV4")
.Input("sparse_input_shape: int64")
.Input("input_labels: float")
.Input("input_weights: float")
+ .Input("leaf_ids: int32")
.Output("finished_nodes: int32")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(c->UnknownDim()));
@@ -122,6 +123,7 @@ input_weights: The training batch's eample weights as a 1-d tensor.
'input_weights[i]' gives the weight for the i-th input.
finished_nodes: A 1-d tensor of node ids that have finished and are ready to
grow.
+leaf_ids: `leaf_ids[i]` is the leaf id for input i.
)doc");
REGISTER_OP("FinalizeTree")
diff --git a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
index 4c7218305b..d240e2f6de 100644
--- a/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
+++ b/tensorflow/contrib/tensor_forest/python/ops/model_ops.py
@@ -18,12 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.tensor_forest.python.ops import gen_model_ops
-from tensorflow.contrib.tensor_forest.python.ops import stats_ops
# pylint: disable=unused-import
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import feature_usage_counts
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import traverse_tree_v4
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_predictions_v4
from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import tree_size
+from tensorflow.contrib.tensor_forest.python.ops.gen_model_ops import update_model_v4
# pylint: enable=unused-import
from tensorflow.contrib.util import loader
@@ -59,13 +60,7 @@ class TreeVariableSavable(saver.BaseSaverBuilder.SaveableObject):
name: the name to save the tree variable under.
"""
self.params = params
- deps = []
- if stats_handle is not None:
- deps.append(stats_ops.finalize_tree(
- tree_handle, stats_handle,
- params=params.serialized_params_proto))
- with ops.control_dependencies(deps):
- tensor = gen_model_ops.tree_serialize(tree_handle)
+ tensor = gen_model_ops.tree_serialize(tree_handle)
# slice_spec is useful for saving a slice from a variable.
# It's not meaningful the tree variable. So we just pass an empty value.
slice_spec = ""
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
index 7e6f00a13d..8198c228dd 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest_v4.py
@@ -27,6 +27,7 @@ from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.contrib.tensor_forest.python.ops import model_ops
from tensorflow.contrib.tensor_forest.python.ops import stats_ops
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import tf_logging as logging
@@ -240,6 +241,22 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
if input_data is None:
input_data = []
+ leaf_ids = model_ops.traverse_tree_v4(
+ self.variables.tree,
+ input_data,
+ sparse_indices,
+ sparse_values,
+ sparse_shape,
+ input_spec=data_spec.SerializeToString(),
+ params=self.params.serialized_params_proto)
+
+ update_model = model_ops.update_model_v4(
+ self.variables.tree,
+ leaf_ids,
+ input_labels,
+ input_weights,
+ params=self.params.serialized_params_proto)
+
finished_nodes = stats_ops.process_input_v4(
self.variables.tree,
self.variables.stats,
@@ -249,13 +266,17 @@ class RandomTreeGraphsV4(tensor_forest.RandomTreeGraphs):
sparse_shape,
input_labels,
input_weights,
+ leaf_ids,
input_spec=data_spec.SerializeToString(),
random_seed=random_seed,
params=self.params.serialized_params_proto)
- return stats_ops.grow_tree_v4(self.variables.tree, self.variables.stats,
- finished_nodes,
- params=self.params.serialized_params_proto)
+ with ops.control_dependencies([update_model]):
+ return stats_ops.grow_tree_v4(
+ self.variables.tree,
+ self.variables.stats,
+ finished_nodes,
+ params=self.params.serialized_params_proto)
def inference_graph(self, input_data, data_spec, sparse_features=None):
sparse_indices = []