aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-08-10 14:44:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 14:52:54 -0700
commit39b6df56193d6fc00b49634ba255ad24e52e9e90 (patch)
treef6cf78232e7296fc28c9217d67d58dfd0deed5c8
parent0b36ff79021b907f5447bfcbaa060dbdc2114c67 (diff)
Make FunctionLibraryDefinition thread-safe.
The eager runtime mutates the FunctionLibraryRuntime's FunctionLibraryDefinition, which is shared across threads; at the same time, OpKernels might read the FunctionLibraryDefinition. This is not thread-safe unless FunctionLibraryDefinition is thread-safe. This change makes FunctionLibraryDefinition, which is basically a map from function names to FunctionDefs, thread-safe. This is almost entirely accomplished by guarding the map with a mutex. There is however one complication: Find and RemoveFunction cannot be made thread-safe in a straightforward way (Find returns a raw pointer to a FunctionDef while Remove can delete the corresponding FunctionDef). In light of the fact that clients only ever call RemoveFunction when they in fact want to replace an existing function with a new one, we make the following modifications to FunctionLibraryDefinition's API: 1. A Contains method is added to check for the existence of a function. 2. A ReplaceFunction method is added. 3. RemoveFunction and RemoveGradient are made private. We also update clients of the FunctionLibraryDefinition to use Contains & ReplaceFunction instead of Find and RemoveFunction. PiperOrigin-RevId: 208271076
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc3
-rw-r--r--tensorflow/core/common_runtime/eager/context.h2
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc17
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h7
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc5
-rw-r--r--tensorflow/core/framework/function.cc65
-rw-r--r--tensorflow/core/framework/function.h79
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc3
-rw-r--r--tensorflow/python/eager/function_test.py56
9 files changed, 169 insertions, 68 deletions
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index fdd71c6a58..f150bf1819 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -1161,8 +1161,7 @@ Status Encapsulator::Subgraph::ReplaceFunctionDef(
strings::StrCat("replace_encapsulate_fdef_", name), fdef);
}
- TF_RETURN_IF_ERROR(library->RemoveFunction(name));
- TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ TF_RETURN_IF_ERROR(library->ReplaceFunction(name, fdef));
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index ebaf500bb3..21c5bdf8e9 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -134,8 +134,6 @@ class EagerContext {
Rendezvous* GetRendezvous() { return rendezvous_; }
- mutex* FunctionsMu() { return &functions_mu_; }
-
const tensorflow::DeviceMgr* local_device_mgr() const {
return (local_device_manager_ != nullptr) ? local_device_manager_.get()
: local_unowned_device_manager_;
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 8eaa6e4429..46065f399c 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -300,12 +300,6 @@ Status EagerLocalExecute(EagerOperation* op,
<< device->name();
}
kernel = new KernelAndDevice(ctx->GetRendezvous());
- // Knowledge of the implementation of Init (and in-turn
- // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
- // will be accessed, so grab on to the lock.
- // See WARNING comment in Execute (before kernel->Run) - would be nice to
- // rework to avoid this subtlety.
- tf_shared_lock l(*ctx->FunctionsMu());
auto* flr = ctx->func_lib(device);
if (flr == nullptr) {
@@ -646,15 +640,8 @@ Status EagerExecute(EagerContext* ctx, Device* device,
TF_RETURN_IF_ERROR(op_inputs[i]->Tensor(&input_tensor));
inputs[i] = *input_tensor;
}
- // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
- // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def.
- // But knowledge of the implementation
- // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
- // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
- // This is quite subtle. Re-work things to make this better? (Would it make
- // sense for FunctionLibraryRuntime to ensure thread-safe access to
- // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
- // for ops which are a part of functions.
+ // TODO(apassos) figure out how to record stats for ops which are a part of
+ // functions.
// TODO(agarwal): change Run to take vector of handles ?
ScopedStepContainer* container = ctx->StepContainer();
if (container == nullptr) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 751cf687b2..0ef419cbaa 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -49,13 +49,6 @@ class KernelAndDevice {
//
// The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice.
- //
- // TODO(ashankar): Figure out thread-safety concerns around
- // FunctionLibraryRuntime (in particular, how the underlying
- // FunctionLibraryDefinition might be mutated by another thread as new
- // functions are registered with it). Conservatively, thread-safe usage of
- // the FunctionLibraryRuntime is pushed on to the caller (see locking in
- // c_api.cc).
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out);
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index 9c9eacb5b5..c23b7d3699 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -643,10 +643,9 @@ Status GraphExecutionState::OptimizeGraph(
for (const FunctionDef& fdef : new_graph.library().function()) {
const string& func_name = fdef.signature().name();
- if ((*optimized_flib)->Find(func_name)) {
+ if ((*optimized_flib)->Contains(func_name)) {
VLOG(3) << "Replace function: name=" << func_name;
- TF_RETURN_IF_ERROR((*optimized_flib)->RemoveFunction(func_name));
- TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
+ TF_RETURN_IF_ERROR((*optimized_flib)->ReplaceFunction(func_name, fdef));
} else {
VLOG(3) << "Add new function: name=" << func_name;
TF_RETURN_IF_ERROR((*optimized_flib)->AddFunctionDef(fdef));
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 57bcc0f513..6b92e10d76 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -920,10 +920,12 @@ FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionLibraryDefinition& other)
- : default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
+ : default_registry_(other.default_registry_) {
+ tf_shared_lock l(other.mu_);
for (const auto& it : other.function_defs_) {
TF_CHECK_OK(AddFunctionDef(it.second->fdef));
}
+ func_grad_ = other.func_grad_;
}
FunctionLibraryDefinition::FunctionLibraryDefinition(
@@ -943,8 +945,19 @@ FunctionLibraryDefinition::FunctionLibraryDefinition(
FunctionLibraryDefinition::~FunctionLibraryDefinition() {}
-const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
- auto iter = function_defs_.find(name);
+bool FunctionLibraryDefinition::Contains(const string& func) const {
+ tf_shared_lock l(mu_);
+ return function_defs_.find(func) != function_defs_.end();
+}
+
+const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const {
+ tf_shared_lock l(mu_);
+ return FindHelper(func);
+}
+
+const FunctionDef* FunctionLibraryDefinition::FindHelper(
+ const string& func) const {
+ auto iter = function_defs_.find(func);
if (iter == function_defs_.end()) {
return nullptr;
} else {
@@ -953,6 +966,7 @@ const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const {
}
Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
+ mutex_lock l(mu_);
bool added;
return AddFunctionDefHelper(fdef, &added);
}
@@ -984,6 +998,7 @@ Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef,
}
Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) {
+ mutex_lock l(mu_);
bool added;
return AddGradientDefHelper(grad, &added);
}
@@ -1009,13 +1024,17 @@ Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad,
Status FunctionLibraryDefinition::AddLibrary(
const FunctionLibraryDefinition& other) {
+ // Clone `other` to ensure thread-safety (grabbing `other`'s lock for
+ // the duration of the function could lead to deadlock).
+ FunctionLibraryDefinition clone(other);
+ mutex_lock l(mu_);
// Remember the funcs and grads that we added successfully so that
// we can roll them back on error.
std::vector<string> funcs;
std::vector<string> funcs_with_grads;
Status s;
bool added;
- for (auto iter : other.function_defs_) {
+ for (auto iter : clone.function_defs_) {
s = AddFunctionDefHelper(iter.second->fdef, &added);
if (!s.ok()) {
Remove(funcs, funcs_with_grads);
@@ -1025,7 +1044,7 @@ Status FunctionLibraryDefinition::AddLibrary(
funcs.push_back(iter.second->fdef.signature().name());
}
}
- for (auto iter : other.func_grad_) {
+ for (auto iter : clone.func_grad_) {
GradientDef grad;
grad.set_function_name(iter.first);
grad.set_gradient_func(iter.second);
@@ -1045,6 +1064,7 @@ Status FunctionLibraryDefinition::AddLibrary(
const FunctionDefLibrary& lib_def) {
// Remember the funcs and grads that we added successfully so that
// we can roll them back on error.
+ mutex_lock l(mu_);
std::vector<string> funcs;
std::vector<string> funcs_with_grads;
Status s;
@@ -1072,6 +1092,15 @@ Status FunctionLibraryDefinition::AddLibrary(
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
+ const FunctionDef& fdef) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveFunction(func));
+ TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
@@ -1106,11 +1135,17 @@ void FunctionLibraryDefinition::Remove(
}
string FunctionLibraryDefinition::FindGradient(const string& func) const {
+ tf_shared_lock l(mu_);
+ return gtl::FindWithDefault(func_grad_, func, "");
+}
+
+string FunctionLibraryDefinition::FindGradientHelper(const string& func) const {
return gtl::FindWithDefault(func_grad_, func, "");
}
Status FunctionLibraryDefinition::LookUp(
const string& op, const OpRegistrationData** op_reg_data) const {
+ tf_shared_lock l(mu_);
auto iter = function_defs_.find(op);
if (iter != function_defs_.end()) {
*op_reg_data = &iter->second->op_registration_data;
@@ -1134,18 +1169,22 @@ const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
return nullptr;
}
const string& func_name = forward_func_attrs->name();
- const string& grad_name = FindGradient(func_name);
- // If 'func' has a user-defined gradient function, uses the grad
- // function's attrs to see if noinline is specified. Otherwise,
- // uses func's attrs.
- if (!grad_name.empty()) {
- return Find(grad_name);
- }
- return Find(func_name);
+ {
+ tf_shared_lock l(mu_);
+ const string& grad_name = FindGradientHelper(func_name);
+ // If 'func' has a user-defined gradient function, uses the grad
+ // function's attrs to see if noinline is specified. Otherwise,
+ // uses func's attrs.
+ if (!grad_name.empty()) {
+ return FindHelper(grad_name);
+ }
+ return FindHelper(func_name);
+ }
}
FunctionDefLibrary FunctionLibraryDefinition::ToProto() const {
FunctionDefLibrary lib;
+ tf_shared_lock l(mu_);
for (const auto& f : function_defs_) {
*lib.add_function() = f.second->fdef;
}
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index 31a816ac5f..c81f4a4450 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -288,8 +289,11 @@ class FunctionCallFrame : public CallFrameInterface {
// Helper to maintain a map between function names in a given
// FunctionDefLibrary and function definitions.
+//
+// This class is thread-safe.
class FunctionLibraryDefinition : public OpRegistryInterface {
public:
+ // Note: This constructor grabs `lib_def`'s lock in shared mode.
explicit FunctionLibraryDefinition(const FunctionLibraryDefinition& lib_def);
FunctionLibraryDefinition(const OpRegistryInterface* default_registry,
const FunctionDefLibrary& lib_def);
@@ -298,9 +302,15 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
FunctionLibraryDefinition& operator=(const FunctionLibraryDefinition&) =
delete;
+ // Returns True if the library contains `func`, False otherwise.
+ bool Contains(const string& func) const;
+
// Returns nullptr if "func" is not defined in "lib_def". Otherwise,
// returns its definition proto.
- const FunctionDef* Find(const string& func) const;
+ //
+ // NB: This function returns a borrowed pointer, which can be invalidated by a
+ // subsequent call to `ReplaceFunction()` with the given name.
+ const FunctionDef* Find(const string& func) const LOCKS_EXCLUDED(mu_);
// Adds function definition 'fdef' to this function library.
// Returns status 'ok' on success, or error otherwise. This is a no-op if
@@ -308,45 +318,45 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// If 'fdef' is successfully added to the library, it will be accessible
// from 'LookUp' and included in the proto returned by 'ToProto'.
// This operation is atomic.
- Status AddFunctionDef(const FunctionDef& fdef);
+ Status AddFunctionDef(const FunctionDef& fdef) LOCKS_EXCLUDED(mu_);
// Adds gradient definition 'grad' to this function library.
// This is a no-op if 'grad' already exists in this function library.
// If 'grad' is successfully added, it will be accessible via 'FindGradient'
// and included in the proto returned by 'ToProto'.
// This operation is atomic.
- Status AddGradientDef(const GradientDef& grad);
+ Status AddGradientDef(const GradientDef& grad) LOCKS_EXCLUDED(mu_);
- // Remove function `func` from the library. Returns non-OK Status unless
- // `func` is in the library.
- Status RemoveFunction(const string& func);
-
- // Remove gradient of function `func` from the library. Returns non-OK Status
- // unless `func` has a gradient.
- Status RemoveGradient(const string& func);
+ // Replaces the function corresponding to `func` with `fdef`. Returns
+ // a non-OK status if "func" was not found in the library, OK otherwise.
+ Status ReplaceFunction(const string& func, const FunctionDef& fdef);
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
- Status AddLibrary(const FunctionLibraryDefinition& other);
+ Status AddLibrary(const FunctionLibraryDefinition& other) LOCKS_EXCLUDED(mu_);
// Adds the functions and gradients in 'lib_def' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
- Status AddLibrary(const FunctionDefLibrary& lib_def);
+ Status AddLibrary(const FunctionDefLibrary& lib_def) LOCKS_EXCLUDED(mu_);
// If the gradient function for 'func' is specified explicitly in
// the library, returns the gradient function name. Otherwise,
// returns an empty string.
- string FindGradient(const string& func) const;
+ string FindGradient(const string& func) const LOCKS_EXCLUDED(mu_);
// OpRegistryInterface method. Useful for constructing a Graph.
//
// If "op" is defined in the library, returns its signature.
// Otherwise, assume "op" is a primitive op and returns its op
// signature and shape inference function.
+ //
+ // NB: This function outputs a borrowed pointer, which can be invalidated by a
+ // subsequent call to `ReplaceFunction()` with the given name.
Status LookUp(const string& op_type_name,
- const OpRegistrationData** op_reg_data) const override;
+ const OpRegistrationData** op_reg_data) const override
+ LOCKS_EXCLUDED(mu_);
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
@@ -370,9 +380,12 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
Status GetAttr(const Node& node, const string& attr, T* value) const;
// Returns a proto representation of the state of this function library.
- FunctionDefLibrary ToProto() const;
+ FunctionDefLibrary ToProto() const LOCKS_EXCLUDED(mu_);
- size_t num_functions() const { return function_defs_.size(); }
+ size_t num_functions() const {
+ tf_shared_lock l(mu_);
+ return function_defs_.size();
+ }
const OpRegistryInterface* default_registry() const {
return default_registry_;
@@ -388,24 +401,42 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
OpRegistrationData op_registration_data;
};
+ const FunctionDef* FindHelper(const string& func) const
+ SHARED_LOCKS_REQUIRED(mu_);
+ string FindGradientHelper(const string& func) const
+ SHARED_LOCKS_REQUIRED(mu_);
+
// Same as AddFunctionDef/AddGradientDef except these methods set
// `added` to true if the `fdef`/`grad` were actually added to this.
- Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added);
- Status AddGradientDefHelper(const GradientDef& grad, bool* added);
+ Status AddFunctionDefHelper(const FunctionDef& fdef, bool* added)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status AddGradientDefHelper(const GradientDef& grad, bool* added)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ mutable mutex mu_;
const OpRegistryInterface* const default_registry_;
gtl::FlatMap<string, std::unique_ptr<FunctionDefAndOpRegistration>>
- function_defs_;
- gtl::FlatMap<string, string> func_grad_;
+ function_defs_ GUARDED_BY(mu_);
+ gtl::FlatMap<string, string> func_grad_ GUARDED_BY(mu_);
// Helper function for GetAttr. Returns the FunctionDef* to get the
// attr from.
- const FunctionDef* GetAttrImpl(const NodeDef& ndef) const;
+ const FunctionDef* GetAttrImpl(const NodeDef& ndef) const LOCKS_EXCLUDED(mu_);
- // Remove all functions in `funcs` and all gradients of
- // functions in `funcs_with_grads` from this library.
+ // Remove all functions in `funcs` and all gradients of functions in
+ // `funcs_with_grads` from this library.
void Remove(const std::vector<string>& funcs,
- const std::vector<string>& funcs_with_grads);
+ const std::vector<string>& funcs_with_grads)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Remove `func` from the library. Returns non-OK Status unless `func` is in
+ // the library. This should only be called when there is a guarantee that the
+ // function being removed hasn't been retrieved with `Find`.
+ Status RemoveFunction(const string& func) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Remove gradient of function `func` from the library. Returns non-OK Status
+ // unless `func` has a gradient.
+ Status RemoveGradient(const string& func) EXCLUSIVE_LOCKS_REQUIRED(mu_);
};
// Forward declare. Defined in common_runtime/function.h
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e42a7807e4..e778b7879d 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -383,8 +383,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
// Replace optimized function with a new FunctionDef.
- TF_RETURN_IF_ERROR(flib.RemoveFunction(func_name));
- TF_RETURN_IF_ERROR(flib.AddFunctionDef(optimized_func));
+ TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
}
// If optimized at least one function, update the graph library.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 7ca1131fda..0488dc9752 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import collections
import functools
+from multiprocessing.pool import ThreadPool
import sys
from tensorflow.core.protobuf import config_pb2
@@ -143,6 +144,61 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def testExecutingStatelessDefunConcurrently(self):
+
+ @function.defun
+ def stateless(x):
+ return math_ops.multiply(2.0, x)
+
+ pool = ThreadPool()
+ inputs = [constant_op.constant(1.0 * x) for x in range(100)]
+ outputs = [float(out) for out in pool.map(stateless, inputs)]
+ expected = [float(2.0 * x) for x in inputs]
+ self.assertSequenceEqual(outputs, expected)
+
+ def testExecutingManyStatelessDefunsConcurrently(self):
+
+ @function.defun
+ def stateless(x):
+ del x
+ return math_ops.multiply(2.0, 2.0)
+
+ pool = ThreadPool()
+ # `pool.map` below instantiates 100 functions, one for each object.
+ outputs = [
+ float(out)
+ for out in pool.map(stateless, [object() for _ in range(100)])
+ ]
+ expected = [4.0] * 100
+ self.assertSequenceEqual(outputs, expected)
+
+ def testExecutingStatefulDefunConcurrently(self):
+
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def stateful(x):
+ v.assign(x)
+
+ pool = ThreadPool()
+ inputs = [constant_op.constant(0.0)] * 100
+ pool.map(stateful, inputs)
+ self.assertEqual(float(v.read_value()), 0.0)
+
+ def testExecutingManyStatefulDefunsConcurrently(self):
+
+ v = resource_variable_ops.ResourceVariable(1.0)
+
+ @function.defun
+ def stateful(x):
+ del x
+ return v.assign(0.0)
+
+ pool = ThreadPool()
+ # `pool.map` below instantiates 100 functions, one for each object.
+ pool.map(stateful, [object() for _ in range(100)])
+ self.assertEqual(float(v.read_value()), 0.0)
+
def disabled_testRandomSeed(self):
@function.defun