diff options
author | Brennan Saeta <saeta@google.com> | 2018-05-30 15:25:46 -0700 |
---|---|---|
committer | Yifei Feng <yifeif@google.com> | 2018-05-30 15:25:46 -0700 |
commit | e469934f1274c7c498e5061995fec425a21c9be8 (patch) | |
tree | ab9c0078f1c1fa5027537096898f560cbd9833fe | |
parent | 176754d6cce54a971c98096f55251870708eea3e (diff) |
Add GCS configure ops.
PiperOrigin-RevId: 198624285
-rw-r--r-- | tensorflow/contrib/cloud/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/__init__.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/kernels/gcs_config_ops.cc | 203 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/ops/gcs_config_ops.cc | 70 | ||||
-rw-r--r-- | tensorflow/contrib/cloud/python/ops/gcs_config_ops.py | 176 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/tf_core_ops.cmake | 1 | ||||
-rwxr-xr-x | tensorflow/contrib/cmake/tf_python.cmake | 2 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.cc | 113 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.h | 48 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system_test.cc | 4 |
12 files changed, 594 insertions, 61 deletions
diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index f3a75e8688..42ba368531 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -15,7 +15,10 @@ load( ) tf_gen_op_libs( - op_lib_names = ["bigquery_reader_ops"], + op_lib_names = [ + "bigquery_reader_ops", + "gcs_config_ops", + ], deps = [ "//tensorflow/core:lib", ], @@ -28,15 +31,25 @@ tf_gen_op_wrapper_py( deps = [":bigquery_reader_ops_op_lib"], ) +tf_gen_op_wrapper_py( + name = "gen_gcs_config_ops", + out = "python/ops/gen_gcs_config_ops.py", + require_shape_functions = True, + visibility = ["//tensorflow:internal"], + deps = [":gcs_config_ops_op_lib"], +) + py_library( name = "cloud_py", srcs = [ "__init__.py", "python/ops/bigquery_reader_ops.py", + "python/ops/gcs_config_ops.py", ], srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", + ":gen_gcs_config_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:io_ops", "//tensorflow/python:util", diff --git a/tensorflow/contrib/cloud/__init__.py b/tensorflow/contrib/cloud/__init__.py index 8870264b95..a6e13ea3ae 100644 --- a/tensorflow/contrib/cloud/__init__.py +++ b/tensorflow/contrib/cloud/__init__.py @@ -20,9 +20,15 @@ from __future__ import print_function # pylint: disable=line-too-long,wildcard-import from tensorflow.contrib.cloud.python.ops.bigquery_reader_ops import * +from tensorflow.contrib.cloud.python.ops.gcs_config_ops import * # pylint: enable=line-too-long,wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['BigQueryReader'] +_allowed_symbols = [ + 'BigQueryReader', + 'ConfigureColabSession', + 'ConfigureGcs', + 'ConfigureGcsHook', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index ff46f0daa8..40160706f7 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -73,3 +73,17 @@ tf_proto_library( srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) + +tf_kernel_library( + name = "gcs_config_ops", + srcs = ["gcs_config_ops.cc"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/platform/cloud:curl_http_request", + "//tensorflow/core/platform/cloud:gcs_file_system", + "//tensorflow/core/platform/cloud:oauth_client", + "@jsoncpp_git//:jsoncpp", + ], +) diff --git a/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc new file mode 100644 index 0000000000..ef4998212e --- /dev/null +++ b/tensorflow/contrib/cloud/kernels/gcs_config_ops.cc @@ -0,0 +1,203 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <sstream> + +#include "include/json/json.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" +#include "tensorflow/core/platform/cloud/gcs_file_system.h" +#include "tensorflow/core/platform/cloud/oauth_client.h" + +namespace tensorflow { +namespace { + +// The default initial delay between retries with exponential backoff. +constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec + +// The minimum time delta between now and the token expiration time +// for the token to be re-used. +constexpr int kExpirationTimeMarginSec = 60; + +// The URL to retrieve the auth bearer token via OAuth with a refresh token. +constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token"; + +// The URL to retrieve the auth bearer token via OAuth with a private key. +constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token"; + +// The authentication token scope to request. +constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform"; + +Status RetrieveGcsFs(OpKernelContext* ctx, RetryingGcsFileSystem** fs) { + DCHECK(fs != nullptr); + *fs = nullptr; + + FileSystem* filesystem = nullptr; + TF_RETURN_IF_ERROR( + ctx->env()->GetFileSystemForFile("gs://fake/file.text", &filesystem)); + if (filesystem == nullptr) { + return errors::FailedPrecondition("The GCS file system is not registered."); + } + + *fs = dynamic_cast<RetryingGcsFileSystem*>(filesystem); + if (*fs == nullptr) { + return errors::Internal( + "The filesystem registered under the 'gs://' scheme was not a " + "tensorflow::RetryingGcsFileSystem*."); + } + return Status::OK(); +} + +template <typename T> +Status ParseScalarArgument(OpKernelContext* ctx, StringPiece argument_name, + T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar<T>()(); + return Status::OK(); +} + +// GcsCredentialsOpKernel overrides the credentials used by the gcs_filesystem. +class GcsCredentialsOpKernel : public OpKernel { + public: + explicit GcsCredentialsOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + string json_string; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "json", &json_string)); + + Json::Value json; + Json::Reader reader; + std::stringstream json_stream(json_string); + OP_REQUIRES(ctx, reader.parse(json_stream, json), + errors::InvalidArgument("Could not parse json: ", json_string)); + + OP_REQUIRES( + ctx, json.isMember("refresh_token") || json.isMember("private_key"), + errors::InvalidArgument("JSON format incompatible; did not find fields " + "`refresh_token` or `private_key`.")); + + auto provider = absl::make_unique<ConstantAuthProvider>(json, ctx->env()); + + // Test getting a token + string dummy_token; + OP_REQUIRES_OK(ctx, provider->GetToken(&dummy_token)); + OP_REQUIRES(ctx, !dummy_token.empty(), + errors::InvalidArgument( + "Could not retrieve a token with the given credentials.")); + + // Set the provider. + gcs->underlying()->SetAuthProvider(std::move(provider)); + } + + private: + class ConstantAuthProvider : public AuthProvider { + public: + ConstantAuthProvider(const Json::Value& json, + std::unique_ptr<OAuthClient> oauth_client, Env* env, + int64 initial_retry_delay_usec) + : json_(json), + oauth_client_(std::move(oauth_client)), + env_(env), + initial_retry_delay_usec_(initial_retry_delay_usec) {} + + ConstantAuthProvider(const Json::Value& json, Env* env) + : ConstantAuthProvider(json, absl::make_unique<OAuthClient>(), env, + kInitialRetryDelayUsec) {} + + ~ConstantAuthProvider() override {} + + Status GetToken(string* token) override { + mutex_lock l(mu_); + const uint64 now_sec = env_->NowSeconds(); + + if (!current_token_.empty() && + now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { + *token = current_token_; + return Status::OK(); + } + if (json_.isMember("refresh_token")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson( + json_, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_)); + } else if (json_.isMember("private_key")) { + TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson( + json_, kOAuthV4Url, kOAuthScope, ¤t_token_, + &expiration_timestamp_sec_)); + } else { + return errors::FailedPrecondition( + "Unexpected content of the JSON credentials file."); + } + + *token = current_token_; + return Status::OK(); + } + + private: + Json::Value json_; + std::unique_ptr<OAuthClient> oauth_client_; + Env* env_; + + mutex mu_; + string current_token_ GUARDED_BY(mu_); + uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0; + + // The initial delay for exponential backoffs when retrying failed calls. + const int64 initial_retry_delay_usec_; + TF_DISALLOW_COPY_AND_ASSIGN(ConstantAuthProvider); + }; +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureCredentials").Device(DEVICE_CPU), + GcsCredentialsOpKernel); + +class GcsBlockCacheOpKernel : public OpKernel { + public: + explicit GcsBlockCacheOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) override { + // Get a handle to the GCS file system. + RetryingGcsFileSystem* gcs = nullptr; + OP_REQUIRES_OK(ctx, RetrieveGcsFs(ctx, &gcs)); + + size_t max_cache_size, block_size, max_staleness; + OP_REQUIRES_OK(ctx, ParseScalarArgument<size_t>(ctx, "max_cache_size", + &max_cache_size)); + OP_REQUIRES_OK(ctx, + ParseScalarArgument<size_t>(ctx, "block_size", &block_size)); + OP_REQUIRES_OK( + ctx, ParseScalarArgument<size_t>(ctx, "max_staleness", &max_staleness)); + + if (gcs->underlying()->block_size() == block_size && + gcs->underlying()->max_bytes() == max_cache_size && + gcs->underlying()->max_staleness() == max_staleness) { + LOG(INFO) << "Skipping resetting the GCS block cache."; + return; + } + gcs->underlying()->ResetFileBlockCache(block_size, max_cache_size, + max_staleness); + } +}; + +REGISTER_KERNEL_BUILDER(Name("GcsConfigureBlockCache").Device(DEVICE_CPU), + GcsBlockCacheOpKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/ops/gcs_config_ops.cc b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc new file mode 100644 index 0000000000..9cf85f5f18 --- /dev/null +++ b/tensorflow/contrib/cloud/ops/gcs_config_ops.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("GcsConfigureCredentials") + .Input("json: string") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Configures the credentials used by the GCS client of the local TF runtime. + +The json input can be of the format: + +1. Refresh Token: +{ + "client_id": "<redacted>", + "client_secret": "<redacted>", + "refresh_token: "<redacted>", + "type": "authorized_user", +} + +2. Service Account: +{ + "type": "service_account", + "project_id": "<redacted>", + "private_key_id": "<redacted>", + "private_key": "------BEGIN PRIVATE KEY-----\n<REDACTED>\n-----END PRIVATE KEY------\n", + "client_email": "<REDACTED>@<REDACTED>.iam.gserviceaccount.com", + "client_id": "<REDACTED>", + # Some additional fields elided +} + +Note the credentials established through this method are shared across all +sessions run on this runtime. + +Note be sure to feed the inputs to this op to ensure the credentials are not +stored in a constant op within the graph that might accidentally be checkpointed +or in other ways be persisted or exfiltrated. +)doc"); + +REGISTER_OP("GcsConfigureBlockCache") + .Input("max_cache_size: uint64") + .Input("block_size: uint64") + .Input("max_staleness: uint64") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Re-configures the GCS block cache with the new configuration values. + +If the values are the same as already configured values, this op is a no-op. If +they are different, the current contents of the block cache is dropped, and a +new block cache is created fresh. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py new file mode 100644 index 0000000000..9ab124ae72 --- /dev/null +++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops.py @@ -0,0 +1,176 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""GCS file system configuration for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json + +from tensorflow.contrib.cloud.python.ops import gen_gcs_config_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training + + +# @tf_export('contrib.cloud.BlockCacheParams') +class BlockCacheParams(object): + """BlockCacheParams is a struct used for configuring the GCS Block Cache.""" + + def __init__(self, block_size=None, max_bytes=None, max_staleness=None): + self._block_size = block_size or 128 * 1024 * 1024 + self._max_bytes = max_bytes or 2 * self._block_size + self._max_staleness = max_staleness or 0 + + @property + def block_size(self): + return self._block_size + + @property + def max_bytes(self): + return self._max_bytes + + @property + def max_staleness(self): + return self._max_staleness + + +# @tf_export('contrib.cloud.ConfigureGcsHook') +class ConfigureGcsHook(training.SessionRunHook): + """ConfigureGcsHook configures GCS when used with Estimator/TPUEstimator. + + Example: + + ``` + sess = tf.Session() + refresh_token = raw_input("Refresh token: ") + client_secret = raw_input("Client secret: ") + client_id = "<REDACTED>" + creds = { + "client_id": client_id, + "refresh_token": refresh_token, + "client_secret": client_secret, + "type": "authorized_user", + } + tf.contrib.cloud.configure_gcs(sess, credentials=creds) + ``` + + """ + + def _verify_dictionary(self, creds_dict): + if 'refresh_token' in creds_dict or 'private_key' in creds_dict: + return True + return False + + def __init__(self, credentials=None, block_cache=None): + """Constructs a ConfigureGcsHook. + + Args: + credentials: A json-formatted string. + block_cache: A `BlockCacheParams` + + Raises: + ValueError: If credentials is improperly formatted or block_cache is not a + BlockCacheParams. + """ + if credentials is not None: + if isinstance(credentials, str): + try: + data = json.loads(credentials) + except ValueError as e: + raise ValueError('credentials was not a well formed JSON string.', e) + if not self._verify_dictionary(data): + raise ValueError( + 'credentials has neither a "refresh_token" nor a "private_key" ' + 'field.') + elif isinstance(credentials, dict): + if not self._verify_dictionary(credentials): + raise ValueError('credentials has neither a "refresh_token" nor a ' + '"private_key" field.') + credentials = json.dumps(credentials) + else: + raise ValueError('credentials is of an unknown type') + + self._credentials = credentials + + if block_cache and not isinstance(block_cache, BlockCacheParams): + raise ValueError('block_cache must be an instance of BlockCacheParams.') + self._block_cache = block_cache + + def begin(self): + if self._credentials: + self._credentials_placeholder = array_ops.placeholder(dtypes.string) + self._credentials_ops = gen_gcs_config_ops.gcs_configure_credentials( + self._credentials_placeholder) + if self._block_cache: + self._block_cache_op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=self._block_cache.max_bytes, + block_size=self._block_cache.block_size, + max_staleness=self._block_cache.max_staleness) + + def after_create_session(self, session, coord): + del coord + if self._credentials_op: + session.run( + self._credentials_op, + feed_dict={self._credentials_placeholder: self._credentials}) + if self._block_cache_op: + session.run(self._block_cache_op) + + +def configure_gcs(session, credentials=None, block_cache=None, device=None): + """Configures the GCS file system for a given a session. + + Args: + session: A `tf.Session` session that should be used to configure the GCS + file system. + credentials: [Optional.] A JSON string + block_cache: [Optional.] A BlockCacheParams to configure the block cache . + device: [Optional.] The device to place the configure ops. + """ + + def configure(credentials, block_cache): + """Helper function to actually configure GCS.""" + if credentials: + if isinstance(credentials, dict): + credentials = json.dumps(credentials) + placeholder = array_ops.placeholder(dtypes.string) + op = gen_gcs_config_ops.gcs_configure_credentials(placeholder) + session.run(op, feed_dict={placeholder: credentials}) + if block_cache: + op = gen_gcs_config_ops.gcs_configure_block_cache( + max_cache_size=block_cache.max_bytes, + block_size=block_cache.block_size, + max_staleness=block_cache.max_staleness) + session.run(op) + + if device: + with ops.device(device): + return configure(credentials, block_cache) + return configure(credentials, block_cache) + + +def configure_colab_session(session): + """ConfigureColabSession configures the GCS file system in Colab. + + Args: + session: A `tf.Session` session. + """ + # Read from the application default credentials (adc). + with open('/content/datalab/adc.json') as f: + data = json.load(f) + configure_gcs(session, credentials=data) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index e558691de4..bc753333db 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -113,6 +113,7 @@ GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_stats "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(gcs_config "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/gcs_config_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(reduce_slice_ops "${tensorflow_source_dir}/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8d24a7ae38..61651f3007 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -420,6 +420,8 @@ GENERATE_PYTHON_OP_LIB("contrib_text_skip_gram_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/text/python/ops/gen_skip_gram_ops.py) GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_gcs_config_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_gcs_config_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index 0fc1e4ae45..67651349ea 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -174,6 +174,7 @@ cc_library( "oauth_client.h", ], copts = tf_copts(), + visibility = ["//tensorflow:__subpackages__"], deps = [ ":curl_http_request", ":http_request", diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index dc12c78a4b..632bb32063 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -290,51 +290,24 @@ Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) { /// A GCS-based implementation of a random access file with an LRU block cache. class GcsRandomAccessFile : public RandomAccessFile { public: - using SignatureGenFun = - std::function<Status(const string& filename, int64* file_signature)>; + using ReadFn = + std::function<Status(const string& filename, uint64 offset, size_t n, + StringPiece* result, char* scratch)>; - GcsRandomAccessFile(const string& filename, FileBlockCache* file_block_cache, - const SignatureGenFun& signature_gen_fun) - : filename_(filename), - file_block_cache_(file_block_cache), - signature_gen_fun_(signature_gen_fun) {} + GcsRandomAccessFile(const string& filename, ReadFn read_fn) + : filename_(filename), read_fn_(std::move(read_fn)) {} /// The implementation of reads with an LRU block cache. Thread safe. Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { - if (file_block_cache_->IsCacheEnabled()) { - int64 signature; - TF_RETURN_IF_ERROR(signature_gen_fun_(filename_, &signature)); - if (!file_block_cache_->ValidateAndUpdateFileSignature(filename_, - signature)) { - VLOG(1) << "File " << filename_ - << " signature has been changed. Refreshing the cache."; - } - } - - *result = StringPiece(); - size_t bytes_transferred; - TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch, - &bytes_transferred)); - *result = StringPiece(scratch, bytes_transferred); - - if (bytes_transferred < n) { - // This is not an error per se. The RandomAccessFile interface expects - // that Read returns OutOfRange if fewer bytes were read than requested. - return errors::OutOfRange("EOF reached, ", result->size(), - " bytes were read out of ", n, - " bytes requested."); - } - return Status::OK(); + return read_fn_(filename_, offset, n, result, scratch); } private: /// The filename of this file. const string filename_; - /// The LRU block cache for this file. - mutable FileBlockCache* file_block_cache_; // not owned - - const SignatureGenFun signature_gen_fun_; + /// The implementation of the read operation (provided by the GCSFileSystem). + const ReadFn read_fn_; }; /// \brief GCS-based implementation of a writeable file. @@ -797,21 +770,50 @@ Status GcsFileSystem::NewRandomAccessFile( const string& fname, std::unique_ptr<RandomAccessFile>* result) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); - result->reset(new GcsRandomAccessFile( - fname, file_block_cache_.get(), - [this, bucket, object](const string& fname, int64* signature) { - GcsFileStat stat; - TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute( - fname, &stat, - [this, bucket, object](const string& fname, GcsFileStat* stat) { - return UncachedStatForObject(fname, bucket, object, stat); - })); - *signature = stat.generation_number; - return Status::OK(); - })); + result->reset(new GcsRandomAccessFile(fname, [this, bucket, object]( + const string& fname, + uint64 offset, size_t n, + StringPiece* result, + char* scratch) { + tf_shared_lock l(block_cache_lock_); + if (file_block_cache_->IsCacheEnabled()) { + GcsFileStat stat; + TF_RETURN_IF_ERROR(stat_cache_->LookupOrCompute( + fname, &stat, + [this, bucket, object](const string& fname, GcsFileStat* stat) { + return UncachedStatForObject(fname, bucket, object, stat); + })); + if (!file_block_cache_->ValidateAndUpdateFileSignature( + fname, stat.generation_number)) { + VLOG(1) + << "File signature has been changed. Refreshing the cache. Path: " + << fname; + } + } + *result = StringPiece(); + size_t bytes_transferred; + TF_RETURN_IF_ERROR( + file_block_cache_->Read(fname, offset, n, scratch, &bytes_transferred)); + *result = StringPiece(scratch, bytes_transferred); + if (bytes_transferred < n) { + return errors::OutOfRange("EOF reached, ", result->size(), + " bytes were read out of ", n, + " bytes requested."); + } + return Status::OK(); + })); return Status::OK(); } +void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes, + size_t max_bytes, + uint64 max_staleness_secs) { + mutex_lock l(block_cache_lock_); + file_block_cache_ = + MakeFileBlockCache(block_size_bytes, max_bytes, max_staleness_secs); + stats_->Configure(this, &throttle_, file_block_cache_.get()); +} + // A helper function to build a FileBlockCache for GcsFileSystem. std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache( size_t block_size, size_t max_bytes, uint64 max_staleness) { @@ -880,6 +882,7 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, } void GcsFileSystem::ClearFileCaches(const string& fname) { + tf_shared_lock l(block_cache_lock_); file_block_cache_->RemoveFile(fname); stat_cache_->Delete(fname); // TODO(rxsang): Remove the patterns that matche the file in @@ -1509,6 +1512,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname, // reclaiming memory once filesystem operations are done (e.g. model is loaded), // or for resetting the filesystem to a consistent state. void GcsFileSystem::FlushCaches() { + tf_shared_lock l(block_cache_lock_); file_block_cache_->Flush(); stat_cache_->Clear(); matching_paths_cache_->Clear(); @@ -1517,8 +1521,15 @@ void GcsFileSystem::FlushCaches() { void GcsFileSystem::SetStats(GcsStatsInterface* stats) { CHECK(stats_ == nullptr) << "SetStats() has already been called."; CHECK(stats != nullptr); + mutex_lock l(block_cache_lock_); stats_ = stats; - stats_->Init(this, &throttle_, file_block_cache_.get()); + stats_->Configure(this, &throttle_, file_block_cache_.get()); +} + +void GcsFileSystem::SetAuthProvider( + std::unique_ptr<AuthProvider> auth_provider) { + mutex_lock l(mu_); + auth_provider_ = std::move(auth_provider); } // Creates an HttpRequest and sets several parameters that are common to all @@ -1531,7 +1542,11 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr<HttpRequest>* request) { } string auth_token; - TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + { + tf_shared_lock l(mu_); + TF_RETURN_IF_ERROR( + AuthProvider::GetToken(auth_provider_.get(), &auth_token)); + } new_request->AddAuthBearerHeader(auth_token); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index d543db1577..74768c98b5 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -43,9 +43,12 @@ class GcsFileSystem; /// time. class GcsStatsInterface { public: - /// Init is called by the GcsFileSystem immediately after being registered. - virtual void Init(GcsFileSystem* fs, GcsThrottle* throttle, - const FileBlockCache* block_cache) = 0; + /// Configure is called by the GcsFileSystem to provide instrumentation hooks. + /// + /// Note: Configure can be called multiple times (e.g. if the block cache is + /// re-initialized). + virtual void Configure(GcsFileSystem* fs, GcsThrottle* throttle, + const FileBlockCache* block_cache) = 0; /// RecordBlockLoadRequest is called to record a block load request is about /// to be made. @@ -132,9 +135,18 @@ class GcsFileSystem : public FileSystem { /// These accessors are mainly for testing purposes, to verify that the /// environment variables that control these parameters are handled correctly. - size_t block_size() const { return file_block_cache_->block_size(); } - size_t max_bytes() const { return file_block_cache_->max_bytes(); } - uint64 max_staleness() const { return file_block_cache_->max_staleness(); } + size_t block_size() { + tf_shared_lock l(block_cache_lock_); + return file_block_cache_->block_size(); + } + size_t max_bytes() { + tf_shared_lock l(block_cache_lock_); + return file_block_cache_->max_bytes(); + } + uint64 max_staleness() { + tf_shared_lock l(block_cache_lock_); + return file_block_cache_->max_staleness(); + } TimeoutConfig timeouts() const { return timeouts_; } string additional_header_name() const { return additional_header_ ? additional_header_->first : ""; @@ -190,6 +202,21 @@ class GcsFileSystem : public FileSystem { Status CreateHttpRequest(std::unique_ptr<HttpRequest>* request); + /// \brief Sets a new AuthProvider on the GCS FileSystem. + /// + /// The new auth provider will be used for all subsequent requests. + void SetAuthProvider(std::unique_ptr<AuthProvider> auth_provider); + + /// \brief Resets the block cache and re-instantiates it with the new values. + /// + /// This method can be used to clear the existing block cache and/or to + /// re-configure the block cache for different values. + /// + /// Note: the existing block cache is not cleaned up until all existing files + /// have been closed. + void ResetFileBlockCache(size_t block_size_bytes, size_t max_bytes, + uint64 max_staleness_secs); + private: // GCS file statistics. struct GcsFileStat { @@ -246,9 +273,14 @@ class GcsFileSystem : public FileSystem { // Clear all the caches related to the file with name `filename`. void ClearFileCaches(const string& fname); - std::unique_ptr<AuthProvider> auth_provider_; + mutex mu_; + std::unique_ptr<AuthProvider> auth_provider_ GUARDED_BY(mu_); std::unique_ptr<HttpRequest::Factory> http_request_factory_; - std::unique_ptr<FileBlockCache> file_block_cache_; + // block_cache_lock_ protects the file_block_cache_ pointer (Note that + // FileBlockCache instances are themselves threadsafe). + mutex block_cache_lock_; + std::unique_ptr<FileBlockCache> file_block_cache_ + GUARDED_BY(block_cache_lock_); std::unique_ptr<GcsDnsCache> dns_cache_; GcsThrottle throttle_; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 3f73b238ad..6a28d9162f 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -2946,8 +2946,8 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { class TestGcsStats : public GcsStatsInterface { public: - void Init(GcsFileSystem* fs, GcsThrottle* throttle, - const FileBlockCache* block_cache) override { + void Configure(GcsFileSystem* fs, GcsThrottle* throttle, + const FileBlockCache* block_cache) override { CHECK(fs_ == nullptr); CHECK(throttle_ == nullptr); CHECK(block_cache_ == nullptr); |