aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mingsheng Hong <hongm@google.com>2018-08-01 14:27:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 14:31:24 -0700
commit83d84a3c15f0ab315598ef13468a0ca03ce7d6f8 (patch)
tree0c7a1cccf4c39cd332dd9a127fb54a6294ac0073
parentd07564b2a4cb84b81e32bad7088ba5aec4aca630 (diff)
Simplified the WhileOp kernel impl by removing the handle cache based on
suggestion from apassos@ -- the underlying lib->Instantiate() does the caching. PiperOrigin-RevId: 206993242
-rw-r--r--tensorflow/core/kernels/functional_ops.cc35
1 files changed, 9 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 36607b4d7b..1c0abf26cd 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -232,30 +232,17 @@ class WhileOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library"), done);
- // TODO(b/37549631): Because this op has `SetIsStateful()` in its
- // op registration, this kernel may be shared by multiple
- // subgraphs, which have different associated
- // `FunctionLibraryRuntime` objects and hence different `FHandle`
- // namespaces. We currently work around this by caching the map
- // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
- // functions this op uses.
+ // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
+ // registration, this kernel may be shared by multiple subgraphs, which have
+ // different associated `FunctionLibraryRuntime` objects and hence different
+ // `FHandle` namespaces. So we must call Instantiate() to make sure we get
+ // the correct function handles with respect to `lib`. Note the underlying
+ // `lib->Instantiate()` caches the created function handles, so calling
+ // `Instantiate()` repeatedly on the same `lib` and function is cheap.
FHandle cond_handle;
FHandle body_handle;
- {
- mutex_lock l(mu_);
- const auto iter = handles_.find(lib);
- if (iter == handles_.end()) {
- OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle),
- done);
- OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle),
- done);
- handles_[lib] = {cond_handle, body_handle};
- } else {
- cond_handle = iter->second.first;
- body_handle = iter->second.second;
- }
- }
-
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
(new State(this, ctx, cond_handle, body_handle, done))->Start();
}
@@ -263,10 +250,6 @@ class WhileOp : public AsyncOpKernel {
NameAttrList cond_func_;
NameAttrList body_func_;
- mutex mu_;
- std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
- handles_ GUARDED_BY(mu_);
-
class State {
public:
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,