diff options
Diffstat (limited to 'tensorflow/core/graph/subgraph_test.cc')
-rw-r--r-- | tensorflow/core/graph/subgraph_test.cc | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc new file mode 100644 index 0000000000..ffb3e6e403 --- /dev/null +++ b/tensorflow/core/graph/subgraph_test.cc @@ -0,0 +1,305 @@ +#include "tensorflow/core/graph/subgraph.h" + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/status.h" +#include <gtest/gtest.h> + +// TODO(josh11b): Test setting the "device" field of a NodeDef. +// TODO(josh11b): Test that feeding won't prune targets. + +namespace tensorflow { +namespace { + +class SubgraphTest : public ::testing::Test { + protected: + SubgraphTest() : g_(new Graph(OpRegistry::Global())) { + RequireDefaultOps(); + device_info_.set_name("/job:a/replica:0/task:0/cpu:0"); + device_info_.set_device_type(DeviceType(DEVICE_CPU).type()); + device_info_.set_incarnation(0); + } + + ~SubgraphTest() override {} + + void ExpectOK(const string& gdef_ascii) { + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_)); + GraphConstructorOptions opts; + TF_CHECK_OK(ConvertGraphDefToGraph(opts, gdef_, g_.get())); + } + + Node* FindNode(const string& name) { + for (Node* n : g_->nodes()) { + if (n->name() == name) return n; + } + return nullptr; + } + + bool HasNode(const string& name) { return FindNode(name) != nullptr; } + + void ExpectNodes(const string& nodes) { + int count = 0; + std::vector<string> actual_nodes; + for (Node* n : g_->nodes()) { + if (n->IsOp()) { + count++; + actual_nodes.push_back(n->name()); + } + } + std::sort(actual_nodes.begin(), actual_nodes.end()); + + LOG(INFO) << "Nodes present: " << str_util::Join(actual_nodes, " "); + + std::vector<string> expected_nodes = str_util::Split(nodes, ','); + std::sort(expected_nodes.begin(), expected_nodes.end()); + for (const string& s : expected_nodes) { + Node* n = FindNode(s); + EXPECT_TRUE(n != nullptr) << s; + if (n->def().op() == "_Send" || n->def().op() == "_Recv") { + EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; + } + } + + EXPECT_TRUE(actual_nodes.size() == expected_nodes.size()) + << "\nActual: " << str_util::Join(actual_nodes, ",") + << "\nExpected: " << str_util::Join(expected_nodes, ","); + } + + bool HasEdge(const string& src, int src_out, const string& dst, int dst_in) { + for (const Edge* e : g_->edges()) { + if (e->src()->name() == src && e->src_output() == src_out && + e->dst()->name() == dst && e->dst_input() == dst_in) + return true; + } + return false; + } + bool HasControlEdge(const string& src, const string& dst) { + return HasEdge(src, Graph::kControlSlot, dst, Graph::kControlSlot); + } + + string Subgraph(const string& fed_str, const string& fetch_str, + const string& targets_str) { + Graph* subgraph = new Graph(OpRegistry::Global()); + CopyGraph(*g_, subgraph); + std::vector<string> fed = + str_util::Split(fed_str, ',', str_util::SkipEmpty()); + std::vector<string> fetch = + str_util::Split(fetch_str, ',', str_util::SkipEmpty()); + std::vector<string> targets = + str_util::Split(targets_str, ',', str_util::SkipEmpty()); + + Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, + targets, device_info_); + if (!s.ok()) { + delete subgraph; + return s.ToString(); + } + + // Replace the graph with the subgraph for the rest of the display program + g_.reset(subgraph); + return "OK"; + } + + Graph* graph() { return g_.get(); } + + private: + GraphDef gdef_; + std::unique_ptr<Graph> g_; + DeviceAttributes device_info_; +}; + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestRelu").Input("i: float").Output("o: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); + +TEST_F(SubgraphTest, Targets1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "", "t1")); + ExpectNodes("W1,input,t1"); +} + +TEST_F(SubgraphTest, Targets2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: 'W1' input: 'input:1' }" + "node { name: 't2' op: 'TestMul' input: 'W2' input: 't1' }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "", "t2,t3_a")); + ExpectNodes("W1,W2,input,t1,t2,t3_a"); +} + +TEST_F(SubgraphTest, FedOutputs1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("input:1", "", "t2")); + ExpectNodes("W1,W2,_recv_input_1,t1,t2"); +} + +TEST_F(SubgraphTest, FedRefNode) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }"); + EXPECT_EQ("OK", Subgraph("W1:0", "", "t1")); + ExpectNodes("_recv_W1_0,W2,t1"); + Node* n = FindNode("_recv_W1_0"); + EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0))); +} + +TEST_F(SubgraphTest, FedOutputs2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + // We feed input:1, but nothing connects to it, so the _recv(input:1) + // node also disappears. + EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2")); + ExpectNodes("_recv_t1_0,_recv_W2_0,t2"); +} + +TEST_F(SubgraphTest, FetchOutputs1) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2")); + ExpectNodes( + "W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0"); +} + +TEST_F(SubgraphTest, FetchOutputs2) { + ExpectOK( + "node { name: 'W1' op: 'TestParams' }" + "node { name: 'W2' op: 'TestParams' }" + "node { name: 'input' op: 'TestInput' }" + "node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }" + "node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }" + "node { name: 't3_a' op: 'TestRelu' input: 't2' }" + "node { name: 't3_b' op: 'TestRelu' input: 't2' }"); + EXPECT_EQ("OK", Subgraph("", "t3_a", "t2")); + ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0"); +} + +TEST_F(SubgraphTest, ChainOfFools) { + ExpectOK( + "node { name: 'a' op: 'TestParams' }" + "node { name: 'b' op: 'TestRelu' input: 'a'}" + "node { name: 'c' op: 'TestRelu' input: 'b'}" + "node { name: 'd' op: 'TestRelu' input: 'c'}" + "node { name: 'e' op: 'TestRelu' input: 'd'}" + "node { name: 'f' op: 'TestRelu' input: 'e'}"); + EXPECT_EQ("OK", Subgraph("c:0", "b:0,e:0", "")); + ExpectNodes("a,b,_send_b_0,_recv_c_0,d,e,_send_e_0"); + EXPECT_TRUE(HasEdge("a", 0, "b", 0)); + EXPECT_TRUE(HasEdge("b", 0, "_send_b_0", 0)); + EXPECT_TRUE(HasEdge("_recv_c_0", 0, "d", 0)); + EXPECT_TRUE(HasEdge("d", 0, "e", 0)); + EXPECT_TRUE(HasEdge("e", 0, "_send_e_0", 0)); +} + +static bool HasSubstr(const string& base, const string& substr) { + bool ok = StringPiece(base).contains(substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +TEST_F(SubgraphTest, Errors) { + ExpectOK( + "node { name: 'a' op: 'TestParams' }" + "node { name: 'b' op: 'TestRelu' input: 'a'}" + "node { name: 'c' op: 'TestRelu' input: 'b'}" + "node { name: 'd' op: 'TestRelu' input: 'c'}" + "node { name: 'e' op: 'TestRelu' input: 'd'}" + "node { name: 'f' op: 'TestRelu' input: 'e'}"); + // Duplicated feed and fetch + EXPECT_TRUE( + HasSubstr(Subgraph("c:0", "b:0,c:0", ""), "both fed and fetched")); + // Feed not found. + EXPECT_TRUE(HasSubstr(Subgraph("foo:0", "", ""), "unable to find")); + // Fetch not found. + EXPECT_TRUE(HasSubstr(Subgraph("", "foo:0", ""), "not found")); + // Target not found. + EXPECT_TRUE(HasSubstr(Subgraph("", "", "foo"), "not found")); +} + +REGISTER_OP("In").Output("o: float"); +REGISTER_OP("Op").Input("i: float").Output("o: float"); + +static void BM_Subgraph(int iters, int num_nodes) { + DeviceAttributes device_info; + device_info.set_name("/job:a/replica:0/task:0/cpu:0"); + device_info.set_device_type(DeviceType(DEVICE_CPU).type()); + device_info.set_incarnation(0); + + testing::StopTiming(); + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* last_node = nullptr; + for (int i = 0; i < num_nodes; i++) { + string name = strings::StrCat("N", i); + if (i > 0) { + last_node = ops::UnaryOp("Op", last_node, b.opts().WithName(name)); + } else { + last_node = ops::SourceOp("In", b.opts().WithName(name)); + } + } + TF_CHECK_OK(b.ToGraph(&g)); + } + + std::vector<string> fed; + if (num_nodes > 1000) { + fed.push_back(strings::StrCat("N", num_nodes - 1000)); + } + std::vector<string> fetch; + std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)}; + testing::StartTiming(); + while (--iters > 0) { + Graph* subgraph = new Graph(OpRegistry::Global()); + CopyGraph(g, subgraph); + TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch, + targets, device_info)); + delete subgraph; + } +} +BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); + +} // namespace +} // namespace tensorflow |