diff options
author | 2017-04-05 09:47:31 -0800 | |
---|---|---|
committer | 2017-04-05 11:07:09 -0700 | |
commit | 72ae81b7a3e7ae940ba9ed34dab5082529afa0dd (patch) | |
tree | 9c460ed324903d115d974a1f39f1d0bd87894861 | |
parent | 97c447c0fc0c0c85473287b2828460cbab3f8128 (diff) |
Add an auto parallelization grappler optimization pass.
Change: 152276787
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 34 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/auto_parallel.cc | 260 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/auto_parallel.h | 63 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/auto_parallel_test.cc | 125 |
4 files changed, 482 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index bd96e2b33c..5c119d77e3 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -26,6 +26,40 @@ filegroup( ) cc_library( + name = "auto_parallel", + srcs = ["auto_parallel.cc"], + hdrs = [ + "auto_parallel.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + ], +) + +cc_test( + name = "auto_parallel_test", + srcs = ["auto_parallel_test.cc"], + deps = [ + ":auto_parallel", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + +cc_library( name = "constant_folding", srcs = ["constant_folding.cc"], hdrs = [ diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc new file mode 100644 index 0000000000..77ab178653 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -0,0 +1,260 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/auto_parallel.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/devices.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace grappler { +const char kAutoParallelPrefix[] = "AutoParallel"; + +NodeDef* AutoParallel::AddNodeDivConst() { + NodeDef* node = graph_.add_node(); + node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const")); + node->set_op("Const"); + + AttrValue attr_data_type; + attr_data_type.set_type(DT_FLOAT); + node->mutable_attr()->insert({"dtype", attr_data_type}); + + AttrValue attr_tensor; + auto tensor = attr_tensor.mutable_tensor(); + tensor->add_float_val(static_cast<float>(num_replicas_)); + tensor->set_dtype(DT_FLOAT); + node->mutable_attr()->insert({"value", attr_tensor}); + return node; +} + +NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a, + const string& input_b) { + NodeDef* node = graph_.add_node(); + node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name)); + node->set_op("RealDiv"); + node->add_input(input_a); + node->add_input(input_b); + AttrValue attr_type; + attr_type.set_type(DT_FLOAT); + node->mutable_attr()->insert({"T", attr_type}); + return node; +} + +NodeDef* AutoParallel::AddNodeControl(const string& name, + const std::set<string>& deps, + GraphDef* graph) { + NodeDef* node = graph->add_node(); + node->set_name(name); + node->set_op("NoOp"); + for (const auto& dep : deps) { + node->add_input(strings::StrCat("^", dep)); + } + return node; +} + +Status AutoParallel::Initialize(const GrapplerItem& item) { + num_gpus_ = GetNumAvailableGPUs(); + LOG(INFO) << "Number of GPUs: " << num_gpus_; + item_ = &item; + graph_ = item.graph; + LOG(INFO) << "Original graph size: " << graph_.node_size(); + if (item.fetch.empty()) { + return Status(error::INVALID_ARGUMENT, "No fetch nodes provided."); + } + + if (item.MainVariables().empty()) { + return Status(error::INVALID_ARGUMENT, "No variables provided."); + } + + for (const auto& init : item.init_ops) { + VLOG(1) << "Init node: " << init; + } + + for (const auto& fetch : item.fetch) { + VLOG(1) << "Fetch node: " << fetch; + } + + for (const auto& var : item.MainVariables()) { + VLOG(2) << "Variable: " << var->name(); + } + + std::set<string> apply_gradients_ops = {"ApplyGradientDescent", + "ApplyProximalGradientDescent", + "ApplyAdadelta", + "ApplyAdagrad", + "ApplyProximalAdagrad", + "ApplyAdagradDA", + "ApplyFtrl", + "ApplyMomentum", + "ApplyAdam", + "ApplyRMSProp", + "ApplyCenteredRMSProp"}; + const NodeDef* dequeue_node = nullptr; + for (int i = 0; i < graph_.node_size(); i++) { + all_nodes_.insert( + std::make_pair(graph_.node(i).name(), graph_.mutable_node(i))); + if (graph_.node(i).op() == "QueueDequeueManyV2") { + dequeue_node = graph_.mutable_node(i); + } + if (apply_gradients_ops.find(graph_.node(i).op()) != + apply_gradients_ops.end()) { + apply_gradients_nodes_.insert(graph_.node(i).name()); + VLOG(2) << "Apply gradients node: " << graph_.node(i).name(); + } + } + + auto div_const_node = AddNodeDivConst(); + all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node)); + std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2}, + {"ApplyProximalGradientDescent", 4}, + {"ApplyAdadelta", 6}, + {"ApplyAdagrad", 3}, + {"ApplyProximalAdagrad", 5}, + {"ApplyAdagradDA", 3}, + {"ApplyFtrl", 3}, + {"ApplyMomentum", 3}, + {"ApplyAdam", 9}, + {"ApplyRMSProp", 7}, + {"ApplyCenteredRMSProp", 8}}; + for (const auto& apply_gradient_node_name : apply_gradients_nodes_) { + auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op(); + auto apply_gradients_node = all_nodes_[apply_gradient_node_name]; + + auto div_node = AddNodeDiv( + apply_gradient_node_name, + apply_gradients_node->input(gradient_pos[apply_gradients_op]), + div_const_node->name()); + all_nodes_.insert(std::make_pair(div_node->name(), div_node)); + *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) = + div_node->name(); + } + LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size(); + + auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch); + LOG(INFO) << "Number of training nodes: " << train_nodes.size(); + + std::vector<const NodeDef*> input_nodes; + if (dequeue_node) { + LOG(INFO) << "Dequeue node: " << dequeue_node->name(); + input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()}); + } + LOG(INFO) << "Number of input nodes: " << input_nodes.size(); + + std::set<string> dont_replicate_nodes; + for (const auto& variable : item.MainVariables()) { + dont_replicate_nodes.insert(variable->name()); + } + // Don't replicate all input nodes, except the dequeue node. + for (const auto& input_node : input_nodes) { + if (input_node->name() != dequeue_node->name()) { + dont_replicate_nodes.insert(input_node->name()); + } + } + + for (const auto& node : train_nodes) { + if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) { + replica_nodes_.insert(node->name()); + } + } + LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size(); + + for (const auto& node : all_nodes_) { + if (replica_nodes_.find(node.first) == replica_nodes_.end()) { + shared_nodes_.insert(node.first); + } + } + LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size(); + return Status::OK(); +} + +bool AutoParallel::NotSharedNode(const string& name) { + return shared_nodes_.find(name) == shared_nodes_.end(); +} + +void AutoParallel::AddSharedNodes(GraphDef* graph) { + string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0); + for (const auto& node : shared_nodes_) { + auto new_node = graph->add_node(); + *new_node = *all_nodes_[node]; + for (int i = 0; i < new_node->input_size(); i++) { + if (NotSharedNode(NodeName(new_node->input(i)))) { + string new_name = AddPrefixToNodeName(new_node->input(i), prefix); + *new_node->mutable_input(i) = new_name; + } + } + } +} + +void AutoParallel::AddOneReplica(GraphDef* graph, int number) { + string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number); + for (const auto& node : replica_nodes_) { + auto new_node = graph->add_node(); + *new_node = *all_nodes_[node]; + if (NotSharedNode(new_node->name())) { + new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix)); + if (num_gpus_ > 0) { + new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_)); + } + for (int i = 0; i < new_node->input_size(); i++) { + if (NotSharedNode(NodeName(new_node->input(i)))) { + string new_name = AddPrefixToNodeName(new_node->input(i), prefix); + *new_node->mutable_input(i) = new_name; + } + } + } + } +} + +void AutoParallel::BuildGraph(GraphDef* graph) { + AddSharedNodes(graph); + for (int i = 0; i < num_replicas_; i++) { + AddOneReplica(graph, i); + } + std::set<string> fetches; + for (int i = 0; i < item_->fetch.size(); i++) { + for (int j = 0; j < num_replicas_; j++) { + string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j); + string fetch = AddPrefixToNodeName(item_->fetch[i], prefix); + fetches.insert(fetch); + } + } + string name_control = + strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch"); + auto control = AddNodeControl(name_control, fetches, graph); + + for (const auto& fetch : item_->fetch) { + AddNodeControl(fetch, {control->name()}, graph); + } + LOG(INFO) << "Parallelized graph size: " << graph->node_size(); +} + +Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + TF_RETURN_IF_ERROR(Initialize(item)); + BuildGraph(output); + return Status::OK(); +} + +void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // TODO(yaozhang): Add feedback. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h new file mode 100644 index 0000000000..cac0db2c23 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/auto_parallel.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Automatically parallelize a graph by splitting in the batch dimension. +class AutoParallel : public GraphOptimizer { + public: + AutoParallel(int num_replicas) : num_replicas_(num_replicas) {} + ~AutoParallel() override {} + + string name() const override { return "autoparallel"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; + + private: + GraphDef graph_; + std::map<string, NodeDef*> all_nodes_; + std::set<string> apply_gradients_nodes_; + std::set<string> replica_nodes_; + std::set<string> shared_nodes_; + const GrapplerItem* item_; + int num_replicas_; + int num_gpus_; + Status Initialize(const GrapplerItem& item); + NodeDef* AddNodeDivConst(); + NodeDef* AddNodeDiv(const string& name, const string& input_a, + const string& input_b); + NodeDef* AddNodeControl(const string& name, const std::set<string>& deps, + GraphDef* graph); + bool NotSharedNode(const string& name); + void AddSharedNodes(GraphDef* graph); + void AddOneReplica(GraphDef* graph, int number); + void BuildGraph(GraphDef* graph); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_ diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc new file mode 100644 index 0000000000..b7786ccd14 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc @@ -0,0 +1,125 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/auto_parallel.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class AutoParallelTest : public ::testing::Test {}; + +TEST_F(AutoParallelTest, SimpleParallel) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1}); + Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); + Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); + Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a}); + Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT}); + auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue}, + {constant_b}, {DT_FLOAT}); + Output add = ops::AddN(s.WithOpName("add"), {constant_a, dequeue[0]}); + Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1}); + Output apply_gradient = ops::ApplyGradientDescent( + s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add}); + + GrapplerItem item; + item.init_ops.push_back("assign"); + item.fetch.push_back("apply_gradient"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + AutoParallel parallel(2); + GraphDef output; + Status status = parallel.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_EQ(20, output.node_size()); + + const NodeDef& node_assign = output.node(0); + EXPECT_EQ("assign", node_assign.name()); + EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_assign.input(1)); + + const NodeDef& node_constant_b = output.node(1); + EXPECT_EQ("constant_b", node_constant_b.name()); + + const NodeDef& node_fifo_queue = output.node(2); + EXPECT_EQ("fifo_queue", node_fifo_queue.name()); + + const NodeDef& node_var = output.node(3); + EXPECT_EQ("var", node_var.name()); + + const NodeDef& node_div_const0 = output.node(4); + EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-Const", + node_div_const0.name()); + + const NodeDef& node_div0 = output.node(5); + EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-apply_gradient", + node_div0.name()); + const NodeDef& node_add0 = output.node(6); + EXPECT_EQ("AutoParallel-Replica-0-add", node_add0.name()); + + const NodeDef& node_gradient0 = output.node(7); + EXPECT_EQ("AutoParallel-Replica-0-apply_gradient", node_gradient0.name()); + + const NodeDef& node_constant_a0 = output.node(8); + EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_constant_a0.name()); + + const NodeDef& node_dequeue0 = output.node(9); + EXPECT_EQ("AutoParallel-Replica-0-dequeue", node_dequeue0.name()); + + const NodeDef& node_learning_rate0 = output.node(10); + EXPECT_EQ("AutoParallel-Replica-0-learning_rate", node_learning_rate0.name()); + + const NodeDef& node_div_const1 = output.node(11); + EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-Const", + node_div_const1.name()); + + const NodeDef& node_div1 = output.node(12); + EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-apply_gradient", + node_div1.name()); + + const NodeDef& node_add1 = output.node(13); + EXPECT_EQ("AutoParallel-Replica-1-add", node_add1.name()); + + const NodeDef& node_gradient1 = output.node(14); + EXPECT_EQ("AutoParallel-Replica-1-apply_gradient", node_gradient1.name()); + + const NodeDef& node_constant_a1 = output.node(15); + EXPECT_EQ("AutoParallel-Replica-1-constant_a", node_constant_a1.name()); + + const NodeDef& node_dequeue1 = output.node(16); + EXPECT_EQ("AutoParallel-Replica-1-dequeue", node_dequeue1.name()); + + const NodeDef& node_learning_rate1 = output.node(17); + EXPECT_EQ("AutoParallel-Replica-1-learning_rate", node_learning_rate1.name()); + + const NodeDef& node_fetch = output.node(18); + EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name()); + EXPECT_EQ("^AutoParallel-Replica-0-apply_gradient", node_fetch.input(0)); + EXPECT_EQ("^AutoParallel-Replica-1-apply_gradient", node_fetch.input(1)); + + const NodeDef& node_gradient = output.node(19); + EXPECT_EQ("apply_gradient", node_gradient.name()); + EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow |