aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/graph_runner.cc
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-04-09 14:26:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-09 14:29:17 -0700
commit9b18bd70b5739d646b21b7d45de0e5c96b8cc2a1 (patch)
tree92b7fc07a773dd20448f936261f1995c208a5ef9 /tensorflow/core/common_runtime/graph_runner.cc
parente60c87c978f7fbb848bc66ca3caa90ccdab8a9b9 (diff)
Don't initialize global threadpool in GraphRunner.
TF_Graph creates a ShapeRefiner, which in turn creates a GraphRunner, which prior to this change would eventually create a LocalDevice that initialized the global eigen threadpool. This prevents users from specifying a custom number of threads for the pool via a ConfigProto. This change introduces a new device class, SingleThreadedCpuDevice, that can be used for light-weight computations without initializing the threadpool. Addresses #18300. PiperOrigin-RevId: 192188031
Diffstat (limited to 'tensorflow/core/common_runtime/graph_runner.cc')
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc21
1 files changed, 8 insertions, 13 deletions
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 1125d2a34a..790f2eaa1e 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// TODO(skyewm): this is necessary to make the single_threaded_cpu_device.h
+// include work. Some other include must be including eigen without defining
+// this. Consider defining in this in a BUILD rule.
+#define EIGEN_USE_THREADS
+
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@@ -20,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/common_runtime/single_threaded_cpu_device.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_util.h"
@@ -36,18 +42,6 @@ namespace tensorflow {
namespace {
-std::unique_ptr<Device> GetCPUDevice(Env* env) {
- std::vector<Device*> devices;
- SessionOptions session_options;
- session_options.env = env;
- Status s = DeviceFactory::GetFactory(DEVICE_CPU)
- ->CreateDevices(session_options, "", &devices);
- if (s.ok() && !devices.empty()) {
- return std::unique_ptr<Device>(devices[0]);
- }
- return nullptr;
-}
-
// A simple rendezvous class.
// Assumes a single sender and a single receiver, no duplicate sends, and no
// sends of dead tensors.
@@ -98,7 +92,8 @@ class SimpleRendezvous : public Rendezvous {
} // namespace
GraphRunner::GraphRunner(Env* env)
- : device_deleter_(GetCPUDevice(env)), device_(device_deleter_.get()) {}
+ : device_deleter_(new SingleThreadedCpuDevice(env)),
+ device_(device_deleter_.get()) {}
GraphRunner::GraphRunner(Device* device) : device_(device) {}
GraphRunner::~GraphRunner() {}