diff options
author | 2017-07-26 11:39:27 -0700 | |
---|---|---|
committer | 2017-07-26 11:43:02 -0700 | |
commit | 0ca9c29ffef9c7ef6a05bd9d4965a93cded9b1b6 (patch) | |
tree | 8ce63130d4993eb98e47ca325fce964c98abd6ee | |
parent | 9eb2fe4c3a8cf3e64dce4603b03d91da76d8b27e (diff) |
Setup a resource manager and enable the use of multiple threads during folding
PiperOrigin-RevId: 163233960
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/constant_folding.h | 2 |
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 7f845bb9e2..72bbaa78af 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -59,7 +59,7 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { class DeviceSimple : public DeviceBase { public: DeviceSimple() : DeviceBase(Env::Default()) { - eigen_worker_threads_.num_threads = 1; + eigen_worker_threads_.num_threads = port::NumSchedulableCPUs(); eigen_worker_threads_.workers = new thread::ThreadPool( Env::Default(), "constant_folding", eigen_worker_threads_.num_threads); eigen_threadpool_wrapper_.reset( @@ -101,6 +101,8 @@ string AsControlDependency(const NodeDef& node) { } // namespace ConstantFolding::ConstantFolding() { + resource_mgr_.reset(new ResourceMgr()); + ops_to_preserve_ = std::regex( "Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|" "Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|" @@ -346,6 +348,7 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node, params.frame_iter = FrameAndIter(0, 0); params.inputs = &inputs; params.op_kernel = op_kernel.get(); + params.resource_manager = resource_mgr_.get(); gtl::InlinedVector<AllocatorAttributes, 4> output_attrs; const int num_outputs = op_kernel->num_outputs(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 1c20233345..88475e4e75 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -19,6 +19,7 @@ limitations under the License. #include <regex> #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/utils.h" @@ -70,6 +71,7 @@ class ConstantFolding : public GraphOptimizer { Status SimplifyGraph(GraphDef* output, const GraphProperties& properties); std::unique_ptr<DeviceBase> device_; + std::unique_ptr<ResourceMgr> resource_mgr_; GraphDef graph_; std::unique_ptr<NodeMap> node_map_; std::set<string> nodes_to_preserve_; |