aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-04 05:05:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-04 06:27:06 -0700
commitd05336271cfb38bb93db902e1a92b9284152a19f (patch)
treedd660f54ec8830e1a885dd75656b5456bf64f863
parent8544c8b65c846025af75685cd23f4a5d3611e4af (diff)
Speed up TensorForest by limiting the leaves that FinishedNodes looks at.
Change: 138182239
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc18
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py6
2 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
index 7afaa00fad..e1fc03362c 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/finished_nodes_op.cc
@@ -57,8 +57,8 @@ struct EvaluateParams {
};
void Evaluate(const EvaluateParams& params, mutex* mutex, int32 start,
- int32 end, std::vector<int32>* final_finished_leaves,
- std::vector<int32>* final_stale) {
+ int32 end, std::unordered_set<int32>* final_finished_leaves,
+ std::unordered_set<int32>* final_stale) {
const auto leaves = params.leaves.unaligned_flat<int32>();
const auto node_map = params.node_to_accumulator.unaligned_flat<int32>();
const auto sums = params.accumulator_sums.tensor<float, 2>();
@@ -77,9 +77,10 @@ void Evaluate(const EvaluateParams& params, mutex* mutex, int32 start,
simple_philox.reset(new random::SimplePhilox(&rnd_gen));
}
+ std::unordered_set<int32> visited;
for (int32 i = start; i < end; i++) {
const int32 leaf = internal::SubtleMustCopy(leaves(i));
- if (leaf == -1) {
+ if (leaf == -1 || visited.find(leaf) != visited.end()) {
continue;
}
if (!FastBoundsCheck(leaf, node_map.size())) {
@@ -119,11 +120,12 @@ void Evaluate(const EvaluateParams& params, mutex* mutex, int32 start,
if (finished) {
finished_leaves.push_back(leaf);
}
+
+ visited.insert(leaf);
}
mutex_lock m(*mutex);
- final_finished_leaves->insert(final_finished_leaves->end(),
- finished_leaves.begin(), finished_leaves.end());
- final_stale->insert(final_stale->end(), stale.begin(), stale.end());
+ final_finished_leaves->insert(finished_leaves.begin(), finished_leaves.end());
+ final_stale->insert(stale.begin(), stale.end());
}
} // namespace
@@ -298,8 +300,8 @@ class FinishedNodes : public OpKernel {
}
}
- std::vector<int32> finished_leaves;
- std::vector<int32> stale;
+ std::unordered_set<int32> finished_leaves;
+ std::unordered_set<int32> stale;
mutex m;
// Require at least 100 leaves per thread. I guess that's about 800 cost
// per unit. This isn't well defined.
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index b7b2fb9637..c20279bc50 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -702,8 +702,12 @@ class RandomTreeGraphs(object):
# Calculate finished nodes.
with ops.control_dependencies(splits_update_ops):
+ # Passing input_leaves to finished nodes here means that nodes that
+ # have become stale won't be deallocated until an input reaches them,
+ # because we're trying to avoid considering every fertile node for
+ # performance reasons.
finished, stale = self.training_ops.finished_nodes(
- self.variables.accumulator_to_node_map,
+ input_leaves,
self.variables.node_to_accumulator_map,
self.variables.candidate_split_sums,
self.variables.candidate_split_squares,