diff options
author | 2018-08-01 14:27:54 -0700 | |
---|---|---|
committer | 2018-08-01 14:31:24 -0700 | |
commit | 83d84a3c15f0ab315598ef13468a0ca03ce7d6f8 (patch) | |
tree | 0c7a1cccf4c39cd332dd9a127fb54a6294ac0073 | |
parent | d07564b2a4cb84b81e32bad7088ba5aec4aca630 (diff) |
Simplified the WhileOp kernel impl by removing the handle cache based on
suggestion from apassos@ -- the underlying lib->Instantiate() does the caching.
PiperOrigin-RevId: 206993242
-rw-r--r-- | tensorflow/core/kernels/functional_ops.cc | 35 |
1 files changed, 9 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 36607b4d7b..1c0abf26cd 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -232,30 +232,17 @@ class WhileOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - // TODO(b/37549631): Because this op has `SetIsStateful()` in its - // op registration, this kernel may be shared by multiple - // subgraphs, which have different associated - // `FunctionLibraryRuntime` objects and hence different `FHandle` - // namespaces. We currently work around this by caching the map - // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two - // functions this op uses. + // TODO(b/37549631): Because this op has `SetIsStateful()` in its op + // registration, this kernel may be shared by multiple subgraphs, which have + // different associated `FunctionLibraryRuntime` objects and hence different + // `FHandle` namespaces. So we must call Instantiate() to make sure we get + // the correct function handles with respect to `lib`. Note the underlying + // `lib->Instantiate()` caches the created function handles, so calling + // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - { - mutex_lock l(mu_); - const auto iter = handles_.find(lib); - if (iter == handles_.end()) { - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), - done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), - done); - handles_[lib] = {cond_handle, body_handle}; - } else { - cond_handle = iter->second.first; - body_handle = iter->second.second; - } - } - + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -263,10 +250,6 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; - mutex mu_; - std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> - handles_ GUARDED_BY(mu_); - class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, |