aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc
blob: ce67db797ded54f5023eaa89369d4781aad31a7c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"

#include <iterator>
#include <numeric>
#include <unordered_set>

#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/logging.h"

using tensorflow::Status;
using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
using tensorflow::random::PhiloxRandom;
using tensorflow::random::SimplePhilox;

namespace tensorflow {
namespace boosted_trees {
namespace utils {

Status DropoutUtils::DropOutTrees(
    const uint64 seed, const LearningRateDropoutDrivenConfig& config,
    const std::unordered_set<int32>& trees_not_to_drop,
    const std::vector<float>& weights, std::vector<int32>* dropped_trees,
    std::vector<float>* original_weights) {
  // Verify params.
  if (dropped_trees == nullptr) {
    return errors::Internal("Dropped trees is nullptr.");
  }
  if (original_weights == nullptr) {
    return errors::InvalidArgument("Original weights is nullptr.");
  }
  const float dropout_probability = config.dropout_probability();
  if (dropout_probability < 0 || dropout_probability > 1) {
    return errors::InvalidArgument(
        "Dropout probability must be in [0,1] range");
  }
  const float probability_of_skipping_dropout =
      config.probability_of_skipping_dropout();
  if (probability_of_skipping_dropout < 0 ||
      probability_of_skipping_dropout > 1) {
    return errors::InvalidArgument(
        "Probability of skipping dropout must be in [0,1] range");
  }
  const auto num_trees = weights.size();

  dropped_trees->clear();
  original_weights->clear();

  // If dropout is no op, return.
  if (dropout_probability == 0 || probability_of_skipping_dropout == 1.0) {
    return Status::OK();
  }

  // Roll the dice for each tree.
  PhiloxRandom philox(seed);
  SimplePhilox rng(&philox);

  std::vector<int32> trees_to_keep;

  // What is the probability of skipping dropout altogether.
  if (probability_of_skipping_dropout != 0) {
    // First roll the dice - do we do dropout
    double roll = rng.RandDouble();
    if (roll < probability_of_skipping_dropout) {
      // don't do dropout
      return Status::OK();
    }
  }

  for (int32 i = 0; i < num_trees; ++i) {
    // We can't drop some of the trees: for example, bias tree in batch mode,
    // or current tree that is built, in the batch mode.
    if (trees_not_to_drop.find(i) != trees_not_to_drop.end()) {
      continue;
    }
    double roll = rng.RandDouble();
    if (roll >= dropout_probability) {
      trees_to_keep.push_back(i);
    } else {
      dropped_trees->push_back(i);
    }
  }

  // Sort the dropped trees indices.
  std::sort(dropped_trees->begin(), dropped_trees->end());
  for (const int32 dropped_tree : *dropped_trees) {
    original_weights->push_back(weights[dropped_tree]);
  }

  return Status::OK();
}

void DropoutUtils::GetTreesWeightsForAddingTrees(
    const std::vector<int32>& dropped_trees,
    const std::vector<float>& dropped_trees_original_weights,
    const int32 new_trees_first_index, const int32 num_trees_to_add,
    std::vector<float>* current_weights, std::vector<int32>* num_updates) {
  CHECK(num_updates->size() == current_weights->size());
  // combined weight of trees that were dropped out

  const float dropped_sum =
      std::accumulate(dropped_trees_original_weights.begin(),
                      dropped_trees_original_weights.end(), 0.0);

  const int num_dropped = dropped_trees.size();

  // Allocate additional weight for the new tree
  const float total_new_trees_weight = dropped_sum / (num_dropped + 1);

  for (int i = 0; i < num_trees_to_add; ++i) {
    const int32 new_tree_index = new_trees_first_index + i;
    if (new_tree_index < current_weights->size()) {
      // We have the entries in weights and updates for this tree already
      (*current_weights)[new_tree_index] =
          total_new_trees_weight / num_trees_to_add;
      (*num_updates)[new_tree_index]++;
    } else {
      // We need to add a new entry. This is non-batch mode.
      current_weights->push_back(total_new_trees_weight / num_trees_to_add);
      num_updates->push_back(1);
    }
  }

  for (int32 i = 0; i < dropped_trees.size(); ++i) {
    const int32 dropped = dropped_trees[i];
    const float original_weight = dropped_trees_original_weights[i];
    const float new_weight = original_weight * num_dropped / (num_dropped + 1);
    (*current_weights)[dropped] = new_weight;
    // Update the number of updates per tree.
    ++(*num_updates)[dropped];
  }
}

}  // namespace utils
}  // namespace boosted_trees
}  // namespace tensorflow