aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/coder
diff options
context:
space:
mode:
authorGravatar Sung Jin Hwang <sjhwang@google.com>2018-04-13 14:51:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 14:53:52 -0700
commitaedc409605be54f9c7cb67f7b49bdc123d65a8fb (patch)
tree50e9dea2433a6a37eca1d0fd5e8c4f0eab72fa38 /tensorflow/contrib/coder
parent8600d918a63c658b9b79ba96ee821c903ba3ee94 (diff)
Added PmfToQuantizedCdf op to contrib/coder in TensorFlow.
The added op transforms probability mass functions (PMF) to quantized cumulative distribution function (CDF), which can be used by range coder ops in contrib/coder. The op takes greedy approach to ensure that the post-quantization probability masses do not sum over the maximum quantized value. The op does not make any adjustment when the post-quantization probability masses already sum less than the maximum value. PiperOrigin-RevId: 192827779
Diffstat (limited to 'tensorflow/contrib/coder')
-rw-r--r--tensorflow/contrib/coder/BUILD34
-rw-r--r--tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc150
-rw-r--r--tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc140
-rw-r--r--tensorflow/contrib/coder/ops/coder_ops.cc32
4 files changed, 355 insertions, 1 deletions
diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD
index ce12e38248..9ca4ce8a9c 100644
--- a/tensorflow/contrib/coder/BUILD
+++ b/tensorflow/contrib/coder/BUILD
@@ -92,6 +92,34 @@ tf_cc_test(
],
)
+tf_kernel_library(
+ name = "pmf_to_cdf_op",
+ srcs = ["kernels/pmf_to_cdf_op.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":coder_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "pmf_to_cdf_op_test",
+ size = "small",
+ srcs = ["kernels/pmf_to_cdf_op_test.cc"],
+ deps = [
+ ":pmf_to_cdf_op",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ ],
+)
+
cc_library(
name = "all_ops",
deps = [":coder_ops_op_lib"],
@@ -99,12 +127,16 @@ cc_library(
cc_library(
name = "all_kernels",
- deps = [":range_coder_ops"],
+ deps = [
+ ":pmf_to_cdf_op",
+ ":range_coder_ops",
+ ],
)
tf_custom_op_library(
name = "python/ops/_coder_ops.so",
srcs = [
+ "kernels/pmf_to_cdf_op.cc",
"kernels/range_coder.cc",
"kernels/range_coder.h",
"kernels/range_coder_ops.cc",
diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc
new file mode 100644
index 0000000000..c787e8eded
--- /dev/null
+++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc
@@ -0,0 +1,150 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <iterator>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+using errors::InvalidArgument;
+
+class PmfToCdfOp : public OpKernel {
+ public:
+ explicit PmfToCdfOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_));
+ OP_REQUIRES(
+ context, 0 < precision_ && precision_ <= 16,
+ InvalidArgument("`precision` must be in [1, 16]: ", precision_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& pmf_tensor = context->input(0);
+
+ TensorShape shape = pmf_tensor.shape();
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(shape),
+ InvalidArgument("`pmf` should be at least 1-D."));
+ OP_REQUIRES(
+ context, shape.dim_size(shape.dims() - 1) > 1,
+ InvalidArgument("`pmf` size should be at least 2 in the last axis."));
+ shape.set_dim(shape.dims() - 1, shape.dim_size(shape.dims() - 1) + 1);
+
+ Tensor* cdf_tensor;
+ OP_REQUIRES_OK(context, context->allocate_output(0, shape, &cdf_tensor));
+
+ auto pmf = pmf_tensor.flat_inner_dims<float, 2>();
+ auto cdf = cdf_tensor->flat_inner_dims<int32, 2>();
+ CHECK_EQ(pmf.dimension(0), cdf.dimension(0));
+ CHECK_EQ(pmf.dimension(1) + 1, cdf.dimension(1));
+
+ const double n = pmf.dimension(1);
+ const int64 cost_per_unit = static_cast<int64>(50.0 * n * std::log2(n));
+ thread::ThreadPool* thread_pool =
+ context->device()->tensorflow_cpu_worker_threads()->workers;
+ thread_pool->ParallelFor(
+ pmf.dimension(0), cost_per_unit,
+ [this, pmf, &cdf](int64 start, int64 limit) {
+ const gtl::ArraySlice<float>::size_type pmf_size = pmf.dimension(1);
+ for (int64 i = start; i < limit; ++i) {
+ cdf(i, 0) = 0;
+ PerShard({&pmf(i, 0), pmf_size}, {&cdf(i, 1), pmf_size});
+ }
+ });
+ }
+
+ private:
+ struct Item {
+ Item(int32* p, double mass) : pointer(p), mass(mass) {
+ penalty = ComputeNextPenalty();
+ }
+
+ void Decrease() {
+ CHECK_GT(*pointer, 1);
+ --*pointer;
+ penalty = ComputeNextPenalty();
+ }
+
+ friend bool operator<(const Item& lhs, const Item& rhs) {
+ return lhs.penalty < rhs.penalty;
+ }
+
+ double ComputeNextPenalty() {
+ if (*pointer <= 1) {
+ return std::numeric_limits<double>::infinity();
+ }
+ return mass * (std::log2(*pointer) - std::log2(*pointer - 1));
+ }
+
+ int32* pointer;
+ double mass;
+ double penalty;
+ };
+
+ void PerShard(gtl::ArraySlice<float> pmf,
+ gtl::MutableArraySlice<int32> cdf) const {
+ CHECK_EQ(pmf.size(), cdf.size());
+
+ const int32 normalizer = 1 << precision_;
+ std::transform(pmf.begin(), pmf.end(), cdf.begin(),
+ [normalizer](float mass) {
+ int32 value = std::rint(mass * normalizer);
+ // NOTE: Consider checking if mass > 0.
+ value = std::max(value, 1);
+ return value;
+ });
+
+ int32 sum = std::accumulate(cdf.begin(), cdf.end(), 0);
+ if (sum > normalizer) {
+ std::vector<Item> queue;
+ queue.reserve(cdf.size());
+ for (int i = 0; i < cdf.size(); ++i) {
+ queue.emplace_back(&cdf[i], pmf[i]);
+ }
+
+ std::sort(queue.begin(), queue.end());
+ while (sum-- > normalizer) {
+ queue[0].Decrease();
+ // Performs a linear search because this find_if is likely to return
+ // iterator very close to the begin.
+ auto iter =
+ std::find_if(std::next(queue.begin()), queue.end(),
+ [&queue](const Item& rhs) { return queue[0] < rhs; });
+ std::rotate(queue.begin(), std::next(queue.begin()), iter);
+ }
+ }
+ std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
+ }
+
+ int precision_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PmfToQuantizedCdf").Device(DEVICE_CPU),
+ PmfToCdfOp);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc
new file mode 100644
index 0000000000..c70e38faab
--- /dev/null
+++ b/tensorflow/contrib/coder/kernels/pmf_to_cdf_op_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <limits>
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+class PmfToQuantizedCdfOpTest : public OpsTestBase {
+ protected:
+ void SetupOp(int precision, Tensor* input) {
+ TF_ASSERT_OK(NodeDefBuilder("pmf_to_cdf", "PmfToQuantizedCdf")
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("precision", precision)
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+
+ inputs_.clear();
+ inputs_.emplace_back(input);
+ }
+
+ void GenerateData(random::SimplePhilox* rand,
+ gtl::MutableArraySlice<float> slice) {
+ constexpr float minimum = std::numeric_limits<float>::epsilon();
+ float sum = 0;
+ for (float& value : slice) {
+ value = std::max(rand->RandFloat(), minimum);
+ sum += value;
+ }
+ for (float& value : slice) {
+ value /= sum;
+ }
+ }
+
+ void Verify(int precision, const Tensor& pmf_tensor,
+ const Tensor& cdf_tensor) {
+ ASSERT_EQ(pmf_tensor.dims(), cdf_tensor.dims());
+ const int n = pmf_tensor.dims();
+
+ for (int i = 0; i < n - 1; ++i) {
+ EXPECT_EQ(pmf_tensor.dim_size(i), cdf_tensor.dim_size(i));
+ }
+
+ auto pmf = pmf_tensor.flat_inner_dims<float, 2>();
+ auto cdf = cdf_tensor.flat_inner_dims<int32, 2>();
+ EXPECT_EQ(pmf.dimension(1) + 1, cdf.dimension(1));
+
+ const int normalizer = 1 << precision;
+ for (int i = 0; i < pmf.dimension(0); ++i) {
+ EXPECT_EQ(0, cdf(i, 0));
+
+ TTypes<int32>::UnalignedConstVec cdf_slice(&cdf(i, 0), cdf.dimension(1));
+
+ for (int j = 1; j < cdf_slice.size(); ++j) {
+ const int32 diff = cdf_slice(j) - cdf_slice(j - 1);
+ EXPECT_GT(diff, 0);
+ }
+
+ EXPECT_LE(cdf_slice(cdf_slice.size() - 1), normalizer);
+ }
+ }
+};
+
+TEST_F(PmfToQuantizedCdfOpTest, UnderSum) {
+ Tensor pmf(DT_FLOAT, {1, 10, 1, 32});
+ auto matrix = pmf.flat_inner_dims<float, 2>();
+ const std::size_t n = matrix.dimension(1);
+
+ random::PhiloxRandom gen(random::New64(), random::New64());
+ random::SimplePhilox rand(&gen);
+ for (int64 i = 0; i < matrix.dimension(0); ++i) {
+ GenerateData(&rand, {&matrix(i, 0), n});
+ }
+
+ constexpr int kPrecision = 10;
+ SetupOp(kPrecision, &pmf);
+ TF_ASSERT_OK(RunOpKernel());
+
+ Verify(kPrecision, pmf, *GetOutput(0));
+}
+
+TEST_F(PmfToQuantizedCdfOpTest, OverSum) {
+ Tensor pmf(DT_FLOAT, {10, 1, 1, 100});
+ auto matrix = pmf.flat_inner_dims<float, 2>();
+
+ // Half of each PMF is filled with zeros. The op will round up zeros to ones,
+ // post quantization. These round ups are likely to make the sum over
+ // normalizer value.
+ matrix.setZero();
+ const std::size_t n = matrix.dimension(1) / 2;
+
+ random::PhiloxRandom gen;
+ random::SimplePhilox rand(&gen);
+ for (int64 i = 0; i < matrix.dimension(0); ++i) {
+ GenerateData(&rand, {&matrix(i, 0), n});
+ }
+
+ constexpr int kPrecision = 7;
+ SetupOp(kPrecision, &pmf);
+ TF_ASSERT_OK(RunOpKernel());
+
+ Verify(kPrecision, pmf, *GetOutput(0));
+}
+
+TEST_F(PmfToQuantizedCdfOpTest, ShapeFn) {
+ ShapeInferenceTestOp op("PmfToQuantizedCdf");
+
+ INFER_OK(op, "?", "?");
+ INFER_OK(op, "[3]", "[4]");
+ INFER_OK(op, "[3,4]", "[d0_0,5]");
+ INFER_OK(op, "[3,4,5]", "[d0_0,d0_1,6]");
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc
index 9056d1a696..9bb171298f 100644
--- a/tensorflow/contrib/coder/ops/coder_ops.cc
+++ b/tensorflow/contrib/coder/ops/coder_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
+using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
@@ -115,5 +116,36 @@ decoded: An int32 tensor with shape equal to `shape`.
precision: The number of bits for probability quantization. Must be <= 16, and
must match the precision used by RangeEncode that produced `encoded`.
)doc");
+
+REGISTER_OP("PmfToQuantizedCdf")
+ .Input("pmf: float")
+ .Output("cdf: int32")
+ .Attr("precision: int >= 1")
+ .SetShapeFn([] (InferenceContext* c) {
+ ShapeHandle in;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &in));
+ DimensionHandle last;
+ TF_RETURN_IF_ERROR(c->Add(c->Dim(in, -1), 1, &last));
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->ReplaceDim(in, -1, last, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Converts PMF to quantized CDF. This op uses floating-point operations
+internally. Therefore the quantized output may not be consistent across multiple
+platforms. For entropy encoders and decoders to have the same quantized CDF on
+different platforms, the quantized CDF should be produced once and saved, then
+the saved quantized CDF should be used everywhere.
+
+After quantization, if PMF sums to less than or equal to 2^precision, then this
+is equivalent to cumsum over the last dimension. This op makes no effort to make
+the sum close to 2^precision when the sum is already <= 2^precision.
+
+After quantization, if PMF sums to greater than 2^precision, then some values of
+PMF is decreased to keep the sum no more than 2^precision.
+
+Note that the input PMF is pre-quantization.
+)doc");
// clang-format on
} // namespace tensorflow