aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cloud
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-05-30 15:25:46 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-05-30 15:25:46 -0700
commite469934f1274c7c498e5061995fec425a21c9be8 (patch)
treeab9c0078f1c1fa5027537096898f560cbd9833fe /tensorflow/contrib/cloud
parent176754d6cce54a971c98096f55251870708eea3e (diff)
Add GCS configure ops.
PiperOrigin-RevId: 198624285
Diffstat (limited to 'tensorflow/contrib/cloud')
-rw-r--r--tensorflow/contrib/cloud/BUILD15
-rw-r--r--tensorflow/contrib/cloud/__init__.py8
-rw-r--r--tensorflow/contrib/cloud/kernels/BUILD14
-rw-r--r--tensorflow/contrib/cloud/kernels/gcs_config_ops.cc203
-rw-r--r--tensorflow/contrib/cloud/ops/gcs_config_ops.cc70
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops.py176
6 files changed, 484 insertions, 2 deletions
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
index f3a75e8688..42ba368531 100644
--- a/tensorflow/contrib/cloud/BUILD
+++ b/tensorflow/contrib/cloud/BUILD
@@ -15,7 +15,10 @@ load(
)
tf_gen_op_libs(
- op_lib_names = ["bigquery_reader_ops"],
+ op_lib_names = [
+ "bigquery_reader_ops",
+ "gcs_config_ops",
+ ],
deps = [
"//tensorflow/core:lib",
],
@@ -28,15 +31,25 @@ tf_gen_op_wrapper_py(
deps = [":bigquery_reader_ops_op_lib"],
)
+tf_gen_op_wrapper_py(
+ name = "gen_gcs_config_ops",
+ out = "python/ops/gen_gcs_config_ops.py",
+ require_shape_functions = True,
+ visibility = ["//tensorflow:internal"],
+ deps = [":gcs_config_ops_op_lib"],
+)
+
py_library(
name = "cloud_py",
srcs = [
"__init__.py",
"python/ops/bigquery_reader_ops.py",
+ "python/ops/gcs_config_ops.py",
],
srcs_version = "PY2AND3",
deps = [
":gen_bigquery_reader_ops",
+ ":gen_gcs_config_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:io_ops",
"//tensorflow/python:util",
diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py
index 8870264b95..a6e13ea3ae 100644
--- a/tensorflow/contrib/cloud/__init__.py
+++ b/tensorflow/contrib/cloud/__init__.py
@@ -20,9 +20,15 @@ from __future__ import print_function
# pylint: disable=line-too-long,wildcard-import
from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import *
+from tensorflow.contrib.cloud.python.ops.gcs_config_ops import *
# pylint: enable=line-too-long,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['BigQueryReader']
+_allowed_symbols = [
+ 'BigQueryReader',
+ 'ConfigureColabSession',
+ 'ConfigureGcs',
+ 'ConfigureGcsHook',
+]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD
index ff46f0daa8..40160706f7 100644
--- a/tensorflow/contrib/cloud/kernels/BUILD
+++ b/tensorflow/contrib/cloud/kernels/BUILD
@@ -73,3 +73,17 @@ tf_proto_library(
srcs = ["bigquery_table_partition.proto"],
cc_api_version = 2,
)
+
+tf_kernel_library(
+ name = "gcs_config_ops",
+ srcs = ["gcs_config_ops.cc"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/platform/cloud:curl_http_request",
+ "//tensorflow/core/platform/cloud:gcs_file_system",
+ "//tensorflow/core/platform/cloud:oauth_client",
+ "@jsoncpp_git//:jsoncpp",
+ ],
+)
diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc
new file mode 100644
index 0000000000..ef4998212e
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc
@@ -0,0 +1,203 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sstream>
+
+#include "include/json/json.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/cloud/curl_http_request.h"
+#include "tensorflow/core/platform/cloud/gcs_file_system.h"
+#include "tensorflow/core/platform/cloud/oauth_client.h"
+
+namespace tensorflow {
+namespace {
+
+// The default initial delay between retries with exponential backoff.
+constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
+
+// The minimum time delta between now and the token expiration time
+// for the token to be re-used.
+constexpr int kExpirationTimeMarginSec = 60;
+
+// The URL to retrieve the auth bearer token via OAuth with a refresh token.
+constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";
+
+// The URL to retrieve the auth bearer token via OAuth with a private key.
+constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
+
+// The authentication token scope to request.
+constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
+
+Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) {
+ DCHECK(fs != nullptr);
+ *fs = nullptr;
+
+ FileSystem* filesystem = nullptr;
+ TF_RETURN_IF_ERROR(
+ ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem));
+ if (filesystem == nullptr) {
+ return errors::FailedPrecondition("The GCS file system is not registered.");
+ }
+
+ *fs = dynamic_cast<RetryingGcsFileSystem*>(filesystem);
+ if (*fs == nullptr) {
+ return errors::Internal(
+ "The filesystem registered under the 'gs://' scheme was not a "
+ "tensorflow::RetryingGcsFileSystem*.");
+ }
+ return Status::OK();
+}
+
+template <typename T>
+Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
+ T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+}
+
+// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem.
+class GcsCredentialsOpKernel : public OpKernel {
+ public:
+ explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ // Get a handle to the GCS file system.
+ RetryingGcsFileSystem* gcs = nullptr;
+ OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));
+
+ string json_string;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "json", &json_string));
+
+ Json::Value json;
+ Json::Reader reader;
+ std::stringstream json_stream(json_string);
+ OP_REQUIRES(ctx, reader.parse(json_stream, json),
+ errors::InvalidArgument("Could not parse json: ", json_string));
+
+ OP_REQUIRES(
+ ctx, json.isMember("refresh_token") || json.isMember("private_key"),
+ errors::InvalidArgument("JSON format incompatible; did not find fields "
+ "`refresh_token` or `private_key`."));
+
+ auto provider = absl::make_unique<ConstantAuthProvider>(json, ctx->env());
+
+ // Test getting a token
+ string dummy_token;
+ OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token));
+ OP_REQUIRES(ctx, !dummy_token.empty(),
+ errors::InvalidArgument(
+ "Could not retrieve a token with the given credentials."));
+
+ // Set the provider.
+ gcs->underlying()->SetAuthProvider(std::move(provider));
+ }
+
+ private:
+ class ConstantAuthProvider : public AuthProvider {
+ public:
+ ConstantAuthProvider(const Json::Value& json,
+ std::unique_ptr<OAuthClient> oauth_client, Env* env,
+ int64 initial_retry_delay_usec)
+ : json_(json),
+ oauth_client_(std::move(oauth_client)),
+ env_(env),
+ initial_retry_delay_usec_(initial_retry_delay_usec) {}
+
+ ConstantAuthProvider(const Json::Value& json, Env* env)
+ : ConstantAuthProvider(json, absl::make_unique<OAuthClient>(), env,
+ kInitialRetryDelayUsec) {}
+
+ ~ConstantAuthProvider() override {}
+
+ Status GetToken(string* token) override {
+ mutex_lock l(mu_);
+ const uint64 now_sec = env_->NowSeconds();
+
+ if (!current_token_.empty() &&
+ now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
+ *token = current_token_;
+ return Status::OK();
+ }
+ if (json_.isMember("refresh_token")) {
+ TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
+ json_, kOAuthV3Url, &current_token_, &expiration_timestamp_sec_));
+ } else if (json_.isMember("private_key")) {
+ TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
+ json_, kOAuthV4Url, kOAuthScope, &current_token_,
+ &expiration_timestamp_sec_));
+ } else {
+ return errors::FailedPrecondition(
+ "Unexpected content of the JSON credentials file.");
+ }
+
+ *token = current_token_;
+ return Status::OK();
+ }
+
+ private:
+ Json::Value json_;
+ std::unique_ptr<OAuthClient> oauth_client_;
+ Env* env_;
+
+ mutex mu_;
+ string current_token_ GUARDED_BY(mu_);
+ uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
+
+ // The initial delay for exponential backoffs when retrying failed calls.
+ const int64 initial_retry_delay_usec_;
+ TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider);
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU),
+ GcsCredentialsOpKernel);
+
+class GcsBlockCacheOpKernel : public OpKernel {
+ public:
+ explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ // Get a handle to the GCS file system.
+ RetryingGcsFileSystem* gcs = nullptr;
+ OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));
+
+ size_t max_cache_size, block_size, max_staleness;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<size_t>(ctx, "max_cache_size",
+ &max_cache_size));
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<size_t>(ctx, "block_size", &block_size));
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<size_t>(ctx, "max_staleness", &max_staleness));
+
+ if (gcs->underlying()->block_size() == block_size &&
+ gcs->underlying()->max_bytes() == max_cache_size &&
+ gcs->underlying()->max_staleness() == max_staleness) {
+ LOG(INFO) << "Skipping resetting the GCS block cache.";
+ return;
+ }
+ gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size,
+ max_staleness);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU),
+ GcsBlockCacheOpKernel);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
new file mode 100644
index 0000000000..9cf85f5f18
--- /dev/null
+++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("GcsConfigureCredentials")
+ .Input("json: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Configures the credentials used by the GCS client of the local TF runtime.
+
+The json input can be of the format:
+
+1. Refresh Token:
+{
+ "client_id": "<redacted>",
+ "client_secret": "<redacted>",
+ "refresh_token: "<redacted>",
+ "type": "authorized_user",
+}
+
+2. Service Account:
+{
+ "type": "service_account",
+ "project_id": "<redacted>",
+ "private_key_id": "<redacted>",
+ "private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n",
+ "client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com",
+ "client_id": "<REDACTED>",
+ # Some additional fields elided
+}
+
+Note the credentials established through this method are shared across all
+sessions run on this runtime.
+
+Note be sure to feed the inputs to this op to ensure the credentials are not
+stored in a constant op within the graph that might accidentally be checkpointed
+or in other ways be persisted or exfiltrated.
+)doc");
+
+REGISTER_OP("GcsConfigureBlockCache")
+ .Input("max_cache_size: uint64")
+ .Input("block_size: uint64")
+ .Input("max_staleness: uint64")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Re-configures the GCS block cache with the new configuration values.
+
+If the values are the same as already configured values, this op is a no-op. If
+they are different, the current contents of the block cache is dropped, and a
+new block cache is created fresh.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
new file mode 100644
index 0000000000..9ab124ae72
--- /dev/null
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
@@ -0,0 +1,176 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GCS file system configuration for TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.training import training
+
+
+# @tf_export('contrib.cloud.BlockCacheParams')
+class BlockCacheParams(object):
+ """BlockCacheParams is a struct used for configuring the GCS Block Cache."""
+
+ def __init__(self, block_size=None, max_bytes=None, max_staleness=None):
+ self._block_size = block_size or 128 * 1024 * 1024
+ self._max_bytes = max_bytes or 2 * self._block_size
+ self._max_staleness = max_staleness or 0
+
+ @property
+ def block_size(self):
+ return self._block_size
+
+ @property
+ def max_bytes(self):
+ return self._max_bytes
+
+ @property
+ def max_staleness(self):
+ return self._max_staleness
+
+
+# @tf_export('contrib.cloud.ConfigureGcsHook')
+class ConfigureGcsHook(training.SessionRunHook):
+ """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator.
+
+ Example:
+
+ ```
+ sess = tf.Session()
+ refresh_token = raw_input("Refresh token: ")
+ client_secret = raw_input("Client secret: ")
+ client_id = "<REDACTED>"
+ creds = {
+ "client_id": client_id,
+ "refresh_token": refresh_token,
+ "client_secret": client_secret,
+ "type": "authorized_user",
+ }
+ tf.contrib.cloud.configure_gcs(sess, credentials=creds)
+ ```
+
+ """
+
+ def _verify_dictionary(self, creds_dict):
+ if 'refresh_token' in creds_dict or 'private_key' in creds_dict:
+ return True
+ return False
+
+ def __init__(self, credentials=None, block_cache=None):
+ """Constructs a ConfigureGcsHook.
+
+ Args:
+ credentials: A json-formatted string.
+ block_cache: A `BlockCacheParams`
+
+ Raises:
+ ValueError: If credentials is improperly formatted or block_cache is not a
+ BlockCacheParams.
+ """
+ if credentials is not None:
+ if isinstance(credentials, str):
+ try:
+ data = json.loads(credentials)
+ except ValueError as e:
+ raise ValueError('credentials was not a well formed JSON string.', e)
+ if not self._verify_dictionary(data):
+ raise ValueError(
+ 'credentials has neither a "refresh_token" nor a "private_key" '
+ 'field.')
+ elif isinstance(credentials, dict):
+ if not self._verify_dictionary(credentials):
+ raise ValueError('credentials has neither a "refresh_token" nor a '
+ '"private_key" field.')
+ credentials = json.dumps(credentials)
+ else:
+ raise ValueError('credentials is of an unknown type')
+
+ self._credentials = credentials
+
+ if block_cache and not isinstance(block_cache, BlockCacheParams):
+ raise ValueError('block_cache must be an instance of BlockCacheParams.')
+ self._block_cache = block_cache
+
+ def begin(self):
+ if self._credentials:
+ self._credentials_placeholder = array_ops.placeholder(dtypes.string)
+ self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials(
+ self._credentials_placeholder)
+ if self._block_cache:
+ self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache(
+ max_cache_size=self._block_cache.max_bytes,
+ block_size=self._block_cache.block_size,
+ max_staleness=self._block_cache.max_staleness)
+
+ def after_create_session(self, session, coord):
+ del coord
+ if self._credentials_op:
+ session.run(
+ self._credentials_op,
+ feed_dict={self._credentials_placeholder: self._credentials})
+ if self._block_cache_op:
+ session.run(self._block_cache_op)
+
+
+def configure_gcs(session, credentials=None, block_cache=None, device=None):
+ """Configures the GCS file system for a given a session.
+
+ Args:
+ session: A `tf.Session` session that should be used to configure the GCS
+ file system.
+ credentials: [Optional.] A JSON string
+ block_cache: [Optional.] A BlockCacheParams to configure the block cache .
+ device: [Optional.] The device to place the configure ops.
+ """
+
+ def configure(credentials, block_cache):
+ """Helper function to actually configure GCS."""
+ if credentials:
+ if isinstance(credentials, dict):
+ credentials = json.dumps(credentials)
+ placeholder = array_ops.placeholder(dtypes.string)
+ op = gen_gcs_config_ops.gcs_configure_credentials(placeholder)
+ session.run(op, feed_dict={placeholder: credentials})
+ if block_cache:
+ op = gen_gcs_config_ops.gcs_configure_block_cache(
+ max_cache_size=block_cache.max_bytes,
+ block_size=block_cache.block_size,
+ max_staleness=block_cache.max_staleness)
+ session.run(op)
+
+ if device:
+ with ops.device(device):
+ return configure(credentials, block_cache)
+ return configure(credentials, block_cache)
+
+
+def configure_colab_session(session):
+ """ConfigureColabSession configures the GCS file system in Colab.
+
+ Args:
+ session: A `tf.Session` session.
+ """
+ # Read from the application default credentials (adc).
+ with open('/content/datalab/adc.json') as f:
+ data = json.load(f)
+ configure_gcs(session, credentials=data)