#include "tensorflow/core/graph/subgraph.h" #include #include #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/public/status.h" #include // TODO(josh11b): Test setting the "device" field of a NodeDef. // TODO(josh11b): Test that feeding won't prune targets. namespace tensorflow { namespace { class SubgraphTest : public ::testing::Test { protected: SubgraphTest() : g_(new Graph(OpRegistry::Global())) { RequireDefaultOps(); device_info_.set_name("/job:a/replica:0/task:0/cpu:0"); device_info_.set_device_type(DeviceType(DEVICE_CPU).type()); device_info_.set_incarnation(0); } ~SubgraphTest() override {} void ExpectOK(const string& gdef_ascii) { CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); GraphConstructorOptions opts; TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); } Node* FindNode(const string& name) { for (Node* n : g_->nodes()) { if (n->name() == name) return n; } return nullptr; } bool HasNode(const string& name) { return FindNode(name) != nullptr; } void ExpectNodes(const string& nodes) { int count = 0; std::vector actual_nodes; for (Node* n : g_->nodes()) { if (n->IsOp()) { count++; actual_nodes.push_back(n->name()); } } std::sort(actual_nodes.begin(), actual_nodes.end()); LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); std::vector expected_nodes = str_util::Split(nodes, ','); std::sort(expected_nodes.begin(), expected_nodes.end()); for (const string& s : expected_nodes) { Node* n = FindNode(s); EXPECT_TRUE(n != nullptr) << s; if (n->def().op() == "_Send" || n->def().op() == "_Recv") { EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; } } EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) << "\nActual: " << str_util::Join(actual_nodes, ",") << "\nExpected: " << str_util::Join(expected_nodes, ","); } bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { for (const Edge* e : g_->edges()) { if (e->src()->name() == src && e->src_output() == src_out && e->dst()->name() == dst && e->dst_input() == dst_in) return true; } return false; } bool HasControlEdge(const string& src, const string& dst) { return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); } string Subgraph(const string& fed_str, const string& fetch_str, const string& targets_str) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(*g_, subgraph); std::vector fed = str_util::Split(fed_str, ',', str_util::SkipEmpty()); std::vector fetch = str_util::Split(fetch_str, ',', str_util::SkipEmpty()); std::vector targets = str_util::Split(targets_str, ',', str_util::SkipEmpty()); Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets, device_info_); if (!s.ok()) { delete subgraph; return s.ToString(); } // Replace the graph with the subgraph for the rest of the display program g_.reset(subgraph); return "OK"; } Graph* graph() { return g_.get(); } private: GraphDef gdef_; std::unique_ptr g_; DeviceAttributes device_info_; }; REGISTER_OP("TestParams").Output("o: float"); REGISTER_OP("TestInput").Output("a: float").Output("b: float"); REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); TEST_F(SubgraphTest, Targets1) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" "node { name: 't3_a' op: 'TestRelu' input: 't2' }" "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); EXPECT_EQ("OK", Subgraph("", "", "t1")); ExpectNodes("W1,input,t1"); } TEST_F(SubgraphTest, Targets2) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }" "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }" "node { name: 't3_a' op: 'TestRelu' input: 't2' }" "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a")); ExpectNodes("W1,W2,input,t1,t2,t3_a"); } TEST_F(SubgraphTest, FedOutputs1) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" "node { name: 't3_a' op: 'TestRelu' input: 't2' }" "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); ExpectNodes("W1,W2,_recv_input_1,t1,t2"); } TEST_F(SubgraphTest, FedRefNode) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); EXPECT_EQ("OK", Subgraph("W1:0", "", "t1")); ExpectNodes("_recv_W1_0,W2,t1"); Node* n = FindNode("_recv_W1_0"); EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); } TEST_F(SubgraphTest, FedOutputs2) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" "node { name: 't3_a' op: 'TestRelu' input: 't2' }" "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); // We feed input:1, but nothing connects to it, so the _recv(input:1) // node also disappears. EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2")); ExpectNodes("_recv_t1_0,_recv_W2_0,t2"); } TEST_F(SubgraphTest, FetchOutputs1) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" "node { name: 't3_a' op: 'TestRelu' input: 't2' }" "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2")); ExpectNodes( "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); } TEST_F(SubgraphTest, FetchOutputs2) { ExpectOK( "node { name: 'W1' op: 'TestParams' }" "node { name: 'W2' op: 'TestParams' }" "node { name: 'input' op: 'TestInput' }" "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" "node { name: 't3_a' op: 'TestRelu' input: 't2' }" "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); EXPECT_EQ("OK", Subgraph("", "t3_a", "t2")); ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); } TEST_F(SubgraphTest, ChainOfFools) { ExpectOK( "node { name: 'a' op: 'TestParams' }" "node { name: 'b' op: 'TestRelu' input: 'a'}" "node { name: 'c' op: 'TestRelu' input: 'b'}" "node { name: 'd' op: 'TestRelu' input: 'c'}" "node { name: 'e' op: 'TestRelu' input: 'd'}" "node { name: 'f' op: 'TestRelu' input: 'e'}"); EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", "")); ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0"); EXPECT_TRUE(HasEdge("a", 0, "b", 0)); EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0)); EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0)); EXPECT_TRUE(HasEdge("d", 0, "e", 0)); EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); } static bool HasSubstr(const string& base, const string& substr) { bool ok = StringPiece(base).contains(substr); EXPECT_TRUE(ok) << base << ", expected substring " << substr; return ok; } TEST_F(SubgraphTest, Errors) { ExpectOK( "node { name: 'a' op: 'TestParams' }" "node { name: 'b' op: 'TestRelu' input: 'a'}" "node { name: 'c' op: 'TestRelu' input: 'b'}" "node { name: 'd' op: 'TestRelu' input: 'c'}" "node { name: 'e' op: 'TestRelu' input: 'd'}" "node { name: 'f' op: 'TestRelu' input: 'e'}"); // Duplicated feed and fetch EXPECT_TRUE( HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched")); // Feed not found. EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "", ""), "unable to find")); // Fetch not found. EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found")); // Target not found. EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found")); } REGISTER_OP("In").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float"); static void BM_Subgraph(int iters, int num_nodes) { DeviceAttributes device_info; device_info.set_name("/job:a/replica:0/task:0/cpu:0"); device_info.set_device_type(DeviceType(DEVICE_CPU).type()); device_info.set_incarnation(0); testing::StopTiming(); Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); Node* last_node = nullptr; for (int i = 0; i < num_nodes; i++) { string name = strings::StrCat("N", i); if (i > 0) { last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name)); } else { last_node = ops::SourceOp("In", b.opts().WithName(name)); } } TF_CHECK_OK(b.ToGraph(&g)); } std::vector fed; if (num_nodes > 1000) { fed.push_back(strings::StrCat("N", num_nodes - 1000)); } std::vector fetch; std::vector targets = {strings::StrCat("N", num_nodes - 1)}; testing::StartTiming(); while (--iters > 0) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(g, subgraph); TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets, device_info)); delete subgraph; } } BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); } // namespace } // namespace tensorflow