diff options
author | Michael Banfield <micban@google.com> | 2018-07-31 13:32:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-31 13:36:10 -0700 |
commit | 87446c32de440b425eaae9306664c9d6b6f7ffd1 (patch) | |
tree | abb385c60d1f86d414baf60028658bbf64610099 /tensorflow/core/platform | |
parent | 20f56d1057aaacd6f2cc2601793689e00bc9561b (diff) |
Add ZoneProvider interface and GoogleZoneProvider class to detect which Google Cloud Engine zone tensorflow is running in.
PiperOrigin-RevId: 206817614
Diffstat (limited to 'tensorflow/core/platform')
13 files changed, 539 insertions, 63 deletions
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 67651349ea..549996aaf8 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -73,6 +73,7 @@ cc_library( linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 visibility = ["//visibility:public"], deps = [ + ":compute_engine_metadata_client", ":curl_http_request", ":expiring_lru_cache", ":file_block_cache", @@ -144,7 +145,7 @@ cc_library( copts = tf_copts(), visibility = ["//tensorflow:__subpackages__"], deps = [ - ":curl_http_request", + ":compute_engine_metadata_client", ":oauth_client", ":retrying_utils", "//tensorflow/core:lib", @@ -154,6 +155,43 @@ cc_library( ) cc_library( + name = "compute_engine_metadata_client", + srcs = [ + "compute_engine_metadata_client.cc", + ], + hdrs = [ + "compute_engine_metadata_client.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":curl_http_request", + ":http_request", + ":retrying_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "compute_engine_zone_provider", + srcs = [ + "compute_engine_zone_provider.cc", + ], + hdrs = [ + "compute_engine_zone_provider.h", + "zone_provider.h", + ], + copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], + deps = [ + ":compute_engine_metadata_client", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( name = "now_seconds_env", testonly = 1, hdrs = ["now_seconds_env.h"], @@ -345,6 +383,34 @@ tf_cc_test( ) tf_cc_test( + name = "compute_engine_metadata_client_test", + size = "small", + srcs = ["compute_engine_metadata_client_test.cc"], + deps = [ + ":compute_engine_metadata_client", + ":http_request_fake", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( + name = "compute_engine_zone_provider_test", + size = "small", + srcs = ["compute_engine_zone_provider_test.cc"], + deps = [ + ":compute_engine_zone_provider", + ":http_request_fake", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_cc_test( name = "retrying_file_system_test", size = "small", srcs = ["retrying_file_system_test.cc"], diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc new file mode 100644 index 0000000000..f41b83ac34 --- /dev/null +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" + +#include <utility> +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/retrying_utils.h" + +namespace tensorflow { + +namespace { + +// The URL to retrieve metadata when running in Google Compute Engine. +constexpr char kGceMetadataBaseUrl[] = "http://metadata/computeMetadata/v1/"; +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +} // namespace + +ComputeEngineMetadataClient::ComputeEngineMetadataClient( + std::shared_ptr<HttpRequest::Factory> http_request_factory) + : ComputeEngineMetadataClient(std::move(http_request_factory), + kInitialRetryDelayUsec) {} + +ComputeEngineMetadataClient::ComputeEngineMetadataClient( + std::shared_ptr<HttpRequest::Factory> http_request_factory, + int64 initial_retry_delay_usec) + : http_request_factory_(std::move(http_request_factory)), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + +Status ComputeEngineMetadataClient::GetMetadata( + const string& path, std::vector<char>* response_buffer) { + const auto get_metadata_from_gce = [path, response_buffer, this]() { + std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); + request->SetUri(kGceMetadataBaseUrl + path); + request->AddHeader("Metadata-Flavor", "Google"); + request->SetResultBuffer(response_buffer); + TF_RETURN_IF_ERROR(request->Send()); + return Status::OK(); + }; + + return RetryingUtils::CallWithRetries(get_metadata_from_gce, + initial_retry_delay_usec_); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h new file mode 100644 index 0000000000..534ccf30b2 --- /dev/null +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h @@ -0,0 +1,64 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/cloud/http_request.h" + +namespace tensorflow { + +/// \brief A client that accesses to the metadata server running on GCE hosts. +/// +/// Uses the provided HttpRequest::Factory to make requests to the local +/// metadata service +/// (https://cloud.google.com/compute/docs/storing-retrieving-metadata). +/// Retries on recoverable failures using exponential backoff with the initial +/// retry wait configurable via initial_retry_delay_usec. +class ComputeEngineMetadataClient { + public: + explicit ComputeEngineMetadataClient( + std::shared_ptr<HttpRequest::Factory> http_request_factory); + ComputeEngineMetadataClient( + std::shared_ptr<HttpRequest::Factory> http_request_factory, + int64 initial_retry_delay_usec); + virtual ~ComputeEngineMetadataClient() {} + + /// \brief Get the metadata value for a given attribute of the metadata + /// service. + /// + /// Given a metadata path relative + /// to http://metadata.google.internal/computeMetadata/v1/, + /// fills response_buffer with the metadata. Returns OK if the server returns + /// the response for the given metadata path successfully. + /// + /// Example usage: + /// To get the zone of an instance: + /// compute_engine_metadata_client.GetMetadata( + /// "instance/zone", response_buffer); + virtual Status GetMetadata(const string& path, + std::vector<char>* response_buffer); + + private: + std::shared_ptr<HttpRequest::Factory> http_request_factory_; + const int64 initial_retry_delay_usec_; + + TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineMetadataClient); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_ diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc new file mode 100644 index 0000000000..4c41ccaa0e --- /dev/null +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" +#include "tensorflow/core/platform/cloud/http_request_fake.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(ComputeEngineMetadataClientTest, GetMetadata) { + const string example_response = "example response"; + + std::vector<HttpRequest*> requests({new FakeHttpRequest( + "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" + "/default/token\n" + "Header Metadata-Flavor: Google\n", + example_response)}); + + std::shared_ptr<HttpRequest::Factory> http_factory = + std::make_shared<FakeHttpRequestFactory>(&requests); + ComputeEngineMetadataClient client(http_factory, 0); + + std::vector<char> result; + TF_EXPECT_OK( + client.GetMetadata("instance/service-accounts/default/token", &result)); + std::vector<char> expected(example_response.begin(), example_response.end()); + EXPECT_EQ(expected, result); +} + +TEST(ComputeEngineMetadataClientTest, RetryOnFailure) { + const string example_response = "example response"; + + std::vector<HttpRequest*> requests( + {new FakeHttpRequest( + "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" + "/default/token\n" + "Header Metadata-Flavor: Google\n", + "", errors::Unavailable("503"), 503), + new FakeHttpRequest( + "Uri: http://metadata/computeMetadata/v1/instance/service-accounts" + "/default/token\n" + "Header Metadata-Flavor: Google\n", + example_response)}); + + std::shared_ptr<HttpRequest::Factory> http_factory = + std::make_shared<FakeHttpRequestFactory>(&requests); + ComputeEngineMetadataClient client(http_factory, 0); + + std::vector<char> result; + TF_EXPECT_OK( + client.GetMetadata("instance/service-accounts/default/token", &result)); + std::vector<char> expected(example_response.begin(), example_response.end()); + EXPECT_EQ(expected, result); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc new file mode 100644 index 0000000000..dacf56187c --- /dev/null +++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h" + +#include <utility> +#include "tensorflow/core/lib/strings/str_util.h" +namespace tensorflow { + +namespace { +constexpr char kGceMetadataZonePath[] = "instance/zone"; +} // namespace + +ComputeEngineZoneProvider::ComputeEngineZoneProvider( + std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client) + : google_metadata_client_(std::move(google_metadata_client)) {} + +Status ComputeEngineZoneProvider::GetZone(string* zone) { + if (!cached_zone.empty()) { + *zone = cached_zone; + return Status::OK(); + } + std::vector<char> response_buffer; + TF_RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath, + &response_buffer)); + StringPiece location(&response_buffer[0], response_buffer.size()); + + std::vector<string> elems = str_util::Split(location, "/"); + if (elems.size() == 4) { + cached_zone = elems.back(); + *zone = cached_zone; + } else { + LOG(ERROR) << "Failed to parse the zone name from location: " + << location.ToString(); + } + + return Status::OK(); +} +ComputeEngineZoneProvider::~ComputeEngineZoneProvider() {} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.h b/tensorflow/core/platform/cloud/compute_engine_zone_provider.h new file mode 100644 index 0000000000..614b688e6f --- /dev/null +++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_ + +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" +#include "tensorflow/core/platform/cloud/zone_provider.h" + +namespace tensorflow { + +class ComputeEngineZoneProvider : public ZoneProvider { + public: + explicit ComputeEngineZoneProvider( + std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client); + virtual ~ComputeEngineZoneProvider(); + + Status GetZone(string* zone) override; + + private: + std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client_; + string cached_zone; + TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineZoneProvider); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_ diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc new file mode 100644 index 0000000000..f7477eca23 --- /dev/null +++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h" +#include "tensorflow/core/platform/cloud/http_request_fake.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class ComputeEngineZoneProviderTest : public ::testing::Test { + protected: + void SetUp() override {} + + void TearDown() override {} +}; + +TEST_F(ComputeEngineZoneProviderTest, GetZone) { + std::vector<HttpRequest*> requests({new FakeHttpRequest( + "Uri: http://metadata/computeMetadata/v1/instance/zone\n" + "Header Metadata-Flavor: Google\n", + "projects/123456789/zones/us-west1-b")}); + + auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); + + auto metadata_client = + std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0); + + ComputeEngineZoneProvider provider(metadata_client); + + string zone; + + TF_EXPECT_OK(provider.GetZone(&zone)); + EXPECT_EQ("us-west1-b", zone); + // Test caching, should be no further requests + TF_EXPECT_OK(provider.GetZone(&zone)); +} + +TEST_F(ComputeEngineZoneProviderTest, InvalidZoneString) { + std::vector<HttpRequest*> requests({new FakeHttpRequest( + "Uri: http://metadata/computeMetadata/v1/instance/zone\n" + "Header Metadata-Flavor: Google\n", + "invalidresponse")}); + + auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); + + auto metadata_client = + std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0); + + ComputeEngineZoneProvider provider(metadata_client); + + string zone; + + TF_EXPECT_OK(provider.GetZone(&zone)); + EXPECT_EQ("", zone); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index aa35e8a116..2e8d13acd5 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -605,13 +605,18 @@ bool StringPieceIdentity(StringPiece str, StringPiece* value) { } // namespace -GcsFileSystem::GcsFileSystem() - : auth_provider_(new GoogleAuthProvider()), - http_request_factory_(new CurlHttpRequest::Factory()) { +GcsFileSystem::GcsFileSystem() { uint64 value; size_t block_size = kDefaultBlockSize; size_t max_bytes = kDefaultMaxCacheSize; uint64 max_staleness = kDefaultMaxStaleness; + + http_request_factory_ = std::make_shared<CurlHttpRequest::Factory>(); + compute_engine_metadata_client_ = + std::make_shared<ComputeEngineMetadataClient>(http_request_factory_); + auth_provider_ = std::unique_ptr<AuthProvider>( + new GoogleAuthProvider(compute_engine_metadata_client_)); + // Apply the sys env override for the readahead buffer size if it's provided. if (GetEnvVar(kReadaheadBufferSize, strings::safe_strtou64, &value)) { block_size = value; diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 74768c98b5..a0372286f5 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/auth_provider.h" +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" #include "tensorflow/core/platform/cloud/expiring_lru_cache.h" #include "tensorflow/core/platform/cloud/file_block_cache.h" #include "tensorflow/core/platform/cloud/gcs_dns_cache.h" @@ -275,12 +276,13 @@ class GcsFileSystem : public FileSystem { mutex mu_; std::unique_ptr<AuthProvider> auth_provider_ GUARDED_BY(mu_); - std::unique_ptr<HttpRequest::Factory> http_request_factory_; + std::shared_ptr<HttpRequest::Factory> http_request_factory_; // block_cache_lock_ protects the file_block_cache_ pointer (Note that // FileBlockCache instances are themselves threadsafe). mutex block_cache_lock_; std::unique_ptr<FileBlockCache> file_block_cache_ GUARDED_BY(block_cache_lock_); + std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_; std::unique_ptr<GcsDnsCache> dns_cache_; GcsThrottle throttle_; diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc index 7e39b63e3e..6ffe51e897 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider.cc @@ -21,11 +21,11 @@ limitations under the License. #include <sys/types.h> #endif #include <fstream> +#include <utility> #include "include/json/json.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/base64.h" -#include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/retrying_utils.h" #include "tensorflow/core/platform/env.h" @@ -63,16 +63,11 @@ constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; // The URL to retrieve the auth bearer token when running in Google Compute // Engine. -constexpr char kGceTokenUrl[] = - "http://metadata/computeMetadata/v1/instance/service-accounts/default/" - "token"; +constexpr char kGceTokenPath[] = "instance/service-accounts/default/token"; // The authentication token scope to request. constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; -// The default initial delay between retries with exponential backoff. -constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec - /// Returns whether the given path points to a readable file. bool IsFile(const string& filename) { std::ifstream fstream(filename.c_str()); @@ -121,20 +116,20 @@ Status GetWellKnownFileName(string* filename) { } // namespace -GoogleAuthProvider::GoogleAuthProvider() - : GoogleAuthProvider( - std::unique_ptr<OAuthClient>(new OAuthClient()), - std::unique_ptr<HttpRequest::Factory>(new CurlHttpRequest::Factory()), - Env::Default(), kInitialRetryDelayUsec) {} +GoogleAuthProvider::GoogleAuthProvider( + std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client) + : GoogleAuthProvider(std::unique_ptr<OAuthClient>(new OAuthClient()), + std::move(compute_engine_metadata_client), + Env::Default()) {} GoogleAuthProvider::GoogleAuthProvider( std::unique_ptr<OAuthClient> oauth_client, - std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env, - int64 initial_retry_delay_usec) + std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client, + Env* env) : oauth_client_(std::move(oauth_client)), - http_request_factory_(std::move(http_request_factory)), - env_(env), - initial_retry_delay_usec_(initial_retry_delay_usec) {} + compute_engine_metadata_client_( + std::move(compute_engine_metadata_client)), + env_(env) {} Status GoogleAuthProvider::GetToken(string* t) { mutex_lock lock(mu_); @@ -207,24 +202,19 @@ Status GoogleAuthProvider::GetTokenFromFiles() { } Status GoogleAuthProvider::GetTokenFromGce() { - const auto get_token_from_gce = [this]() { - std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); - std::vector<char> response_buffer; - const uint64 request_timestamp_sec = env_->NowSeconds(); - request->SetUri(kGceTokenUrl); - request->AddHeader("Metadata-Flavor", "Google"); - request->SetResultBuffer(&response_buffer); - TF_RETURN_IF_ERROR(request->Send()); - StringPiece response = - StringPiece(&response_buffer[0], response_buffer.size()); - - TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse( - response, request_timestamp_sec, ¤t_token_, - &expiration_timestamp_sec_)); - return Status::OK(); - }; - return RetryingUtils::CallWithRetries(get_token_from_gce, - initial_retry_delay_usec_); + std::vector<char> response_buffer; + const uint64 request_timestamp_sec = env_->NowSeconds(); + + TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata( + kGceTokenPath, &response_buffer)); + StringPiece response = + StringPiece(&response_buffer[0], response_buffer.size()); + + TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse( + response, request_timestamp_sec, ¤t_token_, + &expiration_timestamp_sec_)); + + return Status::OK(); } Status GoogleAuthProvider::GetTokenForTesting() { diff --git a/tensorflow/core/platform/cloud/google_auth_provider.h b/tensorflow/core/platform/cloud/google_auth_provider.h index 00da25a959..58a785fd60 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.h +++ b/tensorflow/core/platform/cloud/google_auth_provider.h @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include "tensorflow/core/platform/cloud/auth_provider.h" +#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h" #include "tensorflow/core/platform/cloud/oauth_client.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -27,11 +28,12 @@ namespace tensorflow { /// Implementation based on Google Application Default Credentials. class GoogleAuthProvider : public AuthProvider { public: - GoogleAuthProvider(); - explicit GoogleAuthProvider( - std::unique_ptr<OAuthClient> oauth_client, - std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env, - int64 initial_retry_delay_usec); + GoogleAuthProvider(std::shared_ptr<ComputeEngineMetadataClient> + compute_engine_metadata_client); + explicit GoogleAuthProvider(std::unique_ptr<OAuthClient> oauth_client, + std::shared_ptr<ComputeEngineMetadataClient> + compute_engine_metadata_client, + Env* env); virtual ~GoogleAuthProvider() {} /// \brief Returns the short-term authentication bearer token. @@ -53,13 +55,11 @@ class GoogleAuthProvider : public AuthProvider { Status GetTokenForTesting() EXCLUSIVE_LOCKS_REQUIRED(mu_); std::unique_ptr<OAuthClient> oauth_client_; - std::unique_ptr<HttpRequest::Factory> http_request_factory_; + std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_; Env* env_; mutex mu_; string current_token_ GUARDED_BY(mu_); uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; - // The initial delay for exponential backoffs when retrying failed calls. - const int64 initial_retry_delay_usec_; TF_DISALLOW_COPY_AND_ASSIGN(GoogleAuthProvider); }; diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc index 4281c6c737..07b88a880f 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc @@ -90,10 +90,13 @@ TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) { std::vector<HttpRequest*> requests; FakeEnv env; + + std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = + std::make_shared<FakeHttpRequestFactory>(&requests); + auto metadataClient = + std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - &env, 0); + metadataClient, &env); oauth_client->return_token = "fake-token"; oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600; @@ -124,10 +127,13 @@ TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) { std::vector<HttpRequest*> requests; FakeEnv env; + std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = + std::make_shared<FakeHttpRequestFactory>(&requests); + auto metadataClient = + std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); + GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - &env, 0); + metadataClient, &env); oauth_client->return_token = "fake-token"; oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600; @@ -170,10 +176,12 @@ TEST_F(GoogleAuthProviderTest, RunningOnGCE) { })")}); FakeEnv env; + std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = + std::make_shared<FakeHttpRequestFactory>(&requests); + auto metadataClient = + std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - &env, 0); + metadataClient, &env); string token; TF_EXPECT_OK(provider.GetToken(&token)); @@ -196,10 +204,12 @@ TEST_F(GoogleAuthProviderTest, OverrideForTesting) { auto oauth_client = new FakeOAuthClient; std::vector<HttpRequest*> empty_requests; FakeEnv env; + std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = + std::make_shared<FakeHttpRequestFactory>(&empty_requests); + auto metadataClient = + std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&empty_requests)), - &env, 0); + metadataClient, &env); string token; TF_EXPECT_OK(provider.GetToken(&token)); @@ -216,10 +226,12 @@ TEST_F(GoogleAuthProviderTest, NothingAvailable) { "", errors::NotFound("404"), 404)}); FakeEnv env; + std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = + std::make_shared<FakeHttpRequestFactory>(&requests); + auto metadataClient = + std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - &env, 0); + metadataClient, &env); string token; TF_EXPECT_OK(provider.GetToken(&token)); diff --git a/tensorflow/core/platform/cloud/zone_provider.h b/tensorflow/core/platform/cloud/zone_provider.h new file mode 100644 index 0000000000..421b6a7e1a --- /dev/null +++ b/tensorflow/core/platform/cloud/zone_provider.h @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ +#define TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ + +#include <string> +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +/// Interface for a provider of cloud instance zone +class ZoneProvider { + public: + virtual ~ZoneProvider() {} + + /// \brief Gets the zone of the Cloud instance and set the result in `zone`. + /// Returns OK if success. + /// + /// Returns an empty string in the case where the zone does not match the + /// expected format + /// Safe for concurrent use by multiple threads. + virtual Status GetZone(string* zone) = 0; + + static Status GetZone(ZoneProvider* provider, string* zone) { + if (!provider) { + return errors::Internal("Zone provider is required."); + } + return provider->GetZone(zone); + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_ |