diff options
Diffstat (limited to 'tensorflow/core/kernels/functional_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/functional_ops.cc | 80 |
1 files changed, 44 insertions, 36 deletions
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index cb285bf732..1529d2e336 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -127,31 +127,47 @@ class IfOp : public AsyncOpKernel { explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { auto lib = ctx->function_library(); OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library")); - const NameAttrList* func; - OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func)); - OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func)); - OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_)); } ~IfOp() override {} void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + auto lib = ctx->function_library(); + OP_REQUIRES_ASYNC(ctx, lib != nullptr, + errors::Internal("No function library"), done); + + // TODO(b/37549631): Because this op has `SetIsStateful()` in its op + // registration, this kernel may be shared by multiple subgraphs, which have + // different associated `FunctionLibraryRuntime` objects and hence different + // `FHandle` namespaces. So we must call Instantiate() to make sure we get + // the correct function handles with respect to `lib`. Note the underlying + // `lib->Instantiate()` caches the created function handles, so calling + // `Instantiate()` repeatedly on the same `lib` and function is cheap. + FHandle then_handle; + FHandle else_handle; + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done); + bool cond; OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond)); - (new State(this, ctx, cond, done))->Start(); + (new State(this, ctx, cond, then_handle, else_handle, done))->Start(); } private: - FHandle then_handle_; - FHandle else_handle_; + NameAttrList then_func_; + NameAttrList else_func_; class State { public: - State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done) + State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle, + FHandle else_handle, DoneCallback done) : kernel_(kernel), ctx_(ctx), cond_(cond), + then_handle_(then_handle), + else_handle_(else_handle), done_(std::move(done)), lib_(CHECK_NOTNULL(ctx_->function_library())) { SetRunOptions(ctx_, &opts_, true /* always_collect_stats */); @@ -163,7 +179,7 @@ class IfOp : public AsyncOpKernel { ~State() {} void Start() { - FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_; + FHandle handle = cond_ ? then_handle_ : else_handle_; rets_.clear(); lib_->Run( // Evaluate one of the branch. @@ -184,6 +200,8 @@ class IfOp : public AsyncOpKernel { IfOp* const kernel_; OpKernelContext* const ctx_; const bool cond_; + FHandle then_handle_; + FHandle else_handle_; DoneCallback done_; FunctionLibraryRuntime* const lib_; FunctionLibraryRuntime::Options opts_; @@ -200,6 +218,10 @@ REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"), REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp); REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp); +REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp); +REGISTER_KERNEL_BUILDER( + Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp); + class WhileOp : public AsyncOpKernel { public: explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { @@ -214,30 +236,17 @@ class WhileOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, lib != nullptr, errors::Internal("No function library"), done); - // TODO(b/37549631): Because this op has `SetIsStateful()` in its - // op registration, this kernel may be shared by multiple - // subgraphs, which have different associated - // `FunctionLibraryRuntime` objects and hence different `FHandle` - // namespaces. We currently work around this by caching the map - // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two - // functions this op uses. + // TODO(b/37549631): Because this op has `SetIsStateful()` in its op + // registration, this kernel may be shared by multiple subgraphs, which have + // different associated `FunctionLibraryRuntime` objects and hence different + // `FHandle` namespaces. So we must call Instantiate() to make sure we get + // the correct function handles with respect to `lib`. Note the underlying + // `lib->Instantiate()` caches the created function handles, so calling + // `Instantiate()` repeatedly on the same `lib` and function is cheap. FHandle cond_handle; FHandle body_handle; - { - mutex_lock l(mu_); - const auto iter = handles_.find(lib); - if (iter == handles_.end()) { - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), - done); - OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), - done); - handles_[lib] = {cond_handle, body_handle}; - } else { - cond_handle = iter->second.first; - body_handle = iter->second.second; - } - } - + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done); + OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done); (new State(this, ctx, cond_handle, body_handle, done))->Start(); } @@ -245,10 +254,6 @@ class WhileOp : public AsyncOpKernel { NameAttrList cond_func_; NameAttrList body_func_; - mutex mu_; - std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>> - handles_ GUARDED_BY(mu_); - class State { public: State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle, @@ -378,6 +383,9 @@ REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp); REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp); REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp); +REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp); + Status GetScalar(OpKernelContext* ctx, int index, int32* value, const char* label) { Tensor t = ctx->input(index); |