aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc')
-rw-r--r--tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc101
1 files changed, 36 insertions, 65 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
index 026262e47f..33638ca7e6 100644
--- a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
+++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc
@@ -36,7 +36,7 @@ using tensorforest::Initialize;
using tensorforest::WeightedGiniImpurity;
REGISTER_OP("UpdateFertileSlots")
- .Attr("max_depth: int")
+ .Attr("max_depth: int")
.Attr("regression: bool = False")
.Input("finished: int32")
.Input("non_fertile_leaves: int32")
@@ -45,11 +45,10 @@ REGISTER_OP("UpdateFertileSlots")
.Input("tree_depths: int32")
.Input("accumulator_sums: float")
.Input("node_to_accumulator: int32")
+ .Input("stale_leaves: int32")
.Output("node_map_updates: int32")
.Output("accumulators_cleared: int32")
.Output("accumulators_allocated: int32")
- .Output("new_nonfertile_leaves: int32")
- .Output("new_nonfertile_leaves_scores: float")
.Doc(R"doc(
Updates accumulator slots to reflect finished or newly fertile nodes.
@@ -77,6 +76,8 @@ accumulator_sums: For classification, `accumulator_sums[a][c]` records how
of training examples that have been seen.
node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by
fertile node i, or -1 if node i isn't fertile.
+stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that
+ have stopped accumulating statistics because they are too old.
node_map_updates:= A 2-d int32 tensor describing the changes that need to
be applied to the node_to_accumulator map. Intended to be used with
`tf.scatter_update(node_to_accumulator,
@@ -86,10 +87,7 @@ accumulators_cleared:= A 1-d int32 tensor containing the indices of all
the accumulator slots that need to be cleared.
accumulators_allocated:= A 1-d int32 tensor containing the indices of all
the accumulator slots that need to be allocated.
-new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the
- leaves that are now non-fertile.
-new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the
- splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`.
+
)doc");
class UpdateFertileSlots : public OpKernel {
@@ -112,6 +110,7 @@ class UpdateFertileSlots : public OpKernel {
const Tensor& accumulator_sums = context->input(5);
const Tensor& node_to_accumulator = context->input(6);
+ const Tensor& stale_leaves = context->input(7);
OP_REQUIRES(context, finished.shape().dims() == 1,
errors::InvalidArgument(
@@ -134,6 +133,9 @@ class UpdateFertileSlots : public OpKernel {
OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1,
errors::InvalidArgument(
"node_to_accumulator should be one-dimensional"));
+ OP_REQUIRES(context, stale_leaves.shape().dims() == 1,
+ errors::InvalidArgument(
+ "stale_leaves should be one-dimensional"));
OP_REQUIRES(
context,
@@ -151,6 +153,7 @@ class UpdateFertileSlots : public OpKernel {
if (!CheckTensorBounds(context, tree_depths)) return;
if (!CheckTensorBounds(context, accumulator_sums)) return;
if (!CheckTensorBounds(context, node_to_accumulator)) return;
+ if (!CheckTensorBounds(context, stale_leaves)) return;
// Read finished accumulators into a set for quick lookup.
const auto node_map = node_to_accumulator.unaligned_flat<int32>();
@@ -164,6 +167,16 @@ class UpdateFertileSlots : public OpKernel {
errors::InvalidArgument("finished node is outside the valid range"));
finished_accumulators.insert(node_map(node));
}
+ // Stale accumulators are also finished for the purposes of clearing
+ // and re-allocating.
+ const auto stale_vec = stale_leaves.unaligned_flat<int32>();
+ for (int32 i = 0; i < stale_vec.size(); ++i) {
+ const int32 node = internal::SubtleMustCopy(stale_vec(i));
+ OP_REQUIRES(
+ context, FastBoundsCheck(node, node_map.size()),
+ errors::InvalidArgument("stale node is outside the valid range"));
+ finished_accumulators.insert(node_map(node));
+ }
// Construct leaf heap to sort leaves to allocate accumulators to.
const int32 num_nodes = static_cast<int32>(tree_depths.shape().dim_size(0));
@@ -210,11 +223,10 @@ class UpdateFertileSlots : public OpKernel {
}
// Construct and fill outputs.
- SetNodeMapUpdates(accumulators_to_node, finished, context);
+ SetNodeMapUpdates(accumulators_to_node, finished, stale_leaves, context);
SetAccumulatorsCleared(finished_accumulators,
accumulators_to_node, context);
SetAccumulatorsAllocated(accumulators_to_node, context);
- SetNewNonFertileLeaves(values.get(), i, context);
}
private:
@@ -228,18 +240,20 @@ class UpdateFertileSlots : public OpKernel {
typedef TopN<std::pair<int32, float>, OrderBySecondGreater> LeafHeapType;
typedef std::vector<std::pair<int32, float>> HeapValuesType;
- // Creates an update tensor for node to accumulator map. Sets finished nodes
- // to -1 (no accumulator assigned) and newly allocated nodes to their
- // accumulator.
+ // Creates an update tensor for node to accumulator map. Sets finished and
+ // stale nodes to -1 (no accumulator assigned) and newly allocated nodes to
+ // their accumulator.
void SetNodeMapUpdates(
const std::unordered_map<int32, int32>& accumulators_to_node,
- const Tensor& finished, OpKernelContext* context) {
+ const Tensor& finished, const Tensor& stale, OpKernelContext* context) {
// Node map updates.
Tensor* output_node_map = nullptr;
TensorShape node_map_shape;
node_map_shape.AddDim(2);
- node_map_shape.AddDim(accumulators_to_node.size() +
- static_cast<int32>(finished.shape().dim_size(0)));
+ node_map_shape.AddDim(
+ accumulators_to_node.size() +
+ static_cast<int32>(stale.shape().dim_size(0) +
+ finished.shape().dim_size(0)));
OP_REQUIRES_OK(context,
context->allocate_output(0, node_map_shape,
&output_node_map));
@@ -254,6 +268,13 @@ class UpdateFertileSlots : public OpKernel {
out_node(1, output_slot) = -1;
++output_slot;
}
+ // Set stale nodes to -1.
+ const auto stale_vec = stale.unaligned_flat<int32>();
+ for (int32 i = 0; i < stale_vec.size(); ++i) {
+ out_node(0, output_slot) = stale_vec(i);
+ out_node(1, output_slot) = -1;
+ ++output_slot;
+ }
// Set newly allocated nodes to their allocator.
for (const auto& node_alloc_pair : accumulators_to_node) {
@@ -315,56 +336,6 @@ class UpdateFertileSlots : public OpKernel {
}
}
- // Creates output tensors for non-fertile leaves and non-fertile leaf scores.
- // Start indicates the index in values where the leaves that weren't
- // allocated this round begin, and should thus be placed in the new
- // nonfertile_leaves tensors.
- void SetNewNonFertileLeaves(HeapValuesType* values, int32 start,
- OpKernelContext* context) {
- // Node map updates.
- int32 num_values = static_cast<int32>(values->size()) - start;
-
- // Unfortunately, a zero-sized Variable results in an uninitialized
- // error, probably because they check for zero size instead of
- // a real inititalization condition.
- bool fill_with_garbage = false;
- if (num_values == 0) {
- num_values = 1;
- fill_with_garbage = true;
- }
- Tensor* output_nonfertile_leaves = nullptr;
- TensorShape nonfertile_leaves_shape;
- nonfertile_leaves_shape.AddDim(num_values);
- OP_REQUIRES_OK(context,
- context->allocate_output(3, nonfertile_leaves_shape,
- &output_nonfertile_leaves));
-
- auto out_nonfertile_leaves =
- output_nonfertile_leaves->unaligned_flat<int32>();
-
- Tensor* output_nonfertile_leaves_scores = nullptr;
- TensorShape nonfertile_leaves_scores_shape;
- nonfertile_leaves_scores_shape.AddDim(num_values);
- OP_REQUIRES_OK(context,
- context->allocate_output(4, nonfertile_leaves_scores_shape,
- &output_nonfertile_leaves_scores));
-
- auto out_nonfertile_leaves_scores =
- output_nonfertile_leaves_scores->unaligned_flat<float>();
-
- if (fill_with_garbage) {
- out_nonfertile_leaves(0) = -1;
- out_nonfertile_leaves_scores(0) = 0.0;
- return;
- }
-
- for (int32 i = start; i < values->size(); ++i) {
- const std::pair<int32, float>& node = (*values)[i];
- out_nonfertile_leaves(i -start) = node.first;
- out_nonfertile_leaves_scores(i - start) = node.second;
- }
- }
-
void ConstructLeafHeap(const Tensor& non_fertile_leaves,
const Tensor& non_fertile_leaf_scores,
const Tensor& tree_depths, int32 end_of_tree,