aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/platform/s3/s3_file_system.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/platform/s3/s3_file_system.cc')
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc50
1 files changed, 46 insertions, 4 deletions
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 397f26ec0b..ebda3a2065 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/s3/aws_logging.h"
#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <aws/core/Aws.h>
+#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/utils/logging/AWSLogging.h>
#include <aws/core/utils/logging/LogSystemInterface.h>
@@ -54,13 +56,37 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
cfg.endpointOverride = Aws::String(endpoint);
}
const char* region = getenv("AWS_REGION");
+ if (!region) {
+ // TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
+ region = getenv("S3_REGION");
+ }
if (region) {
cfg.region = Aws::String(region);
} else {
- // TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
- const char* region = getenv("S3_REGION");
- if (region) {
- cfg.region = Aws::String(region);
+ // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
+ // is set with a truthy value.
+ const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
+ string load_config =
+ load_config_env ? str_util::Lowercase(load_config_env) : "";
+ if (load_config == "true" || load_config == "1") {
+ Aws::String config_file;
+ // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
+ const char* config_file_env = getenv("AWS_CONFIG_FILE");
+ if (config_file_env) {
+ config_file = config_file_env;
+ } else {
+ const char* home_env = getenv("HOME");
+ if (home_env) {
+ config_file = home_env;
+ config_file += "/.aws/config";
+ }
+ }
+ Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
+ loader.Load();
+ auto profiles = loader.GetProfiles();
+ if (!profiles["default"].GetRegion().empty()) {
+ cfg.region = profiles["default"].GetRegion();
+ }
}
}
const char* use_https = getenv("S3_USE_HTTPS");
@@ -79,6 +105,22 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
cfg.verifySSL = true;
}
}
+ const char* connect_timeout = getenv("S3_CONNECT_TIMEOUT_MSEC");
+ if (connect_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(connect_timeout, &timeout)) {
+ cfg.connectTimeoutMs = timeout;
+ }
+ }
+ const char* request_timeout = getenv("S3_REQUEST_TIMEOUT_MSEC");
+ if (request_timeout) {
+ int64 timeout;
+
+ if (strings::safe_strto64(request_timeout, &timeout)) {
+ cfg.requestTimeoutMs = timeout;
+ }
+ }
init = true;
}