diff options
Diffstat (limited to 'tensorflow/core/platform/s3/s3_file_system.cc')
-rw-r--r-- | tensorflow/core/platform/s3/s3_file_system.cc | 50 |
1 files changed, 46 insertions, 4 deletions
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index 397f26ec0b..ebda3a2065 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/platform/s3/s3_file_system.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/s3/aws_logging.h" #include "tensorflow/core/platform/s3/s3_crypto.h" #include <aws/core/Aws.h> +#include <aws/core/config/AWSProfileConfigLoader.h> #include <aws/core/utils/FileSystemUtils.h> #include <aws/core/utils/logging/AWSLogging.h> #include <aws/core/utils/logging/LogSystemInterface.h> @@ -54,13 +56,37 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() { cfg.endpointOverride = Aws::String(endpoint); } const char* region = getenv("AWS_REGION"); + if (!region) { + // TODO (yongtang): `S3_REGION` should be deprecated after 2.0. + region = getenv("S3_REGION"); + } if (region) { cfg.region = Aws::String(region); } else { - // TODO (yongtang): `S3_REGION` should be deprecated after 2.0. - const char* region = getenv("S3_REGION"); - if (region) { - cfg.region = Aws::String(region); + // Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG + // is set with a truthy value. + const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG"); + string load_config = + load_config_env ? str_util::Lowercase(load_config_env) : ""; + if (load_config == "true" || load_config == "1") { + Aws::String config_file; + // If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config. + const char* config_file_env = getenv("AWS_CONFIG_FILE"); + if (config_file_env) { + config_file = config_file_env; + } else { + const char* home_env = getenv("HOME"); + if (home_env) { + config_file = home_env; + config_file += "/.aws/config"; + } + } + Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file); + loader.Load(); + auto profiles = loader.GetProfiles(); + if (!profiles["default"].GetRegion().empty()) { + cfg.region = profiles["default"].GetRegion(); + } } } const char* use_https = getenv("S3_USE_HTTPS"); @@ -79,6 +105,22 @@ Aws::Client::ClientConfiguration& GetDefaultClientConfig() { cfg.verifySSL = true; } } + const char* connect_timeout = getenv("S3_CONNECT_TIMEOUT_MSEC"); + if (connect_timeout) { + int64 timeout; + + if (strings::safe_strto64(connect_timeout, &timeout)) { + cfg.connectTimeoutMs = timeout; + } + } + const char* request_timeout = getenv("S3_REQUEST_TIMEOUT_MSEC"); + if (request_timeout) { + int64 timeout; + + if (strings::safe_strto64(request_timeout, &timeout)) { + cfg.requestTimeoutMs = timeout; + } + } init = true; } |