aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/eager/kernel_and_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/eager/kernel_and_device.cc')
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index b410ea175b..dae5d1983f 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -41,17 +41,22 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
out->device_ = device;
out->kernel_.reset(k);
out->flib_ = nullptr;
+ out->runner_ = nullptr;
+ out->default_runner_ = [](std::function<void()> f) { f(); };
return s;
}
// static
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
+ std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device();
out->kernel_.reset(k);
out->flib_ = flib;
+ out->runner_ = runner;
+ out->default_runner_ = [](std::function<void()> f) { f(); };
return s;
}
@@ -83,10 +88,11 @@ Status KernelAndDevice::Run(std::vector<Tensor>* input_tensors,
if (stats != nullptr) {
params.track_allocations = true;
}
- // TODO(apassos): use a thread pool.
- std::function<void(std::function<void()>)> runner =
- [](std::function<void()> f) { f(); };
- params.runner = &runner;
+ if (runner_ == nullptr) {
+ params.runner = &default_runner_;
+ } else {
+ params.runner = runner_;
+ }
ScopedStepContainer step_container(0, [this](const string& name) {
device_->resource_manager()->Cleanup(name).IgnoreError();