diff options
author | Akshay Agrawal <akshayka@google.com> | 2018-08-10 14:44:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-10 14:52:54 -0700 |
commit | 39b6df56193d6fc00b49634ba255ad24e52e9e90 (patch) | |
tree | f6cf78232e7296fc28c9217d67d58dfd0deed5c8 | |
parent | 0b36ff79021b907f5447bfcbaa060dbdc2114c67 (diff) |
Make FunctionLibraryDefinition thread-safe.
The eager runtime mutates the FunctionLibraryRuntime's FunctionLibraryDefinition, which is shared across threads; at the same time, OpKernels might read the FunctionLibraryDefinition. This is not thread-safe unless FunctionLibraryDefinition is thread-safe.
This change makes FunctionLibraryDefinition, which is basically a map from function names to FunctionDefs, thread-safe. This is almost entirely accomplished by guarding the map with a mutex. There is however one complication: Find and RemoveFunction cannot be made thread-safe in a straightforward way (Find returns a raw pointer to a FunctionDef while Remove can delete the corresponding FunctionDef). In light of the fact that clients only ever call RemoveFunction when they in fact want to replace an existing function with a new one, we make the following modifications to FunctionLibraryDefinition's API:
1. A Contains method is added to check for the existence of a function.
2. A ReplaceFunction method is added.
3. RemoveFunction and RemoveGradient are made private.
We also update clients of the FunctionLibraryDefinition to use Contains & ReplaceFunction instead of Find and RemoveFunction.
PiperOrigin-RevId: 208271076
-rw-r--r-- | tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc | 3 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/context.h | 2 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/execute.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/eager/kernel_and_device.h | 7 | ||||
-rw-r--r-- | tensorflow/core/common_runtime/graph_execution_state.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/framework/function.cc | 65 | ||||
-rw-r--r-- | tensorflow/core/framework/function.h | 79 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/meta_optimizer.cc | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 56 |
9 files changed, 169 insertions, 68 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index fdd71c6a58..f150bf1819 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -1161,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef( strings::StrCat("replace_encapsulate_fdef_", name), fdef); } - TF_RETURN_IF_ERROR(library->RemoveFunction(name)); - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef)); return Status::OK(); } diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index ebaf500bb3..21c5bdf8e9 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -134,8 +134,6 @@ class EagerContext { Rendezvous* GetRendezvous() { return rendezvous_; } - mutex* FunctionsMu() { return &functions_mu_; } - const tensorflow::DeviceMgr* local_device_mgr() const { return (local_device_manager_ != nullptr) ? local_device_manager_.get() : local_unowned_device_manager_; diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 8eaa6e4429..46065f399c 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -300,12 +300,6 @@ Status EagerLocalExecute(EagerOperation* op, << device->name(); } kernel = new KernelAndDevice(ctx->GetRendezvous()); - // Knowledge of the implementation of Init (and in-turn - // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def - // will be accessed, so grab on to the lock. - // See WARNING comment in Execute (before kernel->Run) - would be nice to - // rework to avoid this subtlety. - tf_shared_lock l(*ctx->FunctionsMu()); auto* flr = ctx->func_lib(device); if (flr == nullptr) { @@ -646,15 +640,8 @@ Status EagerExecute(EagerContext* ctx, Device* device, TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor)); inputs[i] = *input_tensor; } - // WARNING: kernel->Run utilizes the FunctionLibraryRuntime - // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def. - // But knowledge of the implementation - // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by - // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. - // This is quite subtle. Re-work things to make this better? (Would it make - // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats - // for ops which are a part of functions. + // TODO(apassos) figure out how to record stats for ops which are a part of + // functions. // TODO(agarwal): change Run to take vector of handles ? ScopedStepContainer* container = ctx->StepContainer(); if (container == nullptr) { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 751cf687b2..0ef419cbaa 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -49,13 +49,6 @@ class KernelAndDevice { // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. - // - // TODO(ashankar): Figure out thread-safety concerns around - // FunctionLibraryRuntime (in particular, how the underlying - // FunctionLibraryDefinition might be mutated by another thread as new - // functions are registered with it). Conservatively, thread-safe usage of - // the FunctionLibraryRuntime is pushed on to the caller (see locking in - // c_api.cc). static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, std::function<void(std::function<void()>)>* runner, KernelAndDevice* out); diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 9c9eacb5b5..c23b7d3699 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -643,10 +643,9 @@ Status GraphExecutionState::OptimizeGraph( for (const FunctionDef& fdef : new_graph.library().function()) { const string& func_name = fdef.signature().name(); - if ((*optimized_flib)->Find(func_name)) { + if ((*optimized_flib)->Contains(func_name)) { VLOG(3) << "Replace function: name=" << func_name; - TF_RETURN_IF_ERROR((*optimized_flib)->RemoveFunction(func_name)); - TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef)); + TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef)); } else { VLOG(3) << "Add new function: name=" << func_name; TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef)); diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 57bcc0f513..6b92e10d76 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -920,10 +920,12 @@ FunctionLibraryDefinition::FunctionDefAndOpRegistration:: FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionLibraryDefinition& other) - : default_registry_(other.default_registry_), func_grad_(other.func_grad_) { + : default_registry_(other.default_registry_) { + tf_shared_lock l(other.mu_); for (const auto& it : other.function_defs_) { TF_CHECK_OK(AddFunctionDef(it.second->fdef)); } + func_grad_ = other.func_grad_; } FunctionLibraryDefinition::FunctionLibraryDefinition( @@ -943,8 +945,19 @@ FunctionLibraryDefinition::FunctionLibraryDefinition( FunctionLibraryDefinition::~FunctionLibraryDefinition() {} -const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { - auto iter = function_defs_.find(name); +bool FunctionLibraryDefinition::Contains(const string& func) const { + tf_shared_lock l(mu_); + return function_defs_.find(func) != function_defs_.end(); +} + +const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const { + tf_shared_lock l(mu_); + return FindHelper(func); +} + +const FunctionDef* FunctionLibraryDefinition::FindHelper( + const string& func) const { + auto iter = function_defs_.find(func); if (iter == function_defs_.end()) { return nullptr; } else { @@ -953,6 +966,7 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { } Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { + mutex_lock l(mu_); bool added; return AddFunctionDefHelper(fdef, &added); } @@ -984,6 +998,7 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef, } Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { + mutex_lock l(mu_); bool added; return AddGradientDefHelper(grad, &added); } @@ -1009,13 +1024,17 @@ Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, Status FunctionLibraryDefinition::AddLibrary( const FunctionLibraryDefinition& other) { + // Clone `other` to ensure thread-safety (grabbing `other`'s lock for + // the duration of the function could lead to deadlock). + FunctionLibraryDefinition clone(other); + mutex_lock l(mu_); // Remember the funcs and grads that we added successfully so that // we can roll them back on error. std::vector<string> funcs; std::vector<string> funcs_with_grads; Status s; bool added; - for (auto iter : other.function_defs_) { + for (auto iter : clone.function_defs_) { s = AddFunctionDefHelper(iter.second->fdef, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); @@ -1025,7 +1044,7 @@ Status FunctionLibraryDefinition::AddLibrary( funcs.push_back(iter.second->fdef.signature().name()); } } - for (auto iter : other.func_grad_) { + for (auto iter : clone.func_grad_) { GradientDef grad; grad.set_function_name(iter.first); grad.set_gradient_func(iter.second); @@ -1045,6 +1064,7 @@ Status FunctionLibraryDefinition::AddLibrary( const FunctionDefLibrary& lib_def) { // Remember the funcs and grads that we added successfully so that // we can roll them back on error. + mutex_lock l(mu_); std::vector<string> funcs; std::vector<string> funcs_with_grads; Status s; @@ -1072,6 +1092,15 @@ Status FunctionLibraryDefinition::AddLibrary( return Status::OK(); } +Status FunctionLibraryDefinition::ReplaceFunction(const string& func, + const FunctionDef& fdef) { + mutex_lock l(mu_); + bool added; + TF_RETURN_IF_ERROR(RemoveFunction(func)); + TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added)); + return Status::OK(); +} + Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); if (i == function_defs_.end()) { @@ -1106,11 +1135,17 @@ void FunctionLibraryDefinition::Remove( } string FunctionLibraryDefinition::FindGradient(const string& func) const { + tf_shared_lock l(mu_); + return gtl::FindWithDefault(func_grad_, func, ""); +} + +string FunctionLibraryDefinition::FindGradientHelper(const string& func) const { return gtl::FindWithDefault(func_grad_, func, ""); } Status FunctionLibraryDefinition::LookUp( const string& op, const OpRegistrationData** op_reg_data) const { + tf_shared_lock l(mu_); auto iter = function_defs_.find(op); if (iter != function_defs_.end()) { *op_reg_data = &iter->second->op_registration_data; @@ -1134,18 +1169,22 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( return nullptr; } const string& func_name = forward_func_attrs->name(); - const string& grad_name = FindGradient(func_name); - // If 'func' has a user-defined gradient function, uses the grad - // function's attrs to see if noinline is specified. Otherwise, - // uses func's attrs. - if (!grad_name.empty()) { - return Find(grad_name); - } - return Find(func_name); + { + tf_shared_lock l(mu_); + const string& grad_name = FindGradientHelper(func_name); + // If 'func' has a user-defined gradient function, uses the grad + // function's attrs to see if noinline is specified. Otherwise, + // uses func's attrs. + if (!grad_name.empty()) { + return FindHelper(grad_name); + } + return FindHelper(func_name); + } } FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { FunctionDefLibrary lib; + tf_shared_lock l(mu_); for (const auto& f : function_defs_) { *lib.add_function() = f.second->fdef; } diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 31a816ac5f..c81f4a4450 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -288,8 +289,11 @@ class FunctionCallFrame : public CallFrameInterface { // Helper to maintain a map between function names in a given // FunctionDefLibrary and function definitions. +// +// This class is thread-safe. class FunctionLibraryDefinition : public OpRegistryInterface { public: + // Note: This constructor grabs `lib_def`'s lock in shared mode. explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def); FunctionLibraryDefinition(const OpRegistryInterface* default_registry, const FunctionDefLibrary& lib_def); @@ -298,9 +302,15 @@ class FunctionLibraryDefinition : public OpRegistryInterface { FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) = delete; + // Returns True if the library contains `func`, False otherwise. + bool Contains(const string& func) const; + // Returns nullptr if "func" is not defined in "lib_def". Otherwise, // returns its definition proto. - const FunctionDef* Find(const string& func) const; + // + // NB: This function returns a borrowed pointer, which can be invalidated by a + // subsequent call to `ReplaceFunction()` with the given name. + const FunctionDef* Find(const string& func) const LOCKS_EXCLUDED(mu_); // Adds function definition 'fdef' to this function library. // Returns status 'ok' on success, or error otherwise. This is a no-op if @@ -308,45 +318,45 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // If 'fdef' is successfully added to the library, it will be accessible // from 'LookUp' and included in the proto returned by 'ToProto'. // This operation is atomic. - Status AddFunctionDef(const FunctionDef& fdef); + Status AddFunctionDef(const FunctionDef& fdef) LOCKS_EXCLUDED(mu_); // Adds gradient definition 'grad' to this function library. // This is a no-op if 'grad' already exists in this function library. // If 'grad' is successfully added, it will be accessible via 'FindGradient' // and included in the proto returned by 'ToProto'. // This operation is atomic. - Status AddGradientDef(const GradientDef& grad); + Status AddGradientDef(const GradientDef& grad) LOCKS_EXCLUDED(mu_); - // Remove function `func` from the library. Returns non-OK Status unless - // `func` is in the library. - Status RemoveFunction(const string& func); - - // Remove gradient of function `func` from the library. Returns non-OK Status - // unless `func` has a gradient. - Status RemoveGradient(const string& func); + // Replaces the function corresponding to `func` with `fdef`. Returns + // a non-OK status if "func" was not found in the library, OK otherwise. + Status ReplaceFunction(const string& func, const FunctionDef& fdef); // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. - Status AddLibrary(const FunctionLibraryDefinition& other); + Status AddLibrary(const FunctionLibraryDefinition& other) LOCKS_EXCLUDED(mu_); // Adds the functions and gradients in 'lib_def' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. - Status AddLibrary(const FunctionDefLibrary& lib_def); + Status AddLibrary(const FunctionDefLibrary& lib_def) LOCKS_EXCLUDED(mu_); // If the gradient function for 'func' is specified explicitly in // the library, returns the gradient function name. Otherwise, // returns an empty string. - string FindGradient(const string& func) const; + string FindGradient(const string& func) const LOCKS_EXCLUDED(mu_); // OpRegistryInterface method. Useful for constructing a Graph. // // If "op" is defined in the library, returns its signature. // Otherwise, assume "op" is a primitive op and returns its op // signature and shape inference function. + // + // NB: This function outputs a borrowed pointer, which can be invalidated by a + // subsequent call to `ReplaceFunction()` with the given name. Status LookUp(const string& op_type_name, - const OpRegistrationData** op_reg_data) const override; + const OpRegistrationData** op_reg_data) const override + LOCKS_EXCLUDED(mu_); // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. @@ -370,9 +380,12 @@ class FunctionLibraryDefinition : public OpRegistryInterface { Status GetAttr(const Node& node, const string& attr, T* value) const; // Returns a proto representation of the state of this function library. - FunctionDefLibrary ToProto() const; + FunctionDefLibrary ToProto() const LOCKS_EXCLUDED(mu_); - size_t num_functions() const { return function_defs_.size(); } + size_t num_functions() const { + tf_shared_lock l(mu_); + return function_defs_.size(); + } const OpRegistryInterface* default_registry() const { return default_registry_; @@ -388,24 +401,42 @@ class FunctionLibraryDefinition : public OpRegistryInterface { OpRegistrationData op_registration_data; }; + const FunctionDef* FindHelper(const string& func) const + SHARED_LOCKS_REQUIRED(mu_); + string FindGradientHelper(const string& func) const + SHARED_LOCKS_REQUIRED(mu_); + // Same as AddFunctionDef/AddGradientDef except these methods set // `added` to true if the `fdef`/`grad` were actually added to this. - Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added); - Status AddGradientDefHelper(const GradientDef& grad, bool* added); + Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + Status AddGradientDefHelper(const GradientDef& grad, bool* added) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + mutable mutex mu_; const OpRegistryInterface* const default_registry_; gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>> - function_defs_; - gtl::FlatMap<string, string> func_grad_; + function_defs_ GUARDED_BY(mu_); + gtl::FlatMap<string, string> func_grad_ GUARDED_BY(mu_); // Helper function for GetAttr. Returns the FunctionDef* to get the // attr from. - const FunctionDef* GetAttrImpl(const NodeDef& ndef) const; + const FunctionDef* GetAttrImpl(const NodeDef& ndef) const LOCKS_EXCLUDED(mu_); - // Remove all functions in `funcs` and all gradients of - // functions in `funcs_with_grads` from this library. + // Remove all functions in `funcs` and all gradients of functions in + // `funcs_with_grads` from this library. void Remove(const std::vector<string>& funcs, - const std::vector<string>& funcs_with_grads); + const std::vector<string>& funcs_with_grads) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Remove `func` from the library. Returns non-OK Status unless `func` is in + // the library. This should only be called when there is a guarantee that the + // function being removed hasn't been retrieved with `Find`. + Status RemoveFunction(const string& func) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Remove gradient of function `func` from the library. Returns non-OK Status + // unless `func` has a gradient. + Status RemoveGradient(const string& func) EXCLUSIVE_LOCKS_REQUIRED(mu_); }; // Forward declare. Defined in common_runtime/function.h diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index e42a7807e4..e778b7879d 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -383,8 +383,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func)); // Replace optimized function with a new FunctionDef. - TF_RETURN_IF_ERROR(flib.RemoveFunction(func_name)); - TF_RETURN_IF_ERROR(flib.AddFunctionDef(optimized_func)); + TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func)); } // If optimized at least one function, update the graph library. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 7ca1131fda..0488dc9752 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import collections import functools +from multiprocessing.pool import ThreadPool import sys from tensorflow.core.protobuf import config_pb2 @@ -143,6 +144,61 @@ class FunctionTest(test.TestCase): out = sq_op(t) self.assertAllEqual(out, math_ops.matmul(t, t).numpy()) + def testExecutingStatelessDefunConcurrently(self): + + @function.defun + def stateless(x): + return math_ops.multiply(2.0, x) + + pool = ThreadPool() + inputs = [constant_op.constant(1.0 * x) for x in range(100)] + outputs = [float(out) for out in pool.map(stateless, inputs)] + expected = [float(2.0 * x) for x in inputs] + self.assertSequenceEqual(outputs, expected) + + def testExecutingManyStatelessDefunsConcurrently(self): + + @function.defun + def stateless(x): + del x + return math_ops.multiply(2.0, 2.0) + + pool = ThreadPool() + # `pool.map` below instantiates 100 functions, one for each object. + outputs = [ + float(out) + for out in pool.map(stateless, [object() for _ in range(100)]) + ] + expected = [4.0] * 100 + self.assertSequenceEqual(outputs, expected) + + def testExecutingStatefulDefunConcurrently(self): + + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def stateful(x): + v.assign(x) + + pool = ThreadPool() + inputs = [constant_op.constant(0.0)] * 100 + pool.map(stateful, inputs) + self.assertEqual(float(v.read_value()), 0.0) + + def testExecutingManyStatefulDefunsConcurrently(self): + + v = resource_variable_ops.ResourceVariable(1.0) + + @function.defun + def stateful(x): + del x + return v.assign(0.0) + + pool = ThreadPool() + # `pool.map` below instantiates 100 functions, one for each object. + pool.map(stateful, [object() for _ in range(100)]) + self.assertEqual(float(v.read_value()), 0.0) + def disabled_testRandomSeed(self): @function.defun |