diff options
Diffstat (limited to 'tensorflow/contrib/coder/kernels')
-rw-r--r-- | tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc | 60 | ||||
-rw-r--r-- | tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc | 6 |
2 files changed, 57 insertions, 9 deletions
diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc index c787e8eded..bd5272ee6f 100644 --- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include <algorithm> +#include <functional> #include <iterator> #include <numeric> #include <vector> @@ -79,8 +80,8 @@ class PmfToCdfOp : public OpKernel { } private: - struct Item { - Item(int32* p, double mass) : pointer(p), mass(mass) { + struct PenaltyItem { + PenaltyItem(int32* p, double mass) : pointer(p), mass(mass) { penalty = ComputeNextPenalty(); } @@ -90,7 +91,7 @@ class PmfToCdfOp : public OpKernel { penalty = ComputeNextPenalty(); } - friend bool operator<(const Item& lhs, const Item& rhs) { + friend bool operator<(const PenaltyItem& lhs, const PenaltyItem& rhs) { return lhs.penalty < rhs.penalty; } @@ -106,6 +107,34 @@ class PmfToCdfOp : public OpKernel { double penalty; }; + struct GainItem { + GainItem(int32* p, double mass) : pointer(p), mass(mass) { + gain = ComputeNextGain(); + } + + void Increase() { + CHECK_GT(*pointer, 0); + ++*pointer; + gain = ComputeNextGain(); + } + + friend bool operator>(const GainItem& lhs, const GainItem& rhs) { + return lhs.gain > rhs.gain; + } + + double ComputeNextGain() { + // Never increment zero value to non-zero value. + if (*pointer < 1) { + return -std::numeric_limits<double>::infinity(); + } + return mass * (std::log2(*pointer + 1) - std::log2(*pointer)); + } + + int32* pointer; + double mass; + double gain; + }; + void PerShard(gtl::ArraySlice<float> pmf, gtl::MutableArraySlice<int32> cdf) const { CHECK_EQ(pmf.size(), cdf.size()); @@ -121,7 +150,7 @@ class PmfToCdfOp : public OpKernel { int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0); if (sum > normalizer) { - std::vector<Item> queue; + std::vector<PenaltyItem> queue; queue.reserve(cdf.size()); for (int i = 0; i < cdf.size(); ++i) { queue.emplace_back(&cdf[i], pmf[i]); @@ -132,9 +161,26 @@ class PmfToCdfOp : public OpKernel { queue[0].Decrease(); // Performs a linear search because this find_if is likely to return // iterator very close to the begin. - auto iter = - std::find_if(std::next(queue.begin()), queue.end(), - [&queue](const Item& rhs) { return queue[0] < rhs; }); + auto iter = std::find_if( + std::next(queue.begin()), queue.end(), + [&queue](const PenaltyItem& rhs) { return queue[0] < rhs; }); + std::rotate(queue.begin(), std::next(queue.begin()), iter); + } + } else if (sum < normalizer) { + std::vector<GainItem> queue; + queue.reserve(cdf.size()); + for (int i = 0; i < cdf.size(); ++i) { + queue.emplace_back(&cdf[i], pmf[i]); + } + + std::sort(queue.begin(), queue.end(), std::greater<GainItem>()); + while (sum++ < normalizer) { + queue[0].Increase(); + // Performs a linear search because this find_if is likely to return + // iterator very close to the begin. + auto iter = std::find_if( + std::next(queue.begin()), queue.end(), + [&queue](const GainItem& rhs) { return queue[0] > rhs; }); std::rotate(queue.begin(), std::next(queue.begin()), iter); } } diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc index c70e38faab..3408f6b519 100644 --- a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc +++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc @@ -82,7 +82,7 @@ class PmfToQuantizedCdfOpTest : public OpsTestBase { EXPECT_GT(diff, 0); } - EXPECT_LE(cdf_slice(cdf_slice.size() - 1), normalizer); + EXPECT_EQ(cdf_slice(cdf_slice.size() - 1), normalizer); } } }; @@ -98,6 +98,8 @@ TEST_F(PmfToQuantizedCdfOpTest, UnderSum) { GenerateData(&rand, {&matrix(i, 0), n}); } + pmf.flat<float>() = pmf.flat<float>() * 0.85f; + constexpr int kPrecision = 10; SetupOp(kPrecision, &pmf); TF_ASSERT_OK(RunOpKernel()); @@ -115,7 +117,7 @@ TEST_F(PmfToQuantizedCdfOpTest, OverSum) { matrix.setZero(); const std::size_t n = matrix.dimension(1) / 2; - random::PhiloxRandom gen; + random::PhiloxRandom gen(random::New64(), random::New64()); random::SimplePhilox rand(&gen); for (int64 i = 0; i < matrix.dimension(0); ++i) { GenerateData(&rand, {&matrix(i, 0), n}); |