aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-14 11:20:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-14 11:25:12 -0700
commit3d72bc69ee838e8c6b0b801e274aac4c31647b22 (patch)
tree3ffc29a90a5fb07238eb2f6c8ff2f6baa197470c
parent3a487def43800744d1b83e1f1bac4223b5e88a1c (diff)
Better gating of TensorForest pruning code.
PiperOrigin-RevId: 161981809
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
index 81b4534f10..63bfc1aef1 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc
@@ -207,7 +207,8 @@ void ClassificationStats::AddExample(
}
void ClassificationStats::CheckPrune() {
- if (IsFinished() || weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
+ if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
+ weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
return;
}
++prune_sample_epoch_;