diff options
Diffstat (limited to 'tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc')
-rw-r--r-- | tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc | 101 |
1 files changed, 36 insertions, 65 deletions
diff --git a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc index 026262e47f..33638ca7e6 100644 --- a/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc +++ b/tensorflow/contrib/tensor_forest/core/ops/update_fertile_slots_op.cc @@ -36,7 +36,7 @@ using tensorforest::Initialize; using tensorforest::WeightedGiniImpurity; REGISTER_OP("UpdateFertileSlots") - .Attr("max_depth: int") + .Attr("max_depth: int") .Attr("regression: bool = False") .Input("finished: int32") .Input("non_fertile_leaves: int32") @@ -45,11 +45,10 @@ REGISTER_OP("UpdateFertileSlots") .Input("tree_depths: int32") .Input("accumulator_sums: float") .Input("node_to_accumulator: int32") + .Input("stale_leaves: int32") .Output("node_map_updates: int32") .Output("accumulators_cleared: int32") .Output("accumulators_allocated: int32") - .Output("new_nonfertile_leaves: int32") - .Output("new_nonfertile_leaves_scores: float") .Doc(R"doc( Updates accumulator slots to reflect finished or newly fertile nodes. @@ -77,6 +76,8 @@ accumulator_sums: For classification, `accumulator_sums[a][c]` records how of training examples that have been seen. node_to_accumulator: `node_to_accumulator[i]` is the accumulator slot used by fertile node i, or -1 if node i isn't fertile. +stale_leaves:= A 1-d int32 tensor containing the indices of all leaves that + have stopped accumulating statistics because they are too old. node_map_updates:= A 2-d int32 tensor describing the changes that need to be applied to the node_to_accumulator map. Intended to be used with `tf.scatter_update(node_to_accumulator, @@ -86,10 +87,7 @@ accumulators_cleared:= A 1-d int32 tensor containing the indices of all the accumulator slots that need to be cleared. accumulators_allocated:= A 1-d int32 tensor containing the indices of all the accumulator slots that need to be allocated. -new_nonfertile_leaves:= A 1-d int32 tensor containing the indices of all the - leaves that are now non-fertile. -new_nonfertile_leaves_scores: `new_nonfertile_leaves_scores[i]` contains the - splitting score for the non-fertile leaf `new_nonfertile_leaves[i]`. + )doc"); class UpdateFertileSlots : public OpKernel { @@ -112,6 +110,7 @@ class UpdateFertileSlots : public OpKernel { const Tensor& accumulator_sums = context->input(5); const Tensor& node_to_accumulator = context->input(6); + const Tensor& stale_leaves = context->input(7); OP_REQUIRES(context, finished.shape().dims() == 1, errors::InvalidArgument( @@ -134,6 +133,9 @@ class UpdateFertileSlots : public OpKernel { OP_REQUIRES(context, node_to_accumulator.shape().dims() == 1, errors::InvalidArgument( "node_to_accumulator should be one-dimensional")); + OP_REQUIRES(context, stale_leaves.shape().dims() == 1, + errors::InvalidArgument( + "stale_leaves should be one-dimensional")); OP_REQUIRES( context, @@ -151,6 +153,7 @@ class UpdateFertileSlots : public OpKernel { if (!CheckTensorBounds(context, tree_depths)) return; if (!CheckTensorBounds(context, accumulator_sums)) return; if (!CheckTensorBounds(context, node_to_accumulator)) return; + if (!CheckTensorBounds(context, stale_leaves)) return; // Read finished accumulators into a set for quick lookup. const auto node_map = node_to_accumulator.unaligned_flat<int32>(); @@ -164,6 +167,16 @@ class UpdateFertileSlots : public OpKernel { errors::InvalidArgument("finished node is outside the valid range")); finished_accumulators.insert(node_map(node)); } + // Stale accumulators are also finished for the purposes of clearing + // and re-allocating. + const auto stale_vec = stale_leaves.unaligned_flat<int32>(); + for (int32 i = 0; i < stale_vec.size(); ++i) { + const int32 node = internal::SubtleMustCopy(stale_vec(i)); + OP_REQUIRES( + context, FastBoundsCheck(node, node_map.size()), + errors::InvalidArgument("stale node is outside the valid range")); + finished_accumulators.insert(node_map(node)); + } // Construct leaf heap to sort leaves to allocate accumulators to. const int32 num_nodes = static_cast<int32>(tree_depths.shape().dim_size(0)); @@ -210,11 +223,10 @@ class UpdateFertileSlots : public OpKernel { } // Construct and fill outputs. - SetNodeMapUpdates(accumulators_to_node, finished, context); + SetNodeMapUpdates(accumulators_to_node, finished, stale_leaves, context); SetAccumulatorsCleared(finished_accumulators, accumulators_to_node, context); SetAccumulatorsAllocated(accumulators_to_node, context); - SetNewNonFertileLeaves(values.get(), i, context); } private: @@ -228,18 +240,20 @@ class UpdateFertileSlots : public OpKernel { typedef TopN<std::pair<int32, float>, OrderBySecondGreater> LeafHeapType; typedef std::vector<std::pair<int32, float>> HeapValuesType; - // Creates an update tensor for node to accumulator map. Sets finished nodes - // to -1 (no accumulator assigned) and newly allocated nodes to their - // accumulator. + // Creates an update tensor for node to accumulator map. Sets finished and + // stale nodes to -1 (no accumulator assigned) and newly allocated nodes to + // their accumulator. void SetNodeMapUpdates( const std::unordered_map<int32, int32>& accumulators_to_node, - const Tensor& finished, OpKernelContext* context) { + const Tensor& finished, const Tensor& stale, OpKernelContext* context) { // Node map updates. Tensor* output_node_map = nullptr; TensorShape node_map_shape; node_map_shape.AddDim(2); - node_map_shape.AddDim(accumulators_to_node.size() + - static_cast<int32>(finished.shape().dim_size(0))); + node_map_shape.AddDim( + accumulators_to_node.size() + + static_cast<int32>(stale.shape().dim_size(0) + + finished.shape().dim_size(0))); OP_REQUIRES_OK(context, context->allocate_output(0, node_map_shape, &output_node_map)); @@ -254,6 +268,13 @@ class UpdateFertileSlots : public OpKernel { out_node(1, output_slot) = -1; ++output_slot; } + // Set stale nodes to -1. + const auto stale_vec = stale.unaligned_flat<int32>(); + for (int32 i = 0; i < stale_vec.size(); ++i) { + out_node(0, output_slot) = stale_vec(i); + out_node(1, output_slot) = -1; + ++output_slot; + } // Set newly allocated nodes to their allocator. for (const auto& node_alloc_pair : accumulators_to_node) { @@ -315,56 +336,6 @@ class UpdateFertileSlots : public OpKernel { } } - // Creates output tensors for non-fertile leaves and non-fertile leaf scores. - // Start indicates the index in values where the leaves that weren't - // allocated this round begin, and should thus be placed in the new - // nonfertile_leaves tensors. - void SetNewNonFertileLeaves(HeapValuesType* values, int32 start, - OpKernelContext* context) { - // Node map updates. - int32 num_values = static_cast<int32>(values->size()) - start; - - // Unfortunately, a zero-sized Variable results in an uninitialized - // error, probably because they check for zero size instead of - // a real inititalization condition. - bool fill_with_garbage = false; - if (num_values == 0) { - num_values = 1; - fill_with_garbage = true; - } - Tensor* output_nonfertile_leaves = nullptr; - TensorShape nonfertile_leaves_shape; - nonfertile_leaves_shape.AddDim(num_values); - OP_REQUIRES_OK(context, - context->allocate_output(3, nonfertile_leaves_shape, - &output_nonfertile_leaves)); - - auto out_nonfertile_leaves = - output_nonfertile_leaves->unaligned_flat<int32>(); - - Tensor* output_nonfertile_leaves_scores = nullptr; - TensorShape nonfertile_leaves_scores_shape; - nonfertile_leaves_scores_shape.AddDim(num_values); - OP_REQUIRES_OK(context, - context->allocate_output(4, nonfertile_leaves_scores_shape, - &output_nonfertile_leaves_scores)); - - auto out_nonfertile_leaves_scores = - output_nonfertile_leaves_scores->unaligned_flat<float>(); - - if (fill_with_garbage) { - out_nonfertile_leaves(0) = -1; - out_nonfertile_leaves_scores(0) = 0.0; - return; - } - - for (int32 i = start; i < values->size(); ++i) { - const std::pair<int32, float>& node = (*values)[i]; - out_nonfertile_leaves(i -start) = node.first; - out_nonfertile_leaves_scores(i - start) = node.second; - } - } - void ConstructLeafHeap(const Tensor& non_fertile_leaves, const Tensor& non_fertile_leaf_scores, const Tensor& tree_depths, int32 end_of_tree, |