aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc')
-rw-r--r--tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc60
1 files changed, 53 insertions, 7 deletions
diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc
index c787e8eded..bd5272ee6f 100644
--- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc
+++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc
@@ -16,6 +16,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include <algorithm>
+#include <functional>
#include <iterator>
#include <numeric>
#include <vector>
@@ -79,8 +80,8 @@ class PmfToCdfOp : public OpKernel {
}
private:
- struct Item {
- Item(int32* p, double mass) : pointer(p), mass(mass) {
+ struct PenaltyItem {
+ PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) {
penalty = ComputeNextPenalty();
}
@@ -90,7 +91,7 @@ class PmfToCdfOp : public OpKernel {
penalty = ComputeNextPenalty();
}
- friend bool operator<(const Item& lhs, const Item& rhs) {
+ friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) {
return lhs.penalty < rhs.penalty;
}
@@ -106,6 +107,34 @@ class PmfToCdfOp : public OpKernel {
double penalty;
};
+ struct GainItem {
+ GainItem(int32* p, double mass) : pointer(p), mass(mass) {
+ gain = ComputeNextGain();
+ }
+
+ void Increase() {
+ CHECK_GT(*pointer, 0);
+ ++*pointer;
+ gain = ComputeNextGain();
+ }
+
+ friend bool operator>(const GainItem& lhs, const GainItem& rhs) {
+ return lhs.gain > rhs.gain;
+ }
+
+ double ComputeNextGain() {
+ // Never increment zero value to non-zero value.
+ if (*pointer < 1) {
+ return -std::numeric_limits<double>::infinity();
+ }
+ return mass * (std::log2(*pointer + 1) - std::log2(*pointer));
+ }
+
+ int32* pointer;
+ double mass;
+ double gain;
+ };
+
void PerShard(gtl::ArraySlice<float> pmf,
gtl::MutableArraySlice<int32> cdf) const {
CHECK_EQ(pmf.size(), cdf.size());
@@ -121,7 +150,7 @@ class PmfToCdfOp : public OpKernel {
int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0);
if (sum > normalizer) {
- std::vector<Item> queue;
+ std::vector<PenaltyItem> queue;
queue.reserve(cdf.size());
for (int i = 0; i < cdf.size(); ++i) {
queue.emplace_back(&cdf[i], pmf[i]);
@@ -132,9 +161,26 @@ class PmfToCdfOp : public OpKernel {
queue[0].Decrease();
// Performs a linear search because this find_if is likely to return
// iterator very close to the begin.
- auto iter =
- std::find_if(std::next(queue.begin()), queue.end(),
- [&queue](const Item& rhs) { return queue[0] < rhs; });
+ auto iter = std::find_if(
+ std::next(queue.begin()), queue.end(),
+ [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; });
+ std::rotate(queue.begin(), std::next(queue.begin()), iter);
+ }
+ } else if (sum < normalizer) {
+ std::vector<GainItem> queue;
+ queue.reserve(cdf.size());
+ for (int i = 0; i < cdf.size(); ++i) {
+ queue.emplace_back(&cdf[i], pmf[i]);
+ }
+
+ std::sort(queue.begin(), queue.end(), std::greater<GainItem>());
+ while (sum++ < normalizer) {
+ queue[0].Increase();
+ // Performs a linear search because this find_if is likely to return
+ // iterator very close to the begin.
+ auto iter = std::find_if(
+ std::next(queue.begin()), queue.end(),
+ [&queue](const GainItem& rhs) { return queue[0] > rhs; });
std::rotate(queue.begin(), std::next(queue.begin()), iter);
}
}