diff options
author | Rohan Jain <rohanj@google.com> | 2016-08-12 07:16:39 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-12 08:32:30 -0700 |
commit | f907eeb9f781b00a24afbccd732fdecafb9433c5 (patch) | |
tree | 0278a01b983716e480a9598a204042015adf782c /tensorflow/core/lib/io | |
parent | 84cefad9ccd0cab0c8d594597492feec1f3ebcda (diff) |
Creating a InputStreamInterface for sequentially streaming files.
Change: 130101575
Diffstat (limited to 'tensorflow/core/lib/io')
-rw-r--r-- | tensorflow/core/lib/io/inputstream_interface.cc | 42 | ||||
-rw-r--r-- | tensorflow/core/lib/io/inputstream_interface.h | 47 | ||||
-rw-r--r-- | tensorflow/core/lib/io/inputstream_interface_test.cc | 59 |
3 files changed, 148 insertions, 0 deletions
diff --git a/tensorflow/core/lib/io/inputstream_interface.cc b/tensorflow/core/lib/io/inputstream_interface.cc new file mode 100644 index 0000000000..b18d833e06 --- /dev/null +++ b/tensorflow/core/lib/io/inputstream_interface.cc @@ -0,0 +1,42 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace io { + +// To limit memory usage, the default implementation of SkipNBytes() only reads +// 8MB at a time. +static constexpr int64 kMaxSkipSize = 8 * 1024 * 1024; + +Status InputStreamInterface::SkipNBytes(int64 bytes_to_skip) { + if (bytes_to_skip < 0) { + return errors::InvalidArgument("Can't skip a negative number of bytes"); + } + string unused; + // Read kDefaultSkipSize at a time till bytes_to_skip. + while (bytes_to_skip > 0) { + int64 bytes_to_read = std::min<int64>(kMaxSkipSize, bytes_to_skip); + TF_RETURN_IF_ERROR(ReadNBytes(bytes_to_read, &unused)); + bytes_to_skip -= bytes_to_read; + } + return Status::OK(); +} + +} // namespace io +} // namespace tensorflow diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h new file mode 100644 index 0000000000..3d5d690a69 --- /dev/null +++ b/tensorflow/core/lib/io/inputstream_interface.h @@ -0,0 +1,47 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LIB_IO_INPUTSTREAM_INTERFACE_H_ +#define TENSORFLOW_LIB_IO_INPUTSTREAM_INTERFACE_H_ + +#include <string> +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace io { + +// An interface that defines input streaming operations. +class InputStreamInterface { + public: + InputStreamInterface() {} + virtual ~InputStreamInterface() {} + + // Reads the next bytes_to_read from the file. Typical return codes: + // * OK - in case of success. + // * OUT_OF_RANGE - not enough bytes remaining before end of file. + virtual Status ReadNBytes(int64 bytes_to_read, string* result) = 0; + + // Skips bytes_to_skip before next ReadNBytes. bytes_to_skip should be >= 0. + // Typical return codes: + // * OK - in case of success. + // * OUT_OF_RANGE - not enough bytes remaining before end of file. + virtual Status SkipNBytes(int64 bytes_to_skip); +}; + +} // namespace io +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_ diff --git a/tensorflow/core/lib/io/inputstream_interface_test.cc b/tensorflow/core/lib/io/inputstream_interface_test.cc new file mode 100644 index 0000000000..2b4454bedb --- /dev/null +++ b/tensorflow/core/lib/io/inputstream_interface_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/io/inputstream_interface.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace io { +namespace { + +class TestStringStream : public InputStreamInterface { + public: + TestStringStream(const string& content) : content_(content) {} + + Status ReadNBytes(int64 bytes_to_read, string* result) override { + result->clear(); + if (pos_ + bytes_to_read > content_.size()) { + return errors::OutOfRange("limit reached"); + } + *result = content_.substr(pos_, bytes_to_read); + pos_ += bytes_to_read; + return Status::OK(); + } + + private: + string content_; + int64 pos_ = 0; +}; + +TEST(InputStreamInterface, Basic) { + TestStringStream ss("This is a test string"); + string res; + TF_ASSERT_OK(ss.ReadNBytes(4, &res)); + EXPECT_EQ("This", res); + TF_ASSERT_OK(ss.SkipNBytes(6)); + TF_ASSERT_OK(ss.ReadNBytes(11, &res)); + EXPECT_EQ("test string", res); + // Skipping past end of the file causes OutOfRange error. + EXPECT_TRUE(errors::IsOutOfRange(ss.SkipNBytes(1))); +} + +} // anonymous namespace +} // namespace io +} // namespace tensorflow |