diff options
Diffstat (limited to 'tensorflow/core/common_runtime/eager/kernel_and_device.cc')
-rw-r--r-- | tensorflow/core/common_runtime/eager/kernel_and_device.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index b410ea175b..dae5d1983f 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -41,17 +41,22 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, out->device_ = device; out->kernel_.reset(k); out->flib_ = nullptr; + out->runner_ = nullptr; + out->default_runner_ = [](std::function<void()> f) { f(); }; return s; } // static Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + std::function<void(std::function<void()>)>* runner, KernelAndDevice* out) { OpKernel* k = nullptr; Status s = flib->CreateKernel(ndef, &k); out->device_ = flib->device(); out->kernel_.reset(k); out->flib_ = flib; + out->runner_ = runner; + out->default_runner_ = [](std::function<void()> f) { f(); }; return s; } @@ -83,10 +88,11 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors, if (stats != nullptr) { params.track_allocations = true; } - // TODO(apassos): use a thread pool. - std::function<void(std::function<void()>)> runner = - [](std::function<void()> f) { f(); }; - params.runner = &runner; + if (runner_ == nullptr) { + params.runner = &default_runner_; + } else { + params.runner = runner_; + } ScopedStepContainer step_container(0, [this](const string& name) { device_->resource_manager()->Cleanup(name).IgnoreError(); |