From 0ca9c29ffef9c7ef6a05bd9d4965a93cded9b1b6 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 26 Jul 2017 11:39:27 -0700 Subject: Setup a resource manager and enable the use of multiple threads during folding PiperOrigin-RevId: 163233960 --- tensorflow/core/grappler/optimizers/constant_folding.cc | 5 ++++- tensorflow/core/grappler/optimizers/constant_folding.h | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7f845bb9e2..72bbaa78af 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -59,7 +59,7 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { class DeviceSimple : public DeviceBase { public: DeviceSimple() : DeviceBase(Env::Default()) { - eigen_worker_threads_.num_threads = 1; + eigen_worker_threads_.num_threads = port::NumSchedulableCPUs(); eigen_worker_threads_.workers = new thread::ThreadPool( Env::Default(), "constant_folding", eigen_worker_threads_.num_threads); eigen_threadpool_wrapper_.reset( @@ -101,6 +101,8 @@ string AsControlDependency(const NodeDef& node) { } // namespace ConstantFolding::ConstantFolding() { + resource_mgr_.reset(new ResourceMgr()); + ops_to_preserve_ = std::regex( "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|" "Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|" @@ -346,6 +348,7 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node, params.frame_iter = FrameAndIter(0, 0); params.inputs = &inputs; params.op_kernel = op_kernel.get(); + params.resource_manager = resource_mgr_.get(); gtl::InlinedVector output_attrs; const int num_outputs = op_kernel->num_outputs(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 1c20233345..88475e4e75 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" @@ -70,6 +71,7 @@ class ConstantFolding : public GraphOptimizer { Status SimplifyGraph(GraphDef* output, const GraphProperties& properties); std::unique_ptr device_; + std::unique_ptr resource_mgr_; GraphDef graph_; std::unique_ptr node_map_; std::set nodes_to_preserve_; -- cgit v1.2.3