diff options
author | 2018-04-09 14:26:55 -0700 | |
---|---|---|
committer | 2018-04-09 14:29:17 -0700 | |
commit | 9b18bd70b5739d646b21b7d45de0e5c96b8cc2a1 (patch) | |
tree | 92b7fc07a773dd20448f936261f1995c208a5ef9 /tensorflow/core/common_runtime/graph_runner.cc | |
parent | e60c87c978f7fbb848bc66ca3caa90ccdab8a9b9 (diff) |
Don't initialize global threadpool in GraphRunner.
TF_Graph creates a ShapeRefiner, which in
turn creates a GraphRunner, which prior to this change would eventually create a
LocalDevice that initialized the global eigen threadpool. This prevents
users from specifying a custom number of threads for the pool via a
ConfigProto.
This change introduces a new device class, SingleThreadedCpuDevice, that can
be used for light-weight computations without initializing the threadpool.
Addresses #18300.
PiperOrigin-RevId: 192188031
Diffstat (limited to 'tensorflow/core/common_runtime/graph_runner.cc')
-rw-r--r-- | tensorflow/core/common_runtime/graph_runner.cc | 21 |
1 files changed, 8 insertions, 13 deletions
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 1125d2a34a..790f2eaa1e 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// TODO(skyewm): this is necessary to make the single_threaded_cpu_device.h +// include work. Some other include must be including eigen without defining +// this. Consider defining in this in a BUILD rule. +#define EIGEN_USE_THREADS + #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -20,6 +25,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/common_runtime/single_threaded_cpu_device.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_util.h" @@ -36,18 +42,6 @@ namespace tensorflow { namespace { -std::unique_ptr<Device> GetCPUDevice(Env* env) { - std::vector<Device*> devices; - SessionOptions session_options; - session_options.env = env; - Status s = DeviceFactory::GetFactory(DEVICE_CPU) - ->CreateDevices(session_options, "", &devices); - if (s.ok() && !devices.empty()) { - return std::unique_ptr<Device>(devices[0]); - } - return nullptr; -} - // A simple rendezvous class. // Assumes a single sender and a single receiver, no duplicate sends, and no // sends of dead tensors. @@ -98,7 +92,8 @@ class SimpleRendezvous : public Rendezvous { } // namespace GraphRunner::GraphRunner(Env* env) - : device_deleter_(GetCPUDevice(env)), device_(device_deleter_.get()) {} + : device_deleter_(new SingleThreadedCpuDevice(env)), + device_(device_deleter_.get()) {} GraphRunner::GraphRunner(Device* device) : device_(device) {} GraphRunner::~GraphRunner() {} |