aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-07-26 11:39:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-26 11:43:02 -0700
commit0ca9c29ffef9c7ef6a05bd9d4965a93cded9b1b6 (patch)
tree8ce63130d4993eb98e47ca325fce964c98abd6ee
parent9eb2fe4c3a8cf3e64dce4603b03d91da76d8b27e (diff)
Setup a resource manager and enable the use of multiple threads during folding
PiperOrigin-RevId: 163233960
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h2
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index 7f845bb9e2..72bbaa78af 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -59,7 +59,7 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
class DeviceSimple : public DeviceBase {
public:
DeviceSimple() : DeviceBase(Env::Default()) {
- eigen_worker_threads_.num_threads = 1;
+ eigen_worker_threads_.num_threads = port::NumSchedulableCPUs();
eigen_worker_threads_.workers = new thread::ThreadPool(
Env::Default(), "constant_folding", eigen_worker_threads_.num_threads);
eigen_threadpool_wrapper_.reset(
@@ -101,6 +101,8 @@ string AsControlDependency(const NodeDef& node) {
} // namespace
ConstantFolding::ConstantFolding() {
+ resource_mgr_.reset(new ResourceMgr());
+
ops_to_preserve_ = std::regex(
"Placeholder.*|Const|.*Save.*|.*Restore.*|.*Reader|"
"Enter|RefEnter|Exit|RefExit|NextIteration|RefNextIteration|"
@@ -346,6 +348,7 @@ Status ConstantFolding::EvaluateNode(const NodeDef& node,
params.frame_iter = FrameAndIter(0, 0);
params.inputs = &inputs;
params.op_kernel = op_kernel.get();
+ params.resource_manager = resource_mgr_.get();
gtl::InlinedVector<AllocatorAttributes, 4> output_attrs;
const int num_outputs = op_kernel->num_outputs();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 1c20233345..88475e4e75 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <regex>
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
@@ -70,6 +71,7 @@ class ConstantFolding : public GraphOptimizer {
Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
std::unique_ptr<DeviceBase> device_;
+ std::unique_ptr<ResourceMgr> resource_mgr_;
GraphDef graph_;
std::unique_ptr<NodeMap> node_map_;
std::set<string> nodes_to_preserve_;