diff options
Diffstat (limited to 'tensorflow/core/graph/graph_partition_test.cc')
-rw-r--r-- | tensorflow/core/graph/graph_partition_test.cc | 316 |
1 files changed, 316 insertions, 0 deletions
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc new file mode 100644 index 0000000000..d912c94025 --- /dev/null +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -0,0 +1,316 @@ +#include "tensorflow/core/graph/graph_partition.h" + +#include <unordered_map> + +#include <gtest/gtest.h> +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/control_flow_ops.h" +#include "tensorflow/cc/ops/random_ops.h" +#include "tensorflow/cc/ops/sendrecv_ops.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/equal_graph_def.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace { + +const char gpu_device[] = "/job:a/replica:0/task:0/gpu:0"; + +string SplitByDevice(const Node* node) { return node->assigned_device_name(); } + +string DeviceName(const Node* node) { + char first = node->name()[0]; + if (first == 'G') { + return gpu_device; + } else { + const string cpu_prefix = "/job:a/replica:0/task:0/cpu:"; + int index = first - 'A'; + return strings::StrCat(cpu_prefix, index); + } +} + +void Partition(const GraphDef& graph_def, + std::unordered_map<string, GraphDef>* partitions) { + Graph g(OpRegistry::Global()); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &g)); + + // Assigns devices to each node. Uses 1st letter of the node name as + // the device index. + for (Node* node : g.nodes()) { + node->set_assigned_device_name(DeviceName(node)); + } + + PartitionOptions popts; + popts.node_to_loc = SplitByDevice; + popts.new_name = [&g](const string& prefix) { return g.NewName(prefix); }; + popts.get_incarnation = [](const string& name) { + return (name[0] - 'A') + 100; + }; + popts.control_flow_added = false; + Status s = Partition(popts, &g, partitions); + CHECK(s.ok()) << s; +} + +void CheckLoopConstruction(const GraphDef& graph_def) { + std::unordered_map<string, GraphDef> partitions; + Partition(graph_def, &partitions); + GraphConstructorOptions opts; + for (const auto& kv : partitions) { + const GraphDef& gdef = kv.second; + bool has_control_enter = false; + bool has_control_merge = false; + bool has_control_switch = false; + bool has_control_next = false; + for (const NodeDef& ndef : gdef.node()) { + // _recvs must have a control input + if (ndef.op() == "_Recv") { + bool has_control = false; + for (const string& input_name : ndef.input()) { + if (StringPiece(input_name).starts_with("^")) { + has_control = true; + break; + } + } + EXPECT_TRUE(has_control); + } + // Must have a control loop + if (StringPiece(ndef.name()).starts_with("_cloop")) { + if (ndef.op() == "Enter") { + has_control_enter = true; + } + if (ndef.op() == "Merge") { + has_control_merge = true; + } + if (ndef.op() == "Switch") { + has_control_switch = true; + } + if (ndef.op() == "NextIteration") { + has_control_next = true; + } + } + } + EXPECT_TRUE(has_control_enter); + EXPECT_TRUE(has_control_merge); + EXPECT_TRUE(has_control_switch); + EXPECT_TRUE(has_control_next); + } +} + +REGISTER_OP("Input").Output("o: float"); +REGISTER_OP("BoolInput").Output("o: bool"); +REGISTER_OP("Cross").Input("a: float").Input("b: float").Output("o: float"); + +Node* Input(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("Input", opts); +} + +Node* BoolInput(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("BoolInput", opts); +} + +Node* Cross(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("Cross", a, b, opts); +} + +class GraphPartitionTest : public ::testing::Test { + protected: + GraphPartitionTest() + : in_(GraphDefBuilder::kFailImmediately), + builder_a_(GraphDefBuilder::kFailImmediately), + builder_b_(GraphDefBuilder::kFailImmediately), + a_opts_(builder_a_.opts().WithDevice("/job:a/replica:0/task:0/cpu:0")), + b_opts_(builder_b_.opts().WithDevice("/job:a/replica:0/task:0/cpu:1")) { + RequireDefaultOps(); + } + + const GraphDef& ToGraphDef() { + in_.ToGraphDef(&in_graph_def_); + return in_graph_def_; + } + + void ExpectMatchA() { + GraphDef graph_def; + builder_a_.ToGraphDef(&graph_def); + string a = "/job:a/replica:0/task:0/cpu:0"; + TF_EXPECT_GRAPH_EQ(graph_def, partitions_[a]); + } + + void ExpectMatchB() { + GraphDef graph_def; + builder_b_.ToGraphDef(&graph_def); + string b = "/job:a/replica:0/task:0/cpu:1"; + TF_EXPECT_GRAPH_EQ(graph_def, partitions_[b]); + } + + GraphDefBuilder in_; + GraphDef in_graph_def_; + GraphDefBuilder builder_a_; + GraphDefBuilder builder_b_; + GraphDefBuilder::Options a_opts_; + GraphDefBuilder::Options b_opts_; + std::unordered_map<string, GraphDef> partitions_; +}; + +TEST_F(GraphPartitionTest, SingleDevice) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Cross(a1, a1, in_.opts().WithName("A2")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(1, partitions_.size()); + + a1 = Input(a_opts_.WithName("A1")); + Cross(a1, a1, a_opts_.WithName("A2")); + ExpectMatchA(); +} + +TEST_F(GraphPartitionTest, CrossDeviceData) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0")); + ExpectMatchA(); + + b1 = Input(b_opts_.WithName("B1")); + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1")); + Cross(recv, b1, b_opts_.WithName("B2")); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1)); + _Send(c, "edge_3_A1", a, 82, b, a_opts_.WithName("A1/_1")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_3_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id = Identity(recv, b_opts_.WithName("A1/_3")); + b1 = Input(b_opts_.WithName("B1")); + Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceData_MultiUse) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + Cross(a1, a1, in_.opts().WithName("B3")); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + _Send(a1, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_0")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_1")); + b1 = Input(b_opts_.WithName("B1")); + Cross(recv, b1, b_opts_.WithName("B2")); + Cross(recv, recv, b_opts_.WithName("B3")); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceControl_MultiUse) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(b1, b1, in_.opts().WithName("B2").WithControlInput(a1)); + Input(in_.opts().WithName("B3").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1)); + _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1")); + ExpectMatchA(); + + Node* recv = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id = Identity(recv, b_opts_.WithName("A1/_3")); + b1 = Input(b_opts_.WithName("B1")); + Cross(b1, b1, b_opts_.WithName("B2").WithControlInput(id)); + Input(b_opts_.WithName("B3").WithControlInput(id)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDevice_DataControl) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = Input(in_.opts().WithName("A1")); + Node* b1 = Input(in_.opts().WithName("B1")); + Cross(a1, b1, in_.opts().WithName("B2")); + Input(in_.opts().WithName("B3").WithControlInput(a1)); + + Partition(ToGraphDef(), &partitions_); + EXPECT_EQ(2, partitions_.size()); + + string a = "/job:a/replica:0/task:0/cpu:0"; + string b = "/job:a/replica:0/task:0/cpu:1"; + a1 = Input(a_opts_.WithName("A1")); + Node* c = EmptyConst<float>(a_opts_.WithName("A1/_0").WithControlInput(a1)); + // NOTE: Send 0 A1/_1 -> A1/_2 is not necessarily needed. We could + // use A1/_0 -> A1/_4 as the control as a minor optimization. + _Send(c, "edge_1_A1", a, 82, b, a_opts_.WithName("A1/_1")); + _Send(a1, "edge_2_A1", a, 82, b, a_opts_.WithName("A1/_4")); + ExpectMatchA(); + + Node* recv1 = + _Recv(DT_FLOAT, "edge_1_A1", a, 82, b, b_opts_.WithName("A1/_2")); + Node* id1 = Identity(recv1, b_opts_.WithName("A1/_3")); + Node* recv2 = + _Recv(DT_FLOAT, "edge_2_A1", a, 82, b, b_opts_.WithName("A1/_5")); + b1 = Input(b_opts_.WithName("B1")); + Cross(recv2, b1, b_opts_.WithName("B2")); + Input(b_opts_.WithName("B3").WithControlInput(id1)); + ExpectMatchB(); +} + +TEST_F(GraphPartitionTest, CrossDeviceLoop) { + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + Node* a1 = BoolInput(in_.opts().WithName("A1")); + Node* a2 = Enter(a1, "foo", in_.opts().WithName("A2")); + Node* a3 = Merge({a2, {"A5", 0, DT_BOOL}}, in_.opts().WithName("A3")); + LoopCond(a3, in_.opts().WithName("A4")); + Node* b1 = Identity(a3, in_.opts().WithName("B1")); + NextIteration(b1, in_.opts().WithName("A5")); + + CheckLoopConstruction(ToGraphDef()); +} + +} // namespace +} // namespace tensorflow |