diff options
author | 2017-08-24 16:00:28 -0700 | |
---|---|---|
committer | 2017-08-24 16:04:29 -0700 | |
commit | b2ce451502154bcf13723e719ed2bfaed2814b15 (patch) | |
tree | 4808cc4b0c2a5a8e9812f7f370afe2442c430784 | |
parent | 0a2f40e92d7f9a81b482027fd38c1bc429369abd (diff) |
Make Graph::IsValidNode public
It can be reimplemented with existing public APIs, but instead of doing so,
making this one public seems better.
PiperOrigin-RevId: 166407897
-rw-r--r-- | tensorflow/core/graph/graph.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_test.cc | 31 |
3 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 241e093eea..7d938365c5 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -503,17 +503,17 @@ string Graph::NewName(StringPiece prefix) { return strings::StrCat(prefix, "/_", name_counter_++); } -Status Graph::IsValidNode(Node* node) const { +Status Graph::IsValidNode(const Node* node) const { if (node == nullptr) { return errors::InvalidArgument("Node is null"); } const int id = node->id(); if (id < 0) { - return errors::InvalidArgument("node id ", id, "is less than zero"); + return errors::InvalidArgument("node id ", id, " is less than zero"); } if (static_cast<size_t>(id) >= nodes_.size()) { return errors::InvalidArgument( - "node id ", id, "is >= than number of nodes in graph ", nodes_.size()); + "node id ", id, " is >= than number of nodes in graph ", nodes_.size()); } if (nodes_[id] != node) { return errors::InvalidArgument("Node with id ", id, diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 06461435bc..bd388d9065 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -516,10 +516,12 @@ class Graph { node->assigned_device_name_index_ = InternDeviceName(device_name); } + // Returns OK if `node` is non-null and belongs to this graph + Status IsValidNode(const Node* node) const; + // TODO(josh11b): uint64 hash() const; private: - Status IsValidNode(Node* node) const; // If cost_node is non-null, then cost accounting (in CostModel) // will be associated with that node rather than the new one being // created. diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 68848ae8c8..ca77f3b44d 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -379,6 +379,37 @@ TEST_F(GraphTest, NewName) { EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1; } +TEST_F(GraphTest, IsValidNode) { + // Add 1 node to graph_ + Node* g1_node1; + TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1)); + + // Add 2 nodes to graph2 + Graph graph2(OpRegistry::Global()); + Node* g2_node1; + Node* g2_node2; + TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1)); + TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2)); + + // nullptr + Status s = graph_.IsValidNode(nullptr); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_EQ(string("Node is null"), s.error_message()); + + // node id_ is too high + s = graph_.IsValidNode(g2_node2); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"), + s.error_message()); + + // valid id_ but different ptr + s = graph_.IsValidNode(g2_node1); + EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); + EXPECT_EQ(string("Node with id 2 is different from the passed in node. " + "Does it belong to a different graph?"), + s.error_message()); +} + TEST_F(GraphTest, InputEdges) { Node* a = FromNodeDef("A", "OneOutput", 0); Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2); |