aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2017-08-24 16:00:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-24 16:04:29 -0700
commitb2ce451502154bcf13723e719ed2bfaed2814b15 (patch)
tree4808cc4b0c2a5a8e9812f7f370afe2442c430784
parent0a2f40e92d7f9a81b482027fd38c1bc429369abd (diff)
Make Graph::IsValidNode public
It can be reimplemented with existing public APIs, but instead of doing so, making this one public seems better. PiperOrigin-RevId: 166407897
-rw-r--r--tensorflow/core/graph/graph.cc6
-rw-r--r--tensorflow/core/graph/graph.h4
-rw-r--r--tensorflow/core/graph/graph_test.cc31
3 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 241e093eea..7d938365c5 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -503,17 +503,17 @@ string Graph::NewName(StringPiece prefix) {
return strings::StrCat(prefix, "/_", name_counter_++);
}
-Status Graph::IsValidNode(Node* node) const {
+Status Graph::IsValidNode(const Node* node) const {
if (node == nullptr) {
return errors::InvalidArgument("Node is null");
}
const int id = node->id();
if (id < 0) {
- return errors::InvalidArgument("node id ", id, "is less than zero");
+ return errors::InvalidArgument("node id ", id, " is less than zero");
}
if (static_cast<size_t>(id) >= nodes_.size()) {
return errors::InvalidArgument(
- "node id ", id, "is >= than number of nodes in graph ", nodes_.size());
+ "node id ", id, " is >= than number of nodes in graph ", nodes_.size());
}
if (nodes_[id] != node) {
return errors::InvalidArgument("Node with id ", id,
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 06461435bc..bd388d9065 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -516,10 +516,12 @@ class Graph {
node->assigned_device_name_index_ = InternDeviceName(device_name);
}
+ // Returns OK if `node` is non-null and belongs to this graph
+ Status IsValidNode(const Node* node) const;
+
// TODO(josh11b): uint64 hash() const;
private:
- Status IsValidNode(Node* node) const;
// If cost_node is non-null, then cost accounting (in CostModel)
// will be associated with that node rather than the new one being
// created.
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
index 68848ae8c8..ca77f3b44d 100644
--- a/tensorflow/core/graph/graph_test.cc
+++ b/tensorflow/core/graph/graph_test.cc
@@ -379,6 +379,37 @@ TEST_F(GraphTest, NewName) {
EXPECT_TRUE(StringPiece(a1).starts_with("A")) << a1;
}
+TEST_F(GraphTest, IsValidNode) {
+ // Add 1 node to graph_
+ Node* g1_node1;
+ TF_CHECK_OK(NodeBuilder("g1_node1", "NoOp").Finalize(&graph_, &g1_node1));
+
+ // Add 2 nodes to graph2
+ Graph graph2(OpRegistry::Global());
+ Node* g2_node1;
+ Node* g2_node2;
+ TF_CHECK_OK(NodeBuilder("g2_node1", "NoOp").Finalize(&graph2, &g2_node1));
+ TF_CHECK_OK(NodeBuilder("g2_node2", "NoOp").Finalize(&graph2, &g2_node2));
+
+ // nullptr
+ Status s = graph_.IsValidNode(nullptr);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_EQ(string("Node is null"), s.error_message());
+
+ // node id_ is too high
+ s = graph_.IsValidNode(g2_node2);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_EQ(string("node id 3 is >= than number of nodes in graph 3"),
+ s.error_message());
+
+ // valid id_ but different ptr
+ s = graph_.IsValidNode(g2_node1);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_EQ(string("Node with id 2 is different from the passed in node. "
+ "Does it belong to a different graph?"),
+ s.error_message());
+}
+
TEST_F(GraphTest, InputEdges) {
Node* a = FromNodeDef("A", "OneOutput", 0);
Node* b = FromNodeDef("B", "TwoInputsOneOutput", 2);