diff options
-rw-r--r-- | tensorflow/core/kernels/functional_ops.cc | 35 |
1 files changed, 9 insertions, 26 deletions
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 36607b4d7b..1c0abf26cd 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -232,30 +232,17 @@ class WhileOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - // TODO(b/37549631): Because this op has `SetIsStateful()` in its - // op registration, this kernel may be shared by multiple - // subgraphs, which have different associated - // `FunctionLibraryRuntime` objects and hence different `FHandle` - // namespaces. We currently work around this by caching the map - // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two - // functions this op uses. + // TODO(b/37549631): Because this op has `SetIsStateful()` in its op + // registration, this kernel may be shared by multiple subgraphs, which have + // different associated `FunctionLibraryRuntime` objects and hence different + // `FHandle` namespaces. So we must call Instantiate() to make sure we get + // the correct function handles with respect to `lib`. Note the underlying + // `lib->Instantiate()` caches the created function handles, so calling + // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - { - mutex_lock l(mu_); - const auto iter = handles_.find(lib); - if (iter == handles_.end()) { - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), - done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), - done); - handles_[lib] = {cond_handle, body_handle}; - } else { - cond_handle = iter->second.first; - body_handle = iter->second.second; - } - } - + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -263,10 +250,6 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; - mutex mu_; - std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> - handles_ GUARDED_BY(mu_); - class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, |