aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h
blob: 5e316538cefed30b2867252c9ebc4754216db329 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_

#include <algorithm>
#include <unordered_map>
#include <vector>

#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"

namespace tensorflow {
namespace boosted_trees {
namespace quantiles {

// Buffering container ideally suited for scenarios where we need
// to sort and dedupe/compact fixed chunks of a stream of weighted elements.
template <typename ValueType, typename WeightType,
          typename CompareFn = std::less<ValueType>>
class WeightedQuantilesBuffer {
 public:
  struct BufferEntry {
    BufferEntry(const ValueType& v, const WeightType& w)
        : value(v), weight(w) {}
    BufferEntry() : value(0), weight(0) {}

    bool operator<(const BufferEntry& other) const {
      return kCompFn(value, other.value);
    }
    bool operator==(const BufferEntry& other) const {
      return value == other.value && weight == other.weight;
    }
    friend std::ostream& operator<<(std::ostream& strm,
                                    const BufferEntry& entry) {
      return strm << "{" << entry.value << ", " << entry.weight << "}";
    }
    ValueType value;
    WeightType weight;
  };

  explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
      : max_size_(std::min(block_size << 1, max_elements)) {
    QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
                          << ", " << max_elements << ")";
    vec_.reserve(max_size_);
  }

  // Disallow copying as it's semantically non-sensical in the Squawd algorithm
  // but enable move semantics.
  WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
  WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
  WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
  WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;

  // Push entry to buffer and maintain a compact representation within
  // pre-defined size limit.
  void PushEntry(const ValueType& value, const WeightType& weight) {
    // Callers are expected to act on a full compacted buffer after the
    // PushEntry call returns.
    QCHECK(!IsFull()) << "Buffer already full: " << max_size_;

    // Ignore zero and negative weight entries.
    if (weight <= 0) {
      return;
    }

    // Push back the entry to the buffer.
    vec_.push_back(BufferEntry(value, weight));
  }

  // Returns a sorted vector view of the base buffer and clears the buffer.
  // Callers should minimize how often this is called, ideally only right after
  // the buffer becomes full.
  std::vector<BufferEntry> GenerateEntryList() {
    std::vector<BufferEntry> ret;
    if (vec_.size() == 0) {
      return ret;
    }
    ret.swap(vec_);
    vec_.reserve(max_size_);
    std::sort(ret.begin(), ret.end());
    size_t num_entries = 0;
    for (size_t i = 1; i < ret.size(); ++i) {
      if (ret[i].value != ret[i - 1].value) {
        BufferEntry tmp = ret[i];
        ++num_entries;
        ret[num_entries] = tmp;
      } else {
        ret[num_entries].weight += ret[i].weight;
      }
    }
    ret.resize(num_entries + 1);
    return ret;
  }

  int64 Size() const { return vec_.size(); }
  bool IsFull() const { return vec_.size() >= max_size_; }
  void Clear() { vec_.clear(); }

 private:
  using BufferVector = typename std::vector<BufferEntry>;

  // Comparison function.
  static constexpr decltype(CompareFn()) kCompFn = CompareFn();

  // Base buffer.
  size_t max_size_;
  BufferVector vec_;
};

template <typename ValueType, typename WeightType, typename CompareFn>
constexpr decltype(CompareFn())
    WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;

}  // namespace quantiles
}  // namespace boosted_trees
}  // namespace tensorflow

#endif  // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_