aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-04-05 09:47:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-05 11:07:09 -0700
commit72ae81b7a3e7ae940ba9ed34dab5082529afa0dd (patch)
tree9c460ed324903d115d974a1f39f1d0bd87894861
parent97c447c0fc0c0c85473287b2828460cbab3f8128 (diff)
Add an auto parallelization grappler optimization pass.
Change: 152276787
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD34
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.cc260
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel.h63
-rw-r--r--tensorflow/core/grappler/optimizers/auto_parallel_test.cc125
4 files changed, 482 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index bd96e2b33c..5c119d77e3 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -26,6 +26,40 @@ filegroup(
)
cc_library(
+ name = "auto_parallel",
+ srcs = ["auto_parallel.cc"],
+ hdrs = [
+ "auto_parallel.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:devices",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ ],
+)
+
+cc_test(
+ name = "auto_parallel_test",
+ srcs = ["auto_parallel_test.cc"],
+ deps = [
+ ":auto_parallel",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
+cc_library(
name = "constant_folding",
srcs = ["constant_folding.cc"],
hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc
new file mode 100644
index 0000000000..77ab178653
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc
@@ -0,0 +1,260 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace grappler {
+const char kAutoParallelPrefix[] = "AutoParallel";
+
+NodeDef* AutoParallel::AddNodeDivConst() {
+ NodeDef* node = graph_.add_node();
+ node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-Const"));
+ node->set_op("Const");
+
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_FLOAT);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+
+ AttrValue attr_tensor;
+ auto tensor = attr_tensor.mutable_tensor();
+ tensor->add_float_val(static_cast<float>(num_replicas_));
+ tensor->set_dtype(DT_FLOAT);
+ node->mutable_attr()->insert({"value", attr_tensor});
+ return node;
+}
+
+NodeDef* AutoParallel::AddNodeDiv(const string& name, const string& input_a,
+ const string& input_b) {
+ NodeDef* node = graph_.add_node();
+ node->set_name(strings::StrCat(kAutoParallelPrefix, "-Div-", name));
+ node->set_op("RealDiv");
+ node->add_input(input_a);
+ node->add_input(input_b);
+ AttrValue attr_type;
+ attr_type.set_type(DT_FLOAT);
+ node->mutable_attr()->insert({"T", attr_type});
+ return node;
+}
+
+NodeDef* AutoParallel::AddNodeControl(const string& name,
+ const std::set<string>& deps,
+ GraphDef* graph) {
+ NodeDef* node = graph->add_node();
+ node->set_name(name);
+ node->set_op("NoOp");
+ for (const auto& dep : deps) {
+ node->add_input(strings::StrCat("^", dep));
+ }
+ return node;
+}
+
+Status AutoParallel::Initialize(const GrapplerItem& item) {
+ num_gpus_ = GetNumAvailableGPUs();
+ LOG(INFO) << "Number of GPUs: " << num_gpus_;
+ item_ = &item;
+ graph_ = item.graph;
+ LOG(INFO) << "Original graph size: " << graph_.node_size();
+ if (item.fetch.empty()) {
+ return Status(error::INVALID_ARGUMENT, "No fetch nodes provided.");
+ }
+
+ if (item.MainVariables().empty()) {
+ return Status(error::INVALID_ARGUMENT, "No variables provided.");
+ }
+
+ for (const auto& init : item.init_ops) {
+ VLOG(1) << "Init node: " << init;
+ }
+
+ for (const auto& fetch : item.fetch) {
+ VLOG(1) << "Fetch node: " << fetch;
+ }
+
+ for (const auto& var : item.MainVariables()) {
+ VLOG(2) << "Variable: " << var->name();
+ }
+
+ std::set<string> apply_gradients_ops = {"ApplyGradientDescent",
+ "ApplyProximalGradientDescent",
+ "ApplyAdadelta",
+ "ApplyAdagrad",
+ "ApplyProximalAdagrad",
+ "ApplyAdagradDA",
+ "ApplyFtrl",
+ "ApplyMomentum",
+ "ApplyAdam",
+ "ApplyRMSProp",
+ "ApplyCenteredRMSProp"};
+ const NodeDef* dequeue_node = nullptr;
+ for (int i = 0; i < graph_.node_size(); i++) {
+ all_nodes_.insert(
+ std::make_pair(graph_.node(i).name(), graph_.mutable_node(i)));
+ if (graph_.node(i).op() == "QueueDequeueManyV2") {
+ dequeue_node = graph_.mutable_node(i);
+ }
+ if (apply_gradients_ops.find(graph_.node(i).op()) !=
+ apply_gradients_ops.end()) {
+ apply_gradients_nodes_.insert(graph_.node(i).name());
+ VLOG(2) << "Apply gradients node: " << graph_.node(i).name();
+ }
+ }
+
+ auto div_const_node = AddNodeDivConst();
+ all_nodes_.insert(std::make_pair(div_const_node->name(), div_const_node));
+ std::map<string, int> gradient_pos = {{"ApplyGradientDescent", 2},
+ {"ApplyProximalGradientDescent", 4},
+ {"ApplyAdadelta", 6},
+ {"ApplyAdagrad", 3},
+ {"ApplyProximalAdagrad", 5},
+ {"ApplyAdagradDA", 3},
+ {"ApplyFtrl", 3},
+ {"ApplyMomentum", 3},
+ {"ApplyAdam", 9},
+ {"ApplyRMSProp", 7},
+ {"ApplyCenteredRMSProp", 8}};
+ for (const auto& apply_gradient_node_name : apply_gradients_nodes_) {
+ auto apply_gradients_op = all_nodes_[apply_gradient_node_name]->op();
+ auto apply_gradients_node = all_nodes_[apply_gradient_node_name];
+
+ auto div_node = AddNodeDiv(
+ apply_gradient_node_name,
+ apply_gradients_node->input(gradient_pos[apply_gradients_op]),
+ div_const_node->name());
+ all_nodes_.insert(std::make_pair(div_node->name(), div_node));
+ *apply_gradients_node->mutable_input(gradient_pos[apply_gradients_op]) =
+ div_node->name();
+ }
+ LOG(INFO) << "Graph size after adding div nodes: " << all_nodes_.size();
+
+ auto train_nodes = ComputeTransitiveFanin(graph_, item.fetch);
+ LOG(INFO) << "Number of training nodes: " << train_nodes.size();
+
+ std::vector<const NodeDef*> input_nodes;
+ if (dequeue_node) {
+ LOG(INFO) << "Dequeue node: " << dequeue_node->name();
+ input_nodes = ComputeTransitiveFanin(graph_, {dequeue_node->name()});
+ }
+ LOG(INFO) << "Number of input nodes: " << input_nodes.size();
+
+ std::set<string> dont_replicate_nodes;
+ for (const auto& variable : item.MainVariables()) {
+ dont_replicate_nodes.insert(variable->name());
+ }
+ // Don't replicate all input nodes, except the dequeue node.
+ for (const auto& input_node : input_nodes) {
+ if (input_node->name() != dequeue_node->name()) {
+ dont_replicate_nodes.insert(input_node->name());
+ }
+ }
+
+ for (const auto& node : train_nodes) {
+ if (dont_replicate_nodes.find(node->name()) == dont_replicate_nodes.end()) {
+ replica_nodes_.insert(node->name());
+ }
+ }
+ LOG(INFO) << "Number of replica nodes: " << replica_nodes_.size();
+
+ for (const auto& node : all_nodes_) {
+ if (replica_nodes_.find(node.first) == replica_nodes_.end()) {
+ shared_nodes_.insert(node.first);
+ }
+ }
+ LOG(INFO) << "Number of shared nodes: " << shared_nodes_.size();
+ return Status::OK();
+}
+
+bool AutoParallel::NotSharedNode(const string& name) {
+ return shared_nodes_.find(name) == shared_nodes_.end();
+}
+
+void AutoParallel::AddSharedNodes(GraphDef* graph) {
+ string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", 0);
+ for (const auto& node : shared_nodes_) {
+ auto new_node = graph->add_node();
+ *new_node = *all_nodes_[node];
+ for (int i = 0; i < new_node->input_size(); i++) {
+ if (NotSharedNode(NodeName(new_node->input(i)))) {
+ string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
+ *new_node->mutable_input(i) = new_name;
+ }
+ }
+ }
+}
+
+void AutoParallel::AddOneReplica(GraphDef* graph, int number) {
+ string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", number);
+ for (const auto& node : replica_nodes_) {
+ auto new_node = graph->add_node();
+ *new_node = *all_nodes_[node];
+ if (NotSharedNode(new_node->name())) {
+ new_node->set_name(AddPrefixToNodeName(new_node->name(), prefix));
+ if (num_gpus_ > 0) {
+ new_node->set_device(strings::StrCat("/gpu:", number % num_gpus_));
+ }
+ for (int i = 0; i < new_node->input_size(); i++) {
+ if (NotSharedNode(NodeName(new_node->input(i)))) {
+ string new_name = AddPrefixToNodeName(new_node->input(i), prefix);
+ *new_node->mutable_input(i) = new_name;
+ }
+ }
+ }
+ }
+}
+
+void AutoParallel::BuildGraph(GraphDef* graph) {
+ AddSharedNodes(graph);
+ for (int i = 0; i < num_replicas_; i++) {
+ AddOneReplica(graph, i);
+ }
+ std::set<string> fetches;
+ for (int i = 0; i < item_->fetch.size(); i++) {
+ for (int j = 0; j < num_replicas_; j++) {
+ string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
+ string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
+ fetches.insert(fetch);
+ }
+ }
+ string name_control =
+ strings::StrCat(kAutoParallelPrefix, "-Control-", "Fetch");
+ auto control = AddNodeControl(name_control, fetches, graph);
+
+ for (const auto& fetch : item_->fetch) {
+ AddNodeControl(fetch, {control->name()}, graph);
+ }
+ LOG(INFO) << "Parallelized graph size: " << graph->node_size();
+}
+
+Status AutoParallel::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ TF_RETURN_IF_ERROR(Initialize(item));
+ BuildGraph(output);
+ return Status::OK();
+}
+
+void AutoParallel::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // TODO(yaozhang): Add feedback.
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.h b/tensorflow/core/grappler/optimizers/auto_parallel.h
new file mode 100644
index 0000000000..cac0db2c23
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/auto_parallel.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
+#define TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
+
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Automatically parallelize a graph by splitting in the batch dimension.
+class AutoParallel : public GraphOptimizer {
+ public:
+ AutoParallel(int num_replicas) : num_replicas_(num_replicas) {}
+ ~AutoParallel() override {}
+
+ string name() const override { return "autoparallel"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+
+ private:
+ GraphDef graph_;
+ std::map<string, NodeDef*> all_nodes_;
+ std::set<string> apply_gradients_nodes_;
+ std::set<string> replica_nodes_;
+ std::set<string> shared_nodes_;
+ const GrapplerItem* item_;
+ int num_replicas_;
+ int num_gpus_;
+ Status Initialize(const GrapplerItem& item);
+ NodeDef* AddNodeDivConst();
+ NodeDef* AddNodeDiv(const string& name, const string& input_a,
+ const string& input_b);
+ NodeDef* AddNodeControl(const string& name, const std::set<string>& deps,
+ GraphDef* graph);
+ bool NotSharedNode(const string& name);
+ void AddSharedNodes(GraphDef* graph);
+ void AddOneReplica(GraphDef* graph, int number);
+ void BuildGraph(GraphDef* graph);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_AUTO_PARALLEL_H_
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
new file mode 100644
index 0000000000..b7786ccd14
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc
@@ -0,0 +1,125 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class AutoParallelTest : public ::testing::Test {};
+
+TEST_F(AutoParallelTest, SimpleParallel) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output constant_a = ops::Const(s.WithOpName("constant_a"), 1.0f, {1});
+ Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1});
+ Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT);
+ Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a});
+ Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT});
+ auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue},
+ {constant_b}, {DT_FLOAT});
+ Output add = ops::AddN(s.WithOpName("add"), {constant_a, dequeue[0]});
+ Output learning_rate = ops::Const(s.WithOpName("learning_rate"), 0.01f, {1});
+ Output apply_gradient = ops::ApplyGradientDescent(
+ s.WithOpName("apply_gradient"), {var}, {learning_rate}, {add});
+
+ GrapplerItem item;
+ item.init_ops.push_back("assign");
+ item.fetch.push_back("apply_gradient");
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ AutoParallel parallel(2);
+ GraphDef output;
+ Status status = parallel.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(20, output.node_size());
+
+ const NodeDef& node_assign = output.node(0);
+ EXPECT_EQ("assign", node_assign.name());
+ EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_assign.input(1));
+
+ const NodeDef& node_constant_b = output.node(1);
+ EXPECT_EQ("constant_b", node_constant_b.name());
+
+ const NodeDef& node_fifo_queue = output.node(2);
+ EXPECT_EQ("fifo_queue", node_fifo_queue.name());
+
+ const NodeDef& node_var = output.node(3);
+ EXPECT_EQ("var", node_var.name());
+
+ const NodeDef& node_div_const0 = output.node(4);
+ EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-Const",
+ node_div_const0.name());
+
+ const NodeDef& node_div0 = output.node(5);
+ EXPECT_EQ("AutoParallel-Replica-0-AutoParallel-Div-apply_gradient",
+ node_div0.name());
+ const NodeDef& node_add0 = output.node(6);
+ EXPECT_EQ("AutoParallel-Replica-0-add", node_add0.name());
+
+ const NodeDef& node_gradient0 = output.node(7);
+ EXPECT_EQ("AutoParallel-Replica-0-apply_gradient", node_gradient0.name());
+
+ const NodeDef& node_constant_a0 = output.node(8);
+ EXPECT_EQ("AutoParallel-Replica-0-constant_a", node_constant_a0.name());
+
+ const NodeDef& node_dequeue0 = output.node(9);
+ EXPECT_EQ("AutoParallel-Replica-0-dequeue", node_dequeue0.name());
+
+ const NodeDef& node_learning_rate0 = output.node(10);
+ EXPECT_EQ("AutoParallel-Replica-0-learning_rate", node_learning_rate0.name());
+
+ const NodeDef& node_div_const1 = output.node(11);
+ EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-Const",
+ node_div_const1.name());
+
+ const NodeDef& node_div1 = output.node(12);
+ EXPECT_EQ("AutoParallel-Replica-1-AutoParallel-Div-apply_gradient",
+ node_div1.name());
+
+ const NodeDef& node_add1 = output.node(13);
+ EXPECT_EQ("AutoParallel-Replica-1-add", node_add1.name());
+
+ const NodeDef& node_gradient1 = output.node(14);
+ EXPECT_EQ("AutoParallel-Replica-1-apply_gradient", node_gradient1.name());
+
+ const NodeDef& node_constant_a1 = output.node(15);
+ EXPECT_EQ("AutoParallel-Replica-1-constant_a", node_constant_a1.name());
+
+ const NodeDef& node_dequeue1 = output.node(16);
+ EXPECT_EQ("AutoParallel-Replica-1-dequeue", node_dequeue1.name());
+
+ const NodeDef& node_learning_rate1 = output.node(17);
+ EXPECT_EQ("AutoParallel-Replica-1-learning_rate", node_learning_rate1.name());
+
+ const NodeDef& node_fetch = output.node(18);
+ EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name());
+ EXPECT_EQ("^AutoParallel-Replica-0-apply_gradient", node_fetch.input(0));
+ EXPECT_EQ("^AutoParallel-Replica-1-apply_gradient", node_fetch.input(1));
+
+ const NodeDef& node_gradient = output.node(19);
+ EXPECT_EQ("apply_gradient", node_gradient.name());
+ EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow