aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/s3/s3_file_system.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/s3/s3_file_system.cc')
-rw-r--r--tensorflow/contrib/s3/s3_file_system.cc575
1 files changed, 575 insertions, 0 deletions
diff --git a/tensorflow/contrib/s3/s3_file_system.cc b/tensorflow/contrib/s3/s3_file_system.cc
new file mode 100644
index 0000000000..b09cf81d46
--- /dev/null
+++ b/tensorflow/contrib/s3/s3_file_system.cc
@@ -0,0 +1,575 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/s3/s3_file_system.h"
+#include "tensorflow/contrib/s3/s3_crypto.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#include <aws/core/Aws.h>
+#include <aws/core/utils/FileSystemUtils.h>
+#include <aws/s3/S3Client.h>
+#include <aws/s3/S3Errors.h>
+#include <aws/s3/model/CopyObjectRequest.h>
+#include <aws/s3/model/DeleteObjectRequest.h>
+#include <aws/s3/model/GetObjectRequest.h>
+#include <aws/s3/model/HeadBucketRequest.h>
+#include <aws/s3/model/HeadObjectRequest.h>
+#include <aws/s3/model/ListObjectsRequest.h>
+#include <aws/s3/model/PutObjectRequest.h>
+
+#include <cstdlib>
+
+namespace tensorflow {
+
+static const char* kS3FileSystemAllocationTag = "S3FileSystemAllocation";
+static const size_t kS3ReadAppendableFileBufferSize = 1024 * 1024;
+static const int kS3GetChildrenMaxKeys = 100;
+
+Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
+ static mutex cfg_lock;
+ static bool init(false);
+ static Aws::Client::ClientConfiguration cfg;
+
+ std::lock_guard<mutex> lock(cfg_lock);
+
+ if (!init) {
+ const char* endpoint = getenv("S3_ENDPOINT");
+ if (endpoint) {
+ cfg.endpointOverride = Aws::String(endpoint);
+ }
+ const char* region = getenv("S3_REGION");
+ if (region) {
+ cfg.region = Aws::String(region);
+ }
+ const char* use_https = getenv("S3_USE_HTTPS");
+ if (use_https) {
+ if (use_https[0] == '0') {
+ cfg.scheme = Aws::Http::Scheme::HTTP;
+ } else {
+ cfg.scheme = Aws::Http::Scheme::HTTPS;
+ }
+ }
+ const char* verify_ssl = getenv("S3_VERIFY_SSL");
+ if (verify_ssl) {
+ if (verify_ssl[0] == '0') {
+ cfg.verifySSL = false;
+ } else {
+ cfg.verifySSL = true;
+ }
+ }
+
+ init = true;
+ }
+
+ return cfg;
+};
+
+Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
+ string* object) {
+ if (!bucket || !object) {
+ return errors::Internal("bucket and object cannot be null.");
+ }
+ StringPiece scheme, bucketp, objectp;
+ io::ParseURI(fname, &scheme, &bucketp, &objectp);
+ if (scheme != "s3") {
+ return errors::InvalidArgument("S3 path doesn't start with 's3://': ",
+ fname);
+ }
+ *bucket = bucketp.ToString();
+ if (bucket->empty() || *bucket == ".") {
+ return errors::InvalidArgument("S3 path doesn't contain a bucket name: ",
+ fname);
+ }
+ objectp.Consume("/");
+ *object = objectp.ToString();
+ if (!empty_object_ok && object->empty()) {
+ return errors::InvalidArgument("S3 path doesn't contain an object name: ",
+ fname);
+ }
+ return Status::OK();
+}
+
+class S3RandomAccessFile : public RandomAccessFile {
+ public:
+ S3RandomAccessFile(const string& bucket, const string& object)
+ : bucket_(bucket), object_(object) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+ Aws::S3::Model::GetObjectRequest getObjectRequest;
+ getObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
+ string bytes = strings::StrCat("bytes=", offset, "-", offset + n - 1);
+ getObjectRequest.SetRange(bytes.c_str());
+ getObjectRequest.SetResponseStreamFactory([]() {
+ return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag);
+ });
+ auto getObjectOutcome = s3Client.GetObject(getObjectRequest);
+ if (!getObjectOutcome.IsSuccess()) {
+ n = 0;
+ *result = StringPiece(scratch, n);
+ return Status(error::OUT_OF_RANGE, "Read less bytes than requested");
+ }
+ n = getObjectOutcome.GetResult().GetContentLength();
+ std::stringstream ss;
+ ss << getObjectOutcome.GetResult().GetBody().rdbuf();
+ ss.read(scratch, n);
+
+ *result = StringPiece(scratch, n);
+ return Status::OK();
+ }
+
+ private:
+ string bucket_;
+ string object_;
+};
+
+class S3WritableFile : public WritableFile {
+ public:
+ S3WritableFile(const string& bucket, const string& object)
+ : bucket_(bucket),
+ object_(object),
+ sync_needed_(true),
+ outfile_(Aws::MakeShared<Aws::Utils::TempFile>(
+ kS3FileSystemAllocationTag, "/tmp/s3_filesystem_XXXXXX",
+ std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
+ std::ios_base::out)) {}
+
+ Status Append(const StringPiece& data) override {
+ if (!outfile_) {
+ return errors::FailedPrecondition(
+ "The internal temporary file is not writable.");
+ }
+ sync_needed_ = true;
+ outfile_->write(data.data(), data.size());
+ if (!outfile_->good()) {
+ return errors::Internal(
+ "Could not append to the internal temporary file.");
+ }
+ return Status::OK();
+ }
+
+ Status Close() override {
+ if (outfile_) {
+ TF_RETURN_IF_ERROR(Sync());
+ outfile_.reset();
+ }
+ return Status::OK();
+ }
+
+ Status Flush() override { return Sync(); }
+
+ Status Sync() override {
+ if (!outfile_) {
+ return errors::FailedPrecondition(
+ "The internal temporary file is not writable.");
+ }
+ if (!sync_needed_) {
+ return Status::OK();
+ }
+ Aws::Client::ClientConfiguration clientConfig = GetDefaultClientConfig();
+ clientConfig.connectTimeoutMs = 300000;
+ clientConfig.requestTimeoutMs = 600000;
+ Aws::S3::S3Client s3Client(clientConfig);
+ Aws::S3::Model::PutObjectRequest putObjectRequest;
+ putObjectRequest.WithBucket(bucket_.c_str()).WithKey(object_.c_str());
+ long offset = outfile_->tellp();
+ outfile_->seekg(0);
+ putObjectRequest.SetBody(outfile_);
+ putObjectRequest.SetContentLength(offset);
+ auto putObjectOutcome = s3Client.PutObject(putObjectRequest);
+ outfile_->clear();
+ outfile_->seekp(offset);
+ if (!putObjectOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ putObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
+ putObjectOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+ return Status::OK();
+ }
+
+ private:
+ string bucket_;
+ string object_;
+ bool sync_needed_;
+ std::shared_ptr<Aws::Utils::TempFile> outfile_;
+};
+
+class S3ReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
+ public:
+ S3ReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
+ : data_(std::move(data)), length_(length) {}
+ const void* data() override { return reinterpret_cast<void*>(data_.get()); }
+ uint64 length() override { return length_; }
+
+ private:
+ std::unique_ptr<char[]> data_;
+ uint64 length_;
+};
+
+S3FileSystem::S3FileSystem() {
+ Aws::SDKOptions options;
+ options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Info;
+ options.cryptoOptions.sha256Factory_create_fn = []() {
+ return Aws::MakeShared<S3SHA256Factory>(S3CryptoAllocationTag);
+ };
+ options.cryptoOptions.sha256HMACFactory_create_fn = []() {
+ return Aws::MakeShared<S3SHA256HmacFactory>(S3CryptoAllocationTag);
+ };
+ Aws::InitAPI(options);
+}
+
+S3FileSystem::~S3FileSystem() {
+ Aws::SDKOptions options;
+ options.loggingOptions.logLevel = Aws::Utils::Logging::LogLevel::Info;
+ Aws::ShutdownAPI(options);
+}
+
+Status S3FileSystem::NewRandomAccessFile(
+ const string& fname, std::unique_ptr<RandomAccessFile>* result) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
+ result->reset(new S3RandomAccessFile(bucket, object));
+ return Status::OK();
+}
+
+Status S3FileSystem::NewWritableFile(const string& fname,
+ std::unique_ptr<WritableFile>* result) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
+ result->reset(new S3WritableFile(bucket, object));
+ return Status::OK();
+}
+
+Status S3FileSystem::NewAppendableFile(const string& fname,
+ std::unique_ptr<WritableFile>* result) {
+ std::unique_ptr<RandomAccessFile> reader;
+ TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader));
+ std::unique_ptr<char[]> buffer(new char[kS3ReadAppendableFileBufferSize]);
+ Status status;
+ uint64 offset = 0;
+ StringPiece read_chunk;
+
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
+ result->reset(new S3WritableFile(bucket, object));
+
+ while (true) {
+ status = reader->Read(offset, kS3ReadAppendableFileBufferSize, &read_chunk,
+ buffer.get());
+ if (status.ok()) {
+ (*result)->Append(read_chunk);
+ offset += kS3ReadAppendableFileBufferSize;
+ } else if (status.code() == error::OUT_OF_RANGE) {
+ (*result)->Append(read_chunk);
+ break;
+ } else {
+ (*result).reset();
+ return status;
+ }
+ }
+
+ return Status::OK();
+}
+
+Status S3FileSystem::NewReadOnlyMemoryRegionFromFile(
+ const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
+ uint64 size;
+ TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
+ std::unique_ptr<char[]> data(new char[size]);
+
+ std::unique_ptr<RandomAccessFile> file;
+ TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
+
+ StringPiece piece;
+ TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
+
+ result->reset(new S3ReadOnlyMemoryRegion(std::move(data), size));
+ return Status::OK();
+}
+
+Status S3FileSystem::FileExists(const string& fname) {
+ FileStatistics stats;
+ TF_RETURN_IF_ERROR(this->Stat(fname, &stats));
+ return Status::OK();
+}
+
+Status S3FileSystem::GetChildren(const string& dir,
+ std::vector<string>* result) {
+ string bucket, prefix;
+ TF_RETURN_IF_ERROR(ParseS3Path(dir, false, &bucket, &prefix));
+
+ if (prefix.back() != '/') {
+ prefix.push_back('/');
+ }
+
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+ Aws::S3::Model::ListObjectsRequest listObjectsRequest;
+ listObjectsRequest.WithBucket(bucket.c_str())
+ .WithPrefix(prefix.c_str())
+ .WithMaxKeys(kS3GetChildrenMaxKeys)
+ .WithDelimiter("/");
+ listObjectsRequest.SetResponseStreamFactory(
+ []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
+
+ Aws::S3::Model::ListObjectsResult listObjectsResult;
+ do {
+ auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ if (!listObjectsOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ",
+ listObjectsOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+
+ listObjectsResult = listObjectsOutcome.GetResult();
+ for (const auto& object : listObjectsResult.GetCommonPrefixes()) {
+ Aws::String s = object.GetPrefix();
+ s.erase(s.length() - 1);
+ Aws::String entry = s.substr(strlen(prefix.c_str()));
+ if (entry.length() > 0) {
+ result->push_back(entry.c_str());
+ }
+ }
+ for (const auto& object : listObjectsResult.GetContents()) {
+ Aws::String s = object.GetKey();
+ Aws::String entry = s.substr(strlen(prefix.c_str()));
+ if (entry.length() > 0) {
+ result->push_back(entry.c_str());
+ }
+ }
+ listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
+ } while (listObjectsResult.GetIsTruncated());
+
+ return Status::OK();
+}
+
+Status S3FileSystem::Stat(const string& fname, FileStatistics* stats) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(fname, true, &bucket, &object));
+
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+ if (object.empty()) {
+ Aws::S3::Model::HeadBucketRequest headBucketRequest;
+ headBucketRequest.WithBucket(bucket.c_str());
+ auto headBucketOutcome = s3Client.HeadBucket(headBucketRequest);
+ if (!headBucketOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ headBucketOutcome.GetError().GetExceptionName().c_str(), ": ",
+ headBucketOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+ stats->length = 0;
+ stats->is_directory = 1;
+ return Status::OK();
+ }
+
+ bool found = false;
+
+ Aws::S3::Model::HeadObjectRequest headObjectRequest;
+ headObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
+ headObjectRequest.SetResponseStreamFactory(
+ []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
+ auto headObjectOutcome = s3Client.HeadObject(headObjectRequest);
+ if (headObjectOutcome.IsSuccess()) {
+ stats->length = headObjectOutcome.GetResult().GetContentLength();
+ stats->is_directory = 0;
+ stats->mtime_nsec =
+ headObjectOutcome.GetResult().GetLastModified().Millis() * 1e6;
+ found = true;
+ }
+ string prefix = object;
+ if (prefix.back() != '/') {
+ prefix.push_back('/');
+ }
+ Aws::S3::Model::ListObjectsRequest listObjectsRequest;
+ listObjectsRequest.WithBucket(bucket.c_str())
+ .WithPrefix(prefix.c_str())
+ .WithMaxKeys(1);
+ listObjectsRequest.SetResponseStreamFactory(
+ []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
+ auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ if (listObjectsOutcome.IsSuccess()) {
+ if (listObjectsOutcome.GetResult().GetContents().size() > 0) {
+ stats->length = 0;
+ stats->is_directory = 1;
+ found = true;
+ }
+ }
+ if (!found) {
+ return errors::NotFound("Object ", fname, " does not exist");
+ }
+ return Status::OK();
+}
+
+Status S3FileSystem::DeleteFile(const string& fname) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(fname, false, &bucket, &object));
+
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+ Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
+ deleteObjectRequest.WithBucket(bucket.c_str()).WithKey(object.c_str());
+
+ auto deleteObjectOutcome = s3Client.DeleteObject(deleteObjectRequest);
+ if (!deleteObjectOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
+ deleteObjectOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+ return Status::OK();
+}
+
+Status S3FileSystem::CreateDir(const string& dirname) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(dirname, true, &bucket, &object));
+
+ if (object.empty()) {
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+ Aws::S3::Model::HeadBucketRequest headBucketRequest;
+ headBucketRequest.WithBucket(bucket.c_str());
+ auto headBucketOutcome = s3Client.HeadBucket(headBucketRequest);
+ if (!headBucketOutcome.IsSuccess()) {
+ return errors::NotFound("The bucket ", bucket, " was not found.");
+ }
+ return Status::OK();
+ }
+ string filename = dirname;
+ if (filename.back() != '/') {
+ filename.push_back('/');
+ }
+ std::unique_ptr<WritableFile> file;
+ TF_RETURN_IF_ERROR(NewWritableFile(filename, &file));
+ TF_RETURN_IF_ERROR(file->Close());
+ return Status::OK();
+}
+
+Status S3FileSystem::DeleteDir(const string& dirname) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseS3Path(dirname, false, &bucket, &object));
+
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+ string prefix = object;
+ if (prefix.back() != '/') {
+ prefix.push_back('/');
+ }
+ Aws::S3::Model::ListObjectsRequest listObjectsRequest;
+ listObjectsRequest.WithBucket(bucket.c_str())
+ .WithPrefix(prefix.c_str())
+ .WithMaxKeys(2);
+ listObjectsRequest.SetResponseStreamFactory(
+ []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
+ auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ if (listObjectsOutcome.IsSuccess()) {
+ auto contents = listObjectsOutcome.GetResult().GetContents();
+ if (contents.size() > 1 ||
+ (contents.size() == 1 && contents[0].GetKey() != prefix.c_str())) {
+ return errors::FailedPrecondition("Cannot delete a non-empty directory.");
+ }
+ if (contents.size() == 1 && contents[0].GetKey() == prefix.c_str()) {
+ string filename = dirname;
+ if (filename.back() != '/') {
+ filename.push_back('/');
+ }
+ return DeleteFile(filename);
+ }
+ }
+ return Status::OK();
+}
+
+Status S3FileSystem::GetFileSize(const string& fname, uint64* file_size) {
+ FileStatistics stats;
+ TF_RETURN_IF_ERROR(this->Stat(fname, &stats));
+ *file_size = stats.length;
+ return Status::OK();
+}
+
+Status S3FileSystem::RenameFile(const string& src, const string& target) {
+ string src_bucket, src_object, target_bucket, target_object;
+ TF_RETURN_IF_ERROR(ParseS3Path(src, false, &src_bucket, &src_object));
+ TF_RETURN_IF_ERROR(
+ ParseS3Path(target, false, &target_bucket, &target_object));
+ if (src_object.back() == '/') {
+ if (target_object.back() != '/') {
+ target_object.push_back('/');
+ }
+ } else {
+ if (target_object.back() == '/') {
+ target_object.pop_back();
+ }
+ }
+
+ Aws::S3::S3Client s3Client(GetDefaultClientConfig());
+
+ Aws::S3::Model::CopyObjectRequest copyObjectRequest;
+ Aws::S3::Model::DeleteObjectRequest deleteObjectRequest;
+
+ Aws::S3::Model::ListObjectsRequest listObjectsRequest;
+ listObjectsRequest.WithBucket(src_bucket.c_str())
+ .WithPrefix(src_object.c_str())
+ .WithMaxKeys(kS3GetChildrenMaxKeys);
+ listObjectsRequest.SetResponseStreamFactory(
+ []() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
+
+ Aws::S3::Model::ListObjectsResult listObjectsResult;
+ do {
+ auto listObjectsOutcome = s3Client.ListObjects(listObjectsRequest);
+ if (!listObjectsOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ listObjectsOutcome.GetError().GetExceptionName().c_str(), ": ",
+ listObjectsOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+
+ listObjectsResult = listObjectsOutcome.GetResult();
+ for (const auto& object : listObjectsResult.GetContents()) {
+ Aws::String src_key = object.GetKey();
+ Aws::String target_key = src_key;
+ target_key.replace(0, src_object.length(), target_object.c_str());
+ Aws::String source = Aws::String(src_bucket.c_str()) + "/" + src_key;
+
+ copyObjectRequest.SetBucket(target_bucket.c_str());
+ copyObjectRequest.SetKey(target_key);
+ copyObjectRequest.SetCopySource(source);
+
+ auto copyObjectOutcome = s3Client.CopyObject(copyObjectRequest);
+ if (!copyObjectOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ copyObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
+ copyObjectOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+
+ deleteObjectRequest.SetBucket(src_bucket.c_str());
+ deleteObjectRequest.SetKey(src_key.c_str());
+
+ auto deleteObjectOutcome = s3Client.DeleteObject(deleteObjectRequest);
+ if (!deleteObjectOutcome.IsSuccess()) {
+ string error = strings::StrCat(
+ deleteObjectOutcome.GetError().GetExceptionName().c_str(), ": ",
+ deleteObjectOutcome.GetError().GetMessage().c_str());
+ return errors::Internal(error);
+ }
+ }
+ listObjectsRequest.SetMarker(listObjectsResult.GetNextMarker());
+ } while (listObjectsResult.GetIsTruncated());
+
+ return Status::OK();
+}
+
+REGISTER_FILE_SYSTEM("s3", S3FileSystem);
+
+} // namespace tensorflow