aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-05-30 15:25:46 -0700
committerGravatar Yifei Feng <yifeif@google.com>2018-05-30 15:25:46 -0700
commite469934f1274c7c498e5061995fec425a21c9be8 (patch)
treeab9c0078f1c1fa5027537096898f560cbd9833fe
parent176754d6cce54a971c98096f55251870708eea3e (diff)
Add GCS configure ops.
PiperOrigin-RevId: 198624285
-rw-r--r--tensorflow/contrib/cloud/BUILD15
-rw-r--r--tensorflow/contrib/cloud/__init__.py8
-rw-r--r--tensorflow/contrib/cloud/kernels/BUILD14
-rw-r--r--tensorflow/contrib/cloud/kernels/gcs_config_ops.cc203
-rw-r--r--tensorflow/contrib/cloud/ops/gcs_config_ops.cc70
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops.py176
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/core/platform/cloud/BUILD1
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc113
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h48
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc4
12 files changed, 594 insertions, 61 deletions
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD
index f3a75e8688..42ba368531 100644
--- a/tensorflow/contrib/cloud/BUILD
+++ b/tensorflow/contrib/cloud/BUILD
@@ -15,7 +15,10 @@ load(
)
tf_gen_op_libs(
- op_lib_names = ["bigquery_reader_ops"],
+ op_lib_names = [
+ "bigquery_reader_ops",
+ "gcs_config_ops",
+ ],
deps = [
"//tensorflow/core:lib",
],
@@ -28,15 +31,25 @@ tf_gen_op_wrapper_py(
deps = [":bigquery_reader_ops_op_lib"],
)
+tf_gen_op_wrapper_py(
+ name = "gen_gcs_config_ops",
+ out = "python/ops/gen_gcs_config_ops.py",
+ require_shape_functions = True,
+ visibility = ["//tensorflow:internal"],
+ deps = [":gcs_config_ops_op_lib"],
+)
+
py_library(
name = "cloud_py",
srcs = [
"__init__.py",
"python/ops/bigquery_reader_ops.py",
+ "python/ops/gcs_config_ops.py",
],
srcs_version = "PY2AND3",
deps = [
":gen_bigquery_reader_ops",
+ ":gen_gcs_config_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:io_ops",
"//tensorflow/python:util",
diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py
index 8870264b95..a6e13ea3ae 100644
--- a/tensorflow/contrib/cloud/__init__.py
+++ b/tensorflow/contrib/cloud/__init__.py
@@ -20,9 +20,15 @@ from __future__ import print_function
# pylint: disable=line-too-long,wildcard-import
from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import *
+from tensorflow.contrib.cloud.python.ops.gcs_config_ops import *
# pylint: enable=line-too-long,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ['BigQueryReader']
+_allowed_symbols = [
+ 'BigQueryReader',
+ 'ConfigureColabSession',
+ 'ConfigureGcs',
+ 'ConfigureGcsHook',
+]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD
index ff46f0daa8..40160706f7 100644
--- a/tensorflow/contrib/cloud/kernels/BUILD
+++ b/tensorflow/contrib/cloud/kernels/BUILD
@@ -73,3 +73,17 @@ tf_proto_library(
srcs = ["bigquery_table_partition.proto"],
cc_api_version = 2,
)
+
+tf_kernel_library(
+ name = "gcs_config_ops",
+ srcs = ["gcs_config_ops.cc"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/platform/cloud:curl_http_request",
+ "//tensorflow/core/platform/cloud:gcs_file_system",
+ "//tensorflow/core/platform/cloud:oauth_client",
+ "@jsoncpp_git//:jsoncpp",
+ ],
+)
diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc
new file mode 100644
index 0000000000..ef4998212e
--- /dev/null
+++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc
@@ -0,0 +1,203 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sstream>
+
+#include "include/json/json.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/cloud/curl_http_request.h"
+#include "tensorflow/core/platform/cloud/gcs_file_system.h"
+#include "tensorflow/core/platform/cloud/oauth_client.h"
+
+namespace tensorflow {
+namespace {
+
+// The default initial delay between retries with exponential backoff.
+constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
+
+// The minimum time delta between now and the token expiration time
+// for the token to be re-used.
+constexpr int kExpirationTimeMarginSec = 60;
+
+// The URL to retrieve the auth bearer token via OAuth with a refresh token.
+constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";
+
+// The URL to retrieve the auth bearer token via OAuth with a private key.
+constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
+
+// The authentication token scope to request.
+constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
+
+Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) {
+ DCHECK(fs != nullptr);
+ *fs = nullptr;
+
+ FileSystem* filesystem = nullptr;
+ TF_RETURN_IF_ERROR(
+ ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem));
+ if (filesystem == nullptr) {
+ return errors::FailedPrecondition("The GCS file system is not registered.");
+ }
+
+ *fs = dynamic_cast<RetryingGcsFileSystem*>(filesystem);
+ if (*fs == nullptr) {
+ return errors::Internal(
+ "The filesystem registered under the 'gs://' scheme was not a "
+ "tensorflow::RetryingGcsFileSystem*.");
+ }
+ return Status::OK();
+}
+
+template <typename T>
+Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name,
+ T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+}
+
+// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem.
+class GcsCredentialsOpKernel : public OpKernel {
+ public:
+ explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ // Get a handle to the GCS file system.
+ RetryingGcsFileSystem* gcs = nullptr;
+ OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));
+
+ string json_string;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "json", &json_string));
+
+ Json::Value json;
+ Json::Reader reader;
+ std::stringstream json_stream(json_string);
+ OP_REQUIRES(ctx, reader.parse(json_stream, json),
+ errors::InvalidArgument("Could not parse json: ", json_string));
+
+ OP_REQUIRES(
+ ctx, json.isMember("refresh_token") || json.isMember("private_key"),
+ errors::InvalidArgument("JSON format incompatible; did not find fields "
+ "`refresh_token` or `private_key`."));
+
+ auto provider = absl::make_unique<ConstantAuthProvider>(json, ctx->env());
+
+ // Test getting a token
+ string dummy_token;
+ OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token));
+ OP_REQUIRES(ctx, !dummy_token.empty(),
+ errors::InvalidArgument(
+ "Could not retrieve a token with the given credentials."));
+
+ // Set the provider.
+ gcs->underlying()->SetAuthProvider(std::move(provider));
+ }
+
+ private:
+ class ConstantAuthProvider : public AuthProvider {
+ public:
+ ConstantAuthProvider(const Json::Value& json,
+ std::unique_ptr<OAuthClient> oauth_client, Env* env,
+ int64 initial_retry_delay_usec)
+ : json_(json),
+ oauth_client_(std::move(oauth_client)),
+ env_(env),
+ initial_retry_delay_usec_(initial_retry_delay_usec) {}
+
+ ConstantAuthProvider(const Json::Value& json, Env* env)
+ : ConstantAuthProvider(json, absl::make_unique<OAuthClient>(), env,
+ kInitialRetryDelayUsec) {}
+
+ ~ConstantAuthProvider() override {}
+
+ Status GetToken(string* token) override {
+ mutex_lock l(mu_);
+ const uint64 now_sec = env_->NowSeconds();
+
+ if (!current_token_.empty() &&
+ now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
+ *token = current_token_;
+ return Status::OK();
+ }
+ if (json_.isMember("refresh_token")) {
+ TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
+ json_, kOAuthV3Url, &current_token_, &expiration_timestamp_sec_));
+ } else if (json_.isMember("private_key")) {
+ TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
+ json_, kOAuthV4Url, kOAuthScope, &current_token_,
+ &expiration_timestamp_sec_));
+ } else {
+ return errors::FailedPrecondition(
+ "Unexpected content of the JSON credentials file.");
+ }
+
+ *token = current_token_;
+ return Status::OK();
+ }
+
+ private:
+ Json::Value json_;
+ std::unique_ptr<OAuthClient> oauth_client_;
+ Env* env_;
+
+ mutex mu_;
+ string current_token_ GUARDED_BY(mu_);
+ uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
+
+ // The initial delay for exponential backoffs when retrying failed calls.
+ const int64 initial_retry_delay_usec_;
+ TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider);
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU),
+ GcsCredentialsOpKernel);
+
+class GcsBlockCacheOpKernel : public OpKernel {
+ public:
+ explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) override {
+ // Get a handle to the GCS file system.
+ RetryingGcsFileSystem* gcs = nullptr;
+ OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs));
+
+ size_t max_cache_size, block_size, max_staleness;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<size_t>(ctx, "max_cache_size",
+ &max_cache_size));
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<size_t>(ctx, "block_size", &block_size));
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<size_t>(ctx, "max_staleness", &max_staleness));
+
+ if (gcs->underlying()->block_size() == block_size &&
+ gcs->underlying()->max_bytes() == max_cache_size &&
+ gcs->underlying()->max_staleness() == max_staleness) {
+ LOG(INFO) << "Skipping resetting the GCS block cache.";
+ return;
+ }
+ gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size,
+ max_staleness);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU),
+ GcsBlockCacheOpKernel);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
new file mode 100644
index 0000000000..9cf85f5f18
--- /dev/null
+++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("GcsConfigureCredentials")
+ .Input("json: string")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Configures the credentials used by the GCS client of the local TF runtime.
+
+The json input can be of the format:
+
+1. Refresh Token:
+{
+ "client_id": "<redacted>",
+ "client_secret": "<redacted>",
+ "refresh_token: "<redacted>",
+ "type": "authorized_user",
+}
+
+2. Service Account:
+{
+ "type": "service_account",
+ "project_id": "<redacted>",
+ "private_key_id": "<redacted>",
+ "private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n",
+ "client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com",
+ "client_id": "<REDACTED>",
+ # Some additional fields elided
+}
+
+Note the credentials established through this method are shared across all
+sessions run on this runtime.
+
+Note be sure to feed the inputs to this op to ensure the credentials are not
+stored in a constant op within the graph that might accidentally be checkpointed
+or in other ways be persisted or exfiltrated.
+)doc");
+
+REGISTER_OP("GcsConfigureBlockCache")
+ .Input("max_cache_size: uint64")
+ .Input("block_size: uint64")
+ .Input("max_staleness: uint64")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Re-configures the GCS block cache with the new configuration values.
+
+If the values are the same as already configured values, this op is a no-op. If
+they are different, the current contents of the block cache is dropped, and a
+new block cache is created fresh.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
new file mode 100644
index 0000000000..9ab124ae72
--- /dev/null
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py
@@ -0,0 +1,176 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GCS file system configuration for TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+
+from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.training import training
+
+
+# @tf_export('contrib.cloud.BlockCacheParams')
+class BlockCacheParams(object):
+ """BlockCacheParams is a struct used for configuring the GCS Block Cache."""
+
+ def __init__(self, block_size=None, max_bytes=None, max_staleness=None):
+ self._block_size = block_size or 128 * 1024 * 1024
+ self._max_bytes = max_bytes or 2 * self._block_size
+ self._max_staleness = max_staleness or 0
+
+ @property
+ def block_size(self):
+ return self._block_size
+
+ @property
+ def max_bytes(self):
+ return self._max_bytes
+
+ @property
+ def max_staleness(self):
+ return self._max_staleness
+
+
+# @tf_export('contrib.cloud.ConfigureGcsHook')
+class ConfigureGcsHook(training.SessionRunHook):
+ """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator.
+
+ Example:
+
+ ```
+ sess = tf.Session()
+ refresh_token = raw_input("Refresh token: ")
+ client_secret = raw_input("Client secret: ")
+ client_id = "<REDACTED>"
+ creds = {
+ "client_id": client_id,
+ "refresh_token": refresh_token,
+ "client_secret": client_secret,
+ "type": "authorized_user",
+ }
+ tf.contrib.cloud.configure_gcs(sess, credentials=creds)
+ ```
+
+ """
+
+ def _verify_dictionary(self, creds_dict):
+ if 'refresh_token' in creds_dict or 'private_key' in creds_dict:
+ return True
+ return False
+
+ def __init__(self, credentials=None, block_cache=None):
+ """Constructs a ConfigureGcsHook.
+
+ Args:
+ credentials: A json-formatted string.
+ block_cache: A `BlockCacheParams`
+
+ Raises:
+ ValueError: If credentials is improperly formatted or block_cache is not a
+ BlockCacheParams.
+ """
+ if credentials is not None:
+ if isinstance(credentials, str):
+ try:
+ data = json.loads(credentials)
+ except ValueError as e:
+ raise ValueError('credentials was not a well formed JSON string.', e)
+ if not self._verify_dictionary(data):
+ raise ValueError(
+ 'credentials has neither a "refresh_token" nor a "private_key" '
+ 'field.')
+ elif isinstance(credentials, dict):
+ if not self._verify_dictionary(credentials):
+ raise ValueError('credentials has neither a "refresh_token" nor a '
+ '"private_key" field.')
+ credentials = json.dumps(credentials)
+ else:
+ raise ValueError('credentials is of an unknown type')
+
+ self._credentials = credentials
+
+ if block_cache and not isinstance(block_cache, BlockCacheParams):
+ raise ValueError('block_cache must be an instance of BlockCacheParams.')
+ self._block_cache = block_cache
+
+ def begin(self):
+ if self._credentials:
+ self._credentials_placeholder = array_ops.placeholder(dtypes.string)
+ self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials(
+ self._credentials_placeholder)
+ if self._block_cache:
+ self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache(
+ max_cache_size=self._block_cache.max_bytes,
+ block_size=self._block_cache.block_size,
+ max_staleness=self._block_cache.max_staleness)
+
+ def after_create_session(self, session, coord):
+ del coord
+ if self._credentials_op:
+ session.run(
+ self._credentials_op,
+ feed_dict={self._credentials_placeholder: self._credentials})
+ if self._block_cache_op:
+ session.run(self._block_cache_op)
+
+
+def configure_gcs(session, credentials=None, block_cache=None, device=None):
+ """Configures the GCS file system for a given a session.
+
+ Args:
+ session: A `tf.Session` session that should be used to configure the GCS
+ file system.
+ credentials: [Optional.] A JSON string
+ block_cache: [Optional.] A BlockCacheParams to configure the block cache .
+ device: [Optional.] The device to place the configure ops.
+ """
+
+ def configure(credentials, block_cache):
+ """Helper function to actually configure GCS."""
+ if credentials:
+ if isinstance(credentials, dict):
+ credentials = json.dumps(credentials)
+ placeholder = array_ops.placeholder(dtypes.string)
+ op = gen_gcs_config_ops.gcs_configure_credentials(placeholder)
+ session.run(op, feed_dict={placeholder: credentials})
+ if block_cache:
+ op = gen_gcs_config_ops.gcs_configure_block_cache(
+ max_cache_size=block_cache.max_bytes,
+ block_size=block_cache.block_size,
+ max_staleness=block_cache.max_staleness)
+ session.run(op)
+
+ if device:
+ with ops.device(device):
+ return configure(credentials, block_cache)
+ return configure(credentials, block_cache)
+
+
+def configure_colab_session(session):
+ """ConfigureColabSession configures the GCS file system in Colab.
+
+ Args:
+ session: A `tf.Session` session.
+ """
+ # Read from the application default credentials (adc).
+ with open('/content/datalab/adc.json') as f:
+ data = json.load(f)
+ configure_gcs(session, credentials=data)
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index e558691de4..bc753333db 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor
GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}")
GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc")
+GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc")
########################################################
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 8d24a7ae38..61651f3007 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -420,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py)
+GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py)
GENERATE_PYTHON_OP_LIB("stateless_random_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py)
GENERATE_PYTHON_OP_LIB("debug_ops"
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
index 0fc1e4ae45..67651349ea 100644
--- a/tensorflow/core/platform/cloud/BUILD
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -174,6 +174,7 @@ cc_library(
"oauth_client.h",
],
copts = tf_copts(),
+ visibility = ["//tensorflow:__subpackages__"],
deps = [
":curl_http_request",
":http_request",
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index dc12c78a4b..632bb32063 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -290,51 +290,24 @@ Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) {
/// A GCS-based implementation of a random access file with an LRU block cache.
class GcsRandomAccessFile : public RandomAccessFile {
public:
- using SignatureGenFun =
- std::function<Status(const string& filename, int64* file_signature)>;
+ using ReadFn =
+ std::function<Status(const string& filename, uint64 offset, size_t n,
+ StringPiece* result, char* scratch)>;
- GcsRandomAccessFile(const string& filename, FileBlockCache* file_block_cache,
- const SignatureGenFun& signature_gen_fun)
- : filename_(filename),
- file_block_cache_(file_block_cache),
- signature_gen_fun_(signature_gen_fun) {}
+ GcsRandomAccessFile(const string& filename, ReadFn read_fn)
+ : filename_(filename), read_fn_(std::move(read_fn)) {}
/// The implementation of reads with an LRU block cache. Thread safe.
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
- if (file_block_cache_->IsCacheEnabled()) {
- int64 signature;
- TF_RETURN_IF_ERROR(signature_gen_fun_(filename_, &signature));
- if (!file_block_cache_->ValidateAndUpdateFileSignature(filename_,
- signature)) {
- VLOG(1) << "File " << filename_
- << " signature has been changed. Refreshing the cache.";
- }
- }
-
- *result = StringPiece();
- size_t bytes_transferred;
- TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch,
- &bytes_transferred));
- *result = StringPiece(scratch, bytes_transferred);
-
- if (bytes_transferred < n) {
- // This is not an error per se. The RandomAccessFile interface expects
- // that Read returns OutOfRange if fewer bytes were read than requested.
- return errors::OutOfRange("EOF reached, ", result->size(),
- " bytes were read out of ", n,
- " bytes requested.");
- }
- return Status::OK();
+ return read_fn_(filename_, offset, n, result, scratch);
}
private:
/// The filename of this file.
const string filename_;
- /// The LRU block cache for this file.
- mutable FileBlockCache* file_block_cache_; // not owned
-
- const SignatureGenFun signature_gen_fun_;
+ /// The implementation of the read operation (provided by the GCSFileSystem).
+ const ReadFn read_fn_;
};
/// \brief GCS-based implementation of a writeable file.
@@ -797,21 +770,50 @@ Status GcsFileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
- result->reset(new GcsRandomAccessFile(
- fname, file_block_cache_.get(),
- [this, bucket, object](const string& fname, int64* signature) {
- GcsFileStat stat;
- TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(
- fname, &stat,
- [this, bucket, object](const string& fname, GcsFileStat* stat) {
- return UncachedStatForObject(fname, bucket, object, stat);
- }));
- *signature = stat.generation_number;
- return Status::OK();
- }));
+ result->reset(new GcsRandomAccessFile(fname, [this, bucket, object](
+ const string& fname,
+ uint64 offset, size_t n,
+ StringPiece* result,
+ char* scratch) {
+ tf_shared_lock l(block_cache_lock_);
+ if (file_block_cache_->IsCacheEnabled()) {
+ GcsFileStat stat;
+ TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute(
+ fname, &stat,
+ [this, bucket, object](const string& fname, GcsFileStat* stat) {
+ return UncachedStatForObject(fname, bucket, object, stat);
+ }));
+ if (!file_block_cache_->ValidateAndUpdateFileSignature(
+ fname, stat.generation_number)) {
+ VLOG(1)
+ << "File signature has been changed. Refreshing the cache. Path: "
+ << fname;
+ }
+ }
+ *result = StringPiece();
+ size_t bytes_transferred;
+ TF_RETURN_IF_ERROR(
+ file_block_cache_->Read(fname, offset, n, scratch, &bytes_transferred));
+ *result = StringPiece(scratch, bytes_transferred);
+ if (bytes_transferred < n) {
+ return errors::OutOfRange("EOF reached, ", result->size(),
+ " bytes were read out of ", n,
+ " bytes requested.");
+ }
+ return Status::OK();
+ }));
return Status::OK();
}
+void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes,
+ size_t max_bytes,
+ uint64 max_staleness_secs) {
+ mutex_lock l(block_cache_lock_);
+ file_block_cache_ =
+ MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs);
+ stats_->Configure(this, &throttle_, file_block_cache_.get());
+}
+
// A helper function to build a FileBlockCache for GcsFileSystem.
std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
size_t block_size, size_t max_bytes, uint64 max_staleness) {
@@ -880,6 +882,7 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
}
void GcsFileSystem::ClearFileCaches(const string& fname) {
+ tf_shared_lock l(block_cache_lock_);
file_block_cache_->RemoveFile(fname);
stat_cache_->Delete(fname);
// TODO(rxsang): Remove the patterns that matche the file in
@@ -1509,6 +1512,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname,
// reclaiming memory once filesystem operations are done (e.g. model is loaded),
// or for resetting the filesystem to a consistent state.
void GcsFileSystem::FlushCaches() {
+ tf_shared_lock l(block_cache_lock_);
file_block_cache_->Flush();
stat_cache_->Clear();
matching_paths_cache_->Clear();
@@ -1517,8 +1521,15 @@ void GcsFileSystem::FlushCaches() {
void GcsFileSystem::SetStats(GcsStatsInterface* stats) {
CHECK(stats_ == nullptr) << "SetStats() has already been called.";
CHECK(stats != nullptr);
+ mutex_lock l(block_cache_lock_);
stats_ = stats;
- stats_->Init(this, &throttle_, file_block_cache_.get());
+ stats_->Configure(this, &throttle_, file_block_cache_.get());
+}
+
+void GcsFileSystem::SetAuthProvider(
+ std::unique_ptr<AuthProvider> auth_provider) {
+ mutex_lock l(mu_);
+ auth_provider_ = std::move(auth_provider);
}
// Creates an HttpRequest and sets several parameters that are common to all
@@ -1531,7 +1542,11 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) {
}
string auth_token;
- TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
+ {
+ tf_shared_lock l(mu_);
+ TF_RETURN_IF_ERROR(
+ AuthProvider::GetToken(auth_provider_.get(), &auth_token));
+ }
new_request->AddAuthBearerHeader(auth_token);
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index d543db1577..74768c98b5 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -43,9 +43,12 @@ class GcsFileSystem;
/// time.
class GcsStatsInterface {
public:
- /// Init is called by the GcsFileSystem immediately after being registered.
- virtual void Init(GcsFileSystem* fs, GcsThrottle* throttle,
- const FileBlockCache* block_cache) = 0;
+ /// Configure is called by the GcsFileSystem to provide instrumentation hooks.
+ ///
+ /// Note: Configure can be called multiple times (e.g. if the block cache is
+ /// re-initialized).
+ virtual void Configure(GcsFileSystem* fs, GcsThrottle* throttle,
+ const FileBlockCache* block_cache) = 0;
/// RecordBlockLoadRequest is called to record a block load request is about
/// to be made.
@@ -132,9 +135,18 @@ class GcsFileSystem : public FileSystem {
/// These accessors are mainly for testing purposes, to verify that the
/// environment variables that control these parameters are handled correctly.
- size_t block_size() const { return file_block_cache_->block_size(); }
- size_t max_bytes() const { return file_block_cache_->max_bytes(); }
- uint64 max_staleness() const { return file_block_cache_->max_staleness(); }
+ size_t block_size() {
+ tf_shared_lock l(block_cache_lock_);
+ return file_block_cache_->block_size();
+ }
+ size_t max_bytes() {
+ tf_shared_lock l(block_cache_lock_);
+ return file_block_cache_->max_bytes();
+ }
+ uint64 max_staleness() {
+ tf_shared_lock l(block_cache_lock_);
+ return file_block_cache_->max_staleness();
+ }
TimeoutConfig timeouts() const { return timeouts_; }
string additional_header_name() const {
return additional_header_ ? additional_header_->first : "";
@@ -190,6 +202,21 @@ class GcsFileSystem : public FileSystem {
Status CreateHttpRequest(std::unique_ptr<HttpRequest>* request);
+ /// \brief Sets a new AuthProvider on the GCS FileSystem.
+ ///
+ /// The new auth provider will be used for all subsequent requests.
+ void SetAuthProvider(std::unique_ptr<AuthProvider> auth_provider);
+
+ /// \brief Resets the block cache and re-instantiates it with the new values.
+ ///
+ /// This method can be used to clear the existing block cache and/or to
+ /// re-configure the block cache for different values.
+ ///
+ /// Note: the existing block cache is not cleaned up until all existing files
+ /// have been closed.
+ void ResetFileBlockCache(size_t block_size_bytes, size_t max_bytes,
+ uint64 max_staleness_secs);
+
private:
// GCS file statistics.
struct GcsFileStat {
@@ -246,9 +273,14 @@ class GcsFileSystem : public FileSystem {
// Clear all the caches related to the file with name `filename`.
void ClearFileCaches(const string& fname);
- std::unique_ptr<AuthProvider> auth_provider_;
+ mutex mu_;
+ std::unique_ptr<AuthProvider> auth_provider_ GUARDED_BY(mu_);
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
- std::unique_ptr<FileBlockCache> file_block_cache_;
+ // block_cache_lock_ protects the file_block_cache_ pointer (Note that
+ // FileBlockCache instances are themselves threadsafe).
+ mutex block_cache_lock_;
+ std::unique_ptr<FileBlockCache> file_block_cache_
+ GUARDED_BY(block_cache_lock_);
std::unique_ptr<GcsDnsCache> dns_cache_;
GcsThrottle throttle_;
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 3f73b238ad..6a28d9162f 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -2946,8 +2946,8 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
class TestGcsStats : public GcsStatsInterface {
public:
- void Init(GcsFileSystem* fs, GcsThrottle* throttle,
- const FileBlockCache* block_cache) override {
+ void Configure(GcsFileSystem* fs, GcsThrottle* throttle,
+ const FileBlockCache* block_cache) override {
CHECK(fs_ == nullptr);
CHECK(throttle_ == nullptr);
CHECK(block_cache_ == nullptr);