aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/platform
diff options
context:
space:
mode:
authorGravatar Michael Banfield <micban@google.com>2018-07-31 13:32:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 13:36:10 -0700
commit87446c32de440b425eaae9306664c9d6b6f7ffd1 (patch)
treeabb385c60d1f86d414baf60028658bbf64610099 /tensorflow/core/platform
parent20f56d1057aaacd6f2cc2601793689e00bc9561b (diff)
Add ZoneProvider interface and GoogleZoneProvider class to detect which Google Cloud Engine zone tensorflow is running in.
PiperOrigin-RevId: 206817614
Diffstat (limited to 'tensorflow/core/platform')
-rw-r--r--tensorflow/core/platform/cloud/BUILD68
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.cc59
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.h64
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc68
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider.cc53
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider.h40
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc69
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc11
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h4
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider.cc60
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider.h16
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider_test.cc42
-rw-r--r--tensorflow/core/platform/cloud/zone_provider.h48
13 files changed, 539 insertions, 63 deletions
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 67651349ea..549996aaf8 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -73,6 +73,7 @@ cc_library(
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
deps = [
+ ":compute_engine_metadata_client",
":curl_http_request",
":expiring_lru_cache",
":file_block_cache",
@@ -144,7 +145,7 @@ cc_library(
copts = tf_copts(),
visibility = ["//tensorflow:__subpackages__"],
deps = [
- ":curl_http_request",
+ ":compute_engine_metadata_client",
":oauth_client",
":retrying_utils",
"//tensorflow/core:lib",
@@ -154,6 +155,43 @@ cc_library(
)
cc_library(
+ name = "compute_engine_metadata_client",
+ srcs = [
+ "compute_engine_metadata_client.cc",
+ ],
+ hdrs = [
+ "compute_engine_metadata_client.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":curl_http_request",
+ ":http_request",
+ ":retrying_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "compute_engine_zone_provider",
+ srcs = [
+ "compute_engine_zone_provider.cc",
+ ],
+ hdrs = [
+ "compute_engine_zone_provider.h",
+ "zone_provider.h",
+ ],
+ copts = tf_copts(),
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":compute_engine_metadata_client",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
name = "now_seconds_env",
testonly = 1,
hdrs = ["now_seconds_env.h"],
@@ -345,6 +383,34 @@ tf_cc_test(
)
tf_cc_test(
+ name = "compute_engine_metadata_client_test",
+ size = "small",
+ srcs = ["compute_engine_metadata_client_test.cc"],
+ deps = [
+ ":compute_engine_metadata_client",
+ ":http_request_fake",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "compute_engine_zone_provider_test",
+ size = "small",
+ srcs = ["compute_engine_zone_provider_test.cc"],
+ deps = [
+ ":compute_engine_zone_provider",
+ ":http_request_fake",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
name = "retrying_file_system_test",
size = "small",
srcs = ["retrying_file_system_test.cc"],
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
new file mode 100644
index 0000000000..f41b83ac34
--- /dev/null
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
+
+#include <utility>
+#include "tensorflow/core/platform/cloud/curl_http_request.h"
+#include "tensorflow/core/platform/cloud/retrying_utils.h"
+
+namespace tensorflow {
+
+namespace {
+
+// The URL to retrieve metadata when running in Google Compute Engine.
+constexpr char kGceMetadataBaseUrl[] = "http://metadata/computeMetadata/v1/";
+// The default initial delay between retries with exponential backoff.
+constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
+
+} // namespace
+
+ComputeEngineMetadataClient::ComputeEngineMetadataClient(
+ std::shared_ptr<HttpRequest::Factory> http_request_factory)
+ : ComputeEngineMetadataClient(std::move(http_request_factory),
+ kInitialRetryDelayUsec) {}
+
+ComputeEngineMetadataClient::ComputeEngineMetadataClient(
+ std::shared_ptr<HttpRequest::Factory> http_request_factory,
+ int64 initial_retry_delay_usec)
+ : http_request_factory_(std::move(http_request_factory)),
+ initial_retry_delay_usec_(initial_retry_delay_usec) {}
+
+Status ComputeEngineMetadataClient::GetMetadata(
+ const string& path, std::vector<char>* response_buffer) {
+ const auto get_metadata_from_gce = [path, response_buffer, this]() {
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ request->SetUri(kGceMetadataBaseUrl + path);
+ request->AddHeader("Metadata-Flavor", "Google");
+ request->SetResultBuffer(response_buffer);
+ TF_RETURN_IF_ERROR(request->Send());
+ return Status::OK();
+ };
+
+ return RetryingUtils::CallWithRetries(get_metadata_from_gce,
+ initial_retry_delay_usec_);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
new file mode 100644
index 0000000000..534ccf30b2
--- /dev/null
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
@@ -0,0 +1,64 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/cloud/http_request.h"
+
+namespace tensorflow {
+
+/// \brief A client that accesses to the metadata server running on GCE hosts.
+///
+/// Uses the provided HttpRequest::Factory to make requests to the local
+/// metadata service
+/// (https://cloud.google.com/compute/docs/storing-retrieving-metadata).
+/// Retries on recoverable failures using exponential backoff with the initial
+/// retry wait configurable via initial_retry_delay_usec.
+class ComputeEngineMetadataClient {
+ public:
+ explicit ComputeEngineMetadataClient(
+ std::shared_ptr<HttpRequest::Factory> http_request_factory);
+ ComputeEngineMetadataClient(
+ std::shared_ptr<HttpRequest::Factory> http_request_factory,
+ int64 initial_retry_delay_usec);
+ virtual ~ComputeEngineMetadataClient() {}
+
+ /// \brief Get the metadata value for a given attribute of the metadata
+ /// service.
+ ///
+ /// Given a metadata path relative
+ /// to http://metadata.google.internal/computeMetadata/v1/,
+ /// fills response_buffer with the metadata. Returns OK if the server returns
+ /// the response for the given metadata path successfully.
+ ///
+ /// Example usage:
+ /// To get the zone of an instance:
+ /// compute_engine_metadata_client.GetMetadata(
+ /// "instance/zone", response_buffer);
+ virtual Status GetMetadata(const string& path,
+ std::vector<char>* response_buffer);
+
+ private:
+ std::shared_ptr<HttpRequest::Factory> http_request_factory_;
+ const int64 initial_retry_delay_usec_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineMetadataClient);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_METADATA_CLIENT_H_
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
new file mode 100644
index 0000000000..4c41ccaa0e
--- /dev/null
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
+#include "tensorflow/core/platform/cloud/http_request_fake.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ComputeEngineMetadataClientTest, GetMetadata) {
+ const string example_response = "example response";
+
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
+ "/default/token\n"
+ "Header Metadata-Flavor: Google\n",
+ example_response)});
+
+ std::shared_ptr<HttpRequest::Factory> http_factory =
+ std::make_shared<FakeHttpRequestFactory>(&requests);
+ ComputeEngineMetadataClient client(http_factory, 0);
+
+ std::vector<char> result;
+ TF_EXPECT_OK(
+ client.GetMetadata("instance/service-accounts/default/token", &result));
+ std::vector<char> expected(example_response.begin(), example_response.end());
+ EXPECT_EQ(expected, result);
+}
+
+TEST(ComputeEngineMetadataClientTest, RetryOnFailure) {
+ const string example_response = "example response";
+
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
+ "/default/token\n"
+ "Header Metadata-Flavor: Google\n",
+ "", errors::Unavailable("503"), 503),
+ new FakeHttpRequest(
+ "Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
+ "/default/token\n"
+ "Header Metadata-Flavor: Google\n",
+ example_response)});
+
+ std::shared_ptr<HttpRequest::Factory> http_factory =
+ std::make_shared<FakeHttpRequestFactory>(&requests);
+ ComputeEngineMetadataClient client(http_factory, 0);
+
+ std::vector<char> result;
+ TF_EXPECT_OK(
+ client.GetMetadata("instance/service-accounts/default/token", &result));
+ std::vector<char> expected(example_response.begin(), example_response.end());
+ EXPECT_EQ(expected, result);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
new file mode 100644
index 0000000000..dacf56187c
--- /dev/null
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h"
+
+#include <utility>
+#include "tensorflow/core/lib/strings/str_util.h"
+namespace tensorflow {
+
+namespace {
+constexpr char kGceMetadataZonePath[] = "instance/zone";
+} // namespace
+
+ComputeEngineZoneProvider::ComputeEngineZoneProvider(
+ std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client)
+ : google_metadata_client_(std::move(google_metadata_client)) {}
+
+Status ComputeEngineZoneProvider::GetZone(string* zone) {
+ if (!cached_zone.empty()) {
+ *zone = cached_zone;
+ return Status::OK();
+ }
+ std::vector<char> response_buffer;
+ TF_RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath,
+ &response_buffer));
+ StringPiece location(&response_buffer[0], response_buffer.size());
+
+ std::vector<string> elems = str_util::Split(location, "/");
+ if (elems.size() == 4) {
+ cached_zone = elems.back();
+ *zone = cached_zone;
+ } else {
+ LOG(ERROR) << "Failed to parse the zone name from location: "
+ << location.ToString();
+ }
+
+ return Status::OK();
+}
+ComputeEngineZoneProvider::~ComputeEngineZoneProvider() {}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.h b/tensorflow/core/platform/cloud/compute_engine_zone_provider.h
new file mode 100644
index 0000000000..614b688e6f
--- /dev/null
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.h
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_
+
+#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
+#include "tensorflow/core/platform/cloud/zone_provider.h"
+
+namespace tensorflow {
+
+class ComputeEngineZoneProvider : public ZoneProvider {
+ public:
+ explicit ComputeEngineZoneProvider(
+ std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client);
+ virtual ~ComputeEngineZoneProvider();
+
+ Status GetZone(string* zone) override;
+
+ private:
+ std::shared_ptr<ComputeEngineMetadataClient> google_metadata_client_;
+ string cached_zone;
+ TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineZoneProvider);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_COMPUTE_ENGINE_ZONE_PROVIDER_H_
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
new file mode 100644
index 0000000000..f7477eca23
--- /dev/null
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
@@ -0,0 +1,69 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/compute_engine_zone_provider.h"
+#include "tensorflow/core/platform/cloud/http_request_fake.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+class ComputeEngineZoneProviderTest : public ::testing::Test {
+ protected:
+ void SetUp() override {}
+
+ void TearDown() override {}
+};
+
+TEST_F(ComputeEngineZoneProviderTest, GetZone) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: http://metadata/computeMetadata/v1/instance/zone\n"
+ "Header Metadata-Flavor: Google\n",
+ "projects/123456789/zones/us-west1-b")});
+
+ auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
+
+ auto metadata_client =
+ std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+
+ ComputeEngineZoneProvider provider(metadata_client);
+
+ string zone;
+
+ TF_EXPECT_OK(provider.GetZone(&zone));
+ EXPECT_EQ("us-west1-b", zone);
+ // Test caching, should be no further requests
+ TF_EXPECT_OK(provider.GetZone(&zone));
+}
+
+TEST_F(ComputeEngineZoneProviderTest, InvalidZoneString) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: http://metadata/computeMetadata/v1/instance/zone\n"
+ "Header Metadata-Flavor: Google\n",
+ "invalidresponse")});
+
+ auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
+
+ auto metadata_client =
+ std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+
+ ComputeEngineZoneProvider provider(metadata_client);
+
+ string zone;
+
+ TF_EXPECT_OK(provider.GetZone(&zone));
+ EXPECT_EQ("", zone);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index aa35e8a116..2e8d13acd5 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -605,13 +605,18 @@ bool StringPieceIdentity(StringPiece str, StringPiece* value) {
} // namespace
-GcsFileSystem::GcsFileSystem()
- : auth_provider_(new GoogleAuthProvider()),
- http_request_factory_(new CurlHttpRequest::Factory()) {
+GcsFileSystem::GcsFileSystem() {
uint64 value;
size_t block_size = kDefaultBlockSize;
size_t max_bytes = kDefaultMaxCacheSize;
uint64 max_staleness = kDefaultMaxStaleness;
+
+ http_request_factory_ = std::make_shared<CurlHttpRequest::Factory>();
+ compute_engine_metadata_client_ =
+ std::make_shared<ComputeEngineMetadataClient>(http_request_factory_);
+ auth_provider_ = std::unique_ptr<AuthProvider>(
+ new GoogleAuthProvider(compute_engine_metadata_client_));
+
// Apply the sys env override for the readahead buffer size if it's provided.
if (GetEnvVar(kReadaheadBufferSize, strings::safe_strtou64, &value)) {
block_size = value;
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 74768c98b5..a0372286f5 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/auth_provider.h"
+#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
#include "tensorflow/core/platform/cloud/expiring_lru_cache.h"
#include "tensorflow/core/platform/cloud/file_block_cache.h"
#include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
@@ -275,12 +276,13 @@ class GcsFileSystem : public FileSystem {
mutex mu_;
std::unique_ptr<AuthProvider> auth_provider_ GUARDED_BY(mu_);
- std::unique_ptr<HttpRequest::Factory> http_request_factory_;
+ std::shared_ptr<HttpRequest::Factory> http_request_factory_;
// block_cache_lock_ protects the file_block_cache_ pointer (Note that
// FileBlockCache instances are themselves threadsafe).
mutex block_cache_lock_;
std::unique_ptr<FileBlockCache> file_block_cache_
GUARDED_BY(block_cache_lock_);
+ std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
std::unique_ptr<GcsDnsCache> dns_cache_;
GcsThrottle throttle_;
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc
index 7e39b63e3e..6ffe51e897 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider.cc
@@ -21,11 +21,11 @@ limitations under the License.
#include <sys/types.h>
#endif
#include <fstream>
+#include <utility>
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/base64.h"
-#include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/platform/cloud/retrying_utils.h"
#include "tensorflow/core/platform/env.h"
@@ -63,16 +63,11 @@ constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
// The URL to retrieve the auth bearer token when running in Google Compute
// Engine.
-constexpr char kGceTokenUrl[] =
- "http://metadata/computeMetadata/v1/instance/service-accounts/default/"
- "token";
+constexpr char kGceTokenPath[] = "instance/service-accounts/default/token";
// The authentication token scope to request.
constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
-// The default initial delay between retries with exponential backoff.
-constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
-
/// Returns whether the given path points to a readable file.
bool IsFile(const string& filename) {
std::ifstream fstream(filename.c_str());
@@ -121,20 +116,20 @@ Status GetWellKnownFileName(string* filename) {
} // namespace
-GoogleAuthProvider::GoogleAuthProvider()
- : GoogleAuthProvider(
- std::unique_ptr<OAuthClient>(new OAuthClient()),
- std::unique_ptr<HttpRequest::Factory>(new CurlHttpRequest::Factory()),
- Env::Default(), kInitialRetryDelayUsec) {}
+GoogleAuthProvider::GoogleAuthProvider(
+ std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client)
+ : GoogleAuthProvider(std::unique_ptr<OAuthClient>(new OAuthClient()),
+ std::move(compute_engine_metadata_client),
+ Env::Default()) {}
GoogleAuthProvider::GoogleAuthProvider(
std::unique_ptr<OAuthClient> oauth_client,
- std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env,
- int64 initial_retry_delay_usec)
+ std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client,
+ Env* env)
: oauth_client_(std::move(oauth_client)),
- http_request_factory_(std::move(http_request_factory)),
- env_(env),
- initial_retry_delay_usec_(initial_retry_delay_usec) {}
+ compute_engine_metadata_client_(
+ std::move(compute_engine_metadata_client)),
+ env_(env) {}
Status GoogleAuthProvider::GetToken(string* t) {
mutex_lock lock(mu_);
@@ -207,24 +202,19 @@ Status GoogleAuthProvider::GetTokenFromFiles() {
}
Status GoogleAuthProvider::GetTokenFromGce() {
- const auto get_token_from_gce = [this]() {
- std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
- std::vector<char> response_buffer;
- const uint64 request_timestamp_sec = env_->NowSeconds();
- request->SetUri(kGceTokenUrl);
- request->AddHeader("Metadata-Flavor", "Google");
- request->SetResultBuffer(&response_buffer);
- TF_RETURN_IF_ERROR(request->Send());
- StringPiece response =
- StringPiece(&response_buffer[0], response_buffer.size());
-
- TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
- response, request_timestamp_sec, &current_token_,
- &expiration_timestamp_sec_));
- return Status::OK();
- };
- return RetryingUtils::CallWithRetries(get_token_from_gce,
- initial_retry_delay_usec_);
+ std::vector<char> response_buffer;
+ const uint64 request_timestamp_sec = env_->NowSeconds();
+
+ TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata(
+ kGceTokenPath, &response_buffer));
+ StringPiece response =
+ StringPiece(&response_buffer[0], response_buffer.size());
+
+ TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
+ response, request_timestamp_sec, &current_token_,
+ &expiration_timestamp_sec_));
+
+ return Status::OK();
}
Status GoogleAuthProvider::GetTokenForTesting() {
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.h b/tensorflow/core/platform/cloud/google_auth_provider.h
index 00da25a959..58a785fd60 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.h
+++ b/tensorflow/core/platform/cloud/google_auth_provider.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/platform/cloud/auth_provider.h"
+#include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
#include "tensorflow/core/platform/cloud/oauth_client.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -27,11 +28,12 @@ namespace tensorflow {
/// Implementation based on Google Application Default Credentials.
class GoogleAuthProvider : public AuthProvider {
public:
- GoogleAuthProvider();
- explicit GoogleAuthProvider(
- std::unique_ptr<OAuthClient> oauth_client,
- std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env,
- int64 initial_retry_delay_usec);
+ GoogleAuthProvider(std::shared_ptr<ComputeEngineMetadataClient>
+ compute_engine_metadata_client);
+ explicit GoogleAuthProvider(std::unique_ptr<OAuthClient> oauth_client,
+ std::shared_ptr<ComputeEngineMetadataClient>
+ compute_engine_metadata_client,
+ Env* env);
virtual ~GoogleAuthProvider() {}
/// \brief Returns the short-term authentication bearer token.
@@ -53,13 +55,11 @@ class GoogleAuthProvider : public AuthProvider {
Status GetTokenForTesting() EXCLUSIVE_LOCKS_REQUIRED(mu_);
std::unique_ptr<OAuthClient> oauth_client_;
- std::unique_ptr<HttpRequest::Factory> http_request_factory_;
+ std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
Env* env_;
mutex mu_;
string current_token_ GUARDED_BY(mu_);
uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
- // The initial delay for exponential backoffs when retrying failed calls.
- const int64 initial_retry_delay_usec_;
TF_DISALLOW_COPY_AND_ASSIGN(GoogleAuthProvider);
};
diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
index 4281c6c737..07b88a880f 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
@@ -90,10 +90,13 @@ TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
std::vector<HttpRequest*> requests;
FakeEnv env;
+
+ std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
+ std::make_shared<FakeHttpRequestFactory>(&requests);
+ auto metadataClient =
+ std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- &env, 0);
+ metadataClient, &env);
oauth_client->return_token = "fake-token";
oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
@@ -124,10 +127,13 @@ TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
std::vector<HttpRequest*> requests;
FakeEnv env;
+ std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
+ std::make_shared<FakeHttpRequestFactory>(&requests);
+ auto metadataClient =
+ std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- &env, 0);
+ metadataClient, &env);
oauth_client->return_token = "fake-token";
oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
@@ -170,10 +176,12 @@ TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
})")});
FakeEnv env;
+ std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
+ std::make_shared<FakeHttpRequestFactory>(&requests);
+ auto metadataClient =
+ std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- &env, 0);
+ metadataClient, &env);
string token;
TF_EXPECT_OK(provider.GetToken(&token));
@@ -196,10 +204,12 @@ TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
auto oauth_client = new FakeOAuthClient;
std::vector<HttpRequest*> empty_requests;
FakeEnv env;
+ std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
+ std::make_shared<FakeHttpRequestFactory>(&empty_requests);
+ auto metadataClient =
+ std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&empty_requests)),
- &env, 0);
+ metadataClient, &env);
string token;
TF_EXPECT_OK(provider.GetToken(&token));
@@ -216,10 +226,12 @@ TEST_F(GoogleAuthProviderTest, NothingAvailable) {
"", errors::NotFound("404"), 404)});
FakeEnv env;
+ std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
+ std::make_shared<FakeHttpRequestFactory>(&requests);
+ auto metadataClient =
+ std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- &env, 0);
+ metadataClient, &env);
string token;
TF_EXPECT_OK(provider.GetToken(&token));
diff --git a/tensorflow/core/platform/cloud/zone_provider.h b/tensorflow/core/platform/cloud/zone_provider.h
new file mode 100644
index 0000000000..421b6a7e1a
--- /dev/null
+++ b/tensorflow/core/platform/cloud/zone_provider.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_
+
+#include <string>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+/// Interface for a provider of cloud instance zone
+class ZoneProvider {
+ public:
+ virtual ~ZoneProvider() {}
+
+ /// \brief Gets the zone of the Cloud instance and set the result in `zone`.
+ /// Returns OK if success.
+ ///
+ /// Returns an empty string in the case where the zone does not match the
+ /// expected format
+ /// Safe for concurrent use by multiple threads.
+ virtual Status GetZone(string* zone) = 0;
+
+ static Status GetZone(ZoneProvider* provider, string* zone) {
+ if (!provider) {
+ return errors::Internal("Zone provider is required.");
+ }
+ return provider->GetZone(zone);
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_ZONE_PROVIDER_H_