diff options
author | 2017-12-22 12:42:59 -0800 | |
---|---|---|
committer | 2017-12-22 12:46:28 -0800 | |
commit | e4532d20973c4c00854492362665317551661c18 (patch) | |
tree | 398527e29bd30d39237adb4785be5069fdb646fa /tensorflow/core/kernels/decode_compressed_op.cc | |
parent | 673641c2d6a27fa97ee05453d671853731a4c602 (diff) |
Merge changes from github.
PiperOrigin-RevId: 179953488
Diffstat (limited to 'tensorflow/core/kernels/decode_compressed_op.cc')
-rw-r--r-- | tensorflow/core/kernels/decode_compressed_op.cc | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/decode_compressed_op.cc b/tensorflow/core/kernels/decode_compressed_op.cc new file mode 100644 index 0000000000..3c3d49e1f8 --- /dev/null +++ b/tensorflow/core/kernels/decode_compressed_op.cc @@ -0,0 +1,125 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parse_ops.cc. + +#include <algorithm> +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" + +namespace tensorflow { +namespace { +// Wrap memory buffer into InputStreamInterface +class MemoryInputStream : public io::InputStreamInterface { + public: + explicit MemoryInputStream(const char* buffer, size_t length) + : buf_(buffer), len_(length), pos_(0) {} + + ~MemoryInputStream() override {} + + Status ReadNBytes(int64 bytes_to_read, string* result) override { + result->clear(); + if (bytes_to_read < 0) { + return errors::InvalidArgument("Can't read a negative number of bytes: ", + bytes_to_read); + } + int64 bytes = bytes_to_read; + Status s = Status::OK(); + if (pos_ + bytes_to_read > len_) { + bytes = len_ - pos_; + s = errors::OutOfRange("reached end of file"); + } + if (bytes > 0) { + result->resize(bytes); + memcpy(&(*result)[0], &buf_[pos_], bytes); + pos_ += bytes; + } + return s; + } + + int64 Tell() const override { return pos_; } + + Status Reset() override { + pos_ = 0; + return Status::OK(); + } + + private: + const char* buf_; // Not owned. + int64 len_; + int64 pos_ = 0; // Tracks where we are in the file. +}; +} // namespace + +class DecodeCompressedOp : public OpKernel { + public: + explicit DecodeCompressedOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("compression_type", &compression_type_)); + OP_REQUIRES(context, + (compression_type_.empty() || compression_type_ == "ZLIB" || + compression_type_ == "GZIP"), + errors::InvalidArgument( + "Only ZLIB, GZIP or NONE are supported compressions")); + } + + void Compute(OpKernelContext* context) override { + const Tensor* bytes_tensor; + OP_REQUIRES_OK(context, context->input("bytes", &bytes_tensor)); + const auto& bytes_flat = bytes_tensor->flat<string>(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", bytes_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<string>(); + if (compression_type_.empty()) { + for (int64 i = 0; i < bytes_flat.size(); i++) { + output_flat(i) = bytes_flat(i); + } + } else { + const io::ZlibCompressionOptions zlib_options = + compression_type_ == "ZLIB" ? io::ZlibCompressionOptions::DEFAULT() + : io::ZlibCompressionOptions::GZIP(); + for (int64 i = 0; i < bytes_flat.size(); i++) { + std::unique_ptr<MemoryInputStream> input_stream( + new MemoryInputStream(bytes_flat(i).data(), bytes_flat(i).size())); + std::unique_ptr<io::ZlibInputStream> zlib_stream( + new io::ZlibInputStream( + input_stream.get(), static_cast<size_t>(kBufferSize), + static_cast<size_t>(kBufferSize), zlib_options)); + string output_string; + Status s = zlib_stream->ReadNBytes(INT_MAX, &output_string); + OP_REQUIRES(context, (s.ok() || errors::IsOutOfRange(s)), s); + output_flat(i) = output_string; + } + } + } + + private: + enum { kBufferSize = 256 << 10 /* 256 kB */ }; + string compression_type_; +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeCompressed").Device(DEVICE_CPU), + DecodeCompressedOp) + +} // namespace tensorflow |