aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/decode_compressed_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-22 12:42:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 12:46:28 -0800
commite4532d20973c4c00854492362665317551661c18 (patch)
tree398527e29bd30d39237adb4785be5069fdb646fa /tensorflow/core/kernels/decode_compressed_op.cc
parent673641c2d6a27fa97ee05453d671853731a4c602 (diff)
Merge changes from github.
PiperOrigin-RevId: 179953488
Diffstat (limited to 'tensorflow/core/kernels/decode_compressed_op.cc')
-rw-r--r--tensorflow/core/kernels/decode_compressed_op.cc125
1 files changed, 125 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/decode_compressed_op.cc b/tensorflow/core/kernels/decode_compressed_op.cc
new file mode 100644
index 0000000000..3c3d49e1f8
--- /dev/null
+++ b/tensorflow/core/kernels/decode_compressed_op.cc
@@ -0,0 +1,125 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/parse_ops.cc.
+
+#include <algorithm>
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_inputstream.h"
+
+namespace tensorflow {
+namespace {
+// Wrap memory buffer into InputStreamInterface
+class MemoryInputStream : public io::InputStreamInterface {
+ public:
+ explicit MemoryInputStream(const char* buffer, size_t length)
+ : buf_(buffer), len_(length), pos_(0) {}
+
+ ~MemoryInputStream() override {}
+
+ Status ReadNBytes(int64 bytes_to_read, string* result) override {
+ result->clear();
+ if (bytes_to_read < 0) {
+ return errors::InvalidArgument("Can't read a negative number of bytes: ",
+ bytes_to_read);
+ }
+ int64 bytes = bytes_to_read;
+ Status s = Status::OK();
+ if (pos_ + bytes_to_read > len_) {
+ bytes = len_ - pos_;
+ s = errors::OutOfRange("reached end of file");
+ }
+ if (bytes > 0) {
+ result->resize(bytes);
+ memcpy(&(*result)[0], &buf_[pos_], bytes);
+ pos_ += bytes;
+ }
+ return s;
+ }
+
+ int64 Tell() const override { return pos_; }
+
+ Status Reset() override {
+ pos_ = 0;
+ return Status::OK();
+ }
+
+ private:
+ const char* buf_; // Not owned.
+ int64 len_;
+ int64 pos_ = 0; // Tracks where we are in the file.
+};
+} // namespace
+
+class DecodeCompressedOp : public OpKernel {
+ public:
+ explicit DecodeCompressedOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("compression_type", &compression_type_));
+ OP_REQUIRES(context,
+ (compression_type_.empty() || compression_type_ == "ZLIB" ||
+ compression_type_ == "GZIP"),
+ errors::InvalidArgument(
+ "Only ZLIB, GZIP or NONE are supported compressions"));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* bytes_tensor;
+ OP_REQUIRES_OK(context, context->input("bytes", &bytes_tensor));
+ const auto& bytes_flat = bytes_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", bytes_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<string>();
+ if (compression_type_.empty()) {
+ for (int64 i = 0; i < bytes_flat.size(); i++) {
+ output_flat(i) = bytes_flat(i);
+ }
+ } else {
+ const io::ZlibCompressionOptions zlib_options =
+ compression_type_ == "ZLIB" ? io::ZlibCompressionOptions::DEFAULT()
+ : io::ZlibCompressionOptions::GZIP();
+ for (int64 i = 0; i < bytes_flat.size(); i++) {
+ std::unique_ptr<MemoryInputStream> input_stream(
+ new MemoryInputStream(bytes_flat(i).data(), bytes_flat(i).size()));
+ std::unique_ptr<io::ZlibInputStream> zlib_stream(
+ new io::ZlibInputStream(
+ input_stream.get(), static_cast<size_t>(kBufferSize),
+ static_cast<size_t>(kBufferSize), zlib_options));
+ string output_string;
+ Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string);
+ OP_REQUIRES(context, (s.ok() || errors::IsOutOfRange(s)), s);
+ output_flat(i) = output_string;
+ }
+ }
+ }
+
+ private:
+ enum { kBufferSize = 256 << 10 /* 256 kB */ };
+ string compression_type_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DecodeCompressed").Device(DEVICE_CPU),
+ DecodeCompressedOp)
+
+} // namespace tensorflow