aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2017-05-16 16:08:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-16 16:12:05 -0700
commit749e5cc18381f7a5ec174673f76e20aead8529c6 (patch)
tree4b92d36c9e1d8e59e34fd8d08e7f11fbda1315d9 /tensorflow/core/graph
parented5d05d8b53425ef98aad129a60143a5011a4288 (diff)
Reduce direct references to NodeDef in favor of Node and AttrSlice
This is one step towards replacing in-memory use of NodeDef with a customized NodeInfo class. There are still quite a few Node::def() references, but far fewer than before. Those remaining require more work, either because they are part of kernel registration (which is a bunch of functions), copy and modify the NodeDef, etc. Follow-on CLs will remove more. RELNOTES: n/a PiperOrigin-RevId: 156244933
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/control_flow.cc2
-rw-r--r--tensorflow/core/graph/graph.cc4
-rw-r--r--tensorflow/core/graph/graph.h14
-rw-r--r--tensorflow/core/graph/graph_constructor.cc4
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc66
-rw-r--r--tensorflow/core/graph/graph_partition.cc8
-rw-r--r--tensorflow/core/graph/graph_test.cc10
-rw-r--r--tensorflow/core/graph/optimizer_cse.cc37
-rw-r--r--tensorflow/core/graph/quantize_training.cc6
-rw-r--r--tensorflow/core/graph/quantize_training_test.cc12
-rw-r--r--tensorflow/core/graph/subgraph.cc6
-rw-r--r--tensorflow/core/graph/subgraph_test.cc4
12 files changed, 80 insertions, 93 deletions
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc
index 8409fb4cd0..db6683d1e7 100644
--- a/tensorflow/core/graph/control_flow.cc
+++ b/tensorflow/core/graph/control_flow.cc
@@ -88,7 +88,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) {
out_info->frame = out;
out_info->parent_frame = frame;
TF_RETURN_IF_ERROR(
- GetNodeAttr(out->def(), "frame_name", &out_info->frame_name));
+ GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name));
if (out_info->frame_name.empty()) {
return errors::InvalidArgument("The Enter node ", out->name(),
" must have a frame name.");
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index d765959ca0..9066de5668 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -78,7 +78,7 @@ string Node::DebugString() const {
} else {
strings::StrAppend(&ret, " op device:");
strings::StrAppend(&ret, "{", assigned_device_name_, "}");
- strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}");
+ strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}");
}
return ret;
}
@@ -474,7 +474,7 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
for (size_t i = 0; i < inputs.size(); ++i) {
const Edge* edge = inputs[i];
if (edge == nullptr) {
- node_def->add_input(node->def().input(i));
+ node_def->add_input(node->requested_inputs()[i]);
} else {
const Node* src = edge->src();
if (!src->IsOp()) continue;
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index ac22dfc324..8554cb2f4b 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -71,6 +71,7 @@ class Node {
int cost_id() const { return cost_id_; }
const string& name() const { return props_->node_def_.name(); }
const string& type_string() const { return props_->node_def_.op(); }
+
// def() provides the NodeDef the user supplied, but the specifics
// of this Node may have changed due to placement, optimization, etc.
// In particular:
@@ -80,6 +81,7 @@ class Node {
// * def().device() is the "user's requested device" and may not match
// the actual assigned device, see assigned_device_name() below;
// * def().attr() is authoritative.
+ // TODO(irving): Replace with NodeInfo.
const NodeDef& def() const { return props_->node_def_; }
const OpDef& op_def() const { return *props_->op_def_; }
@@ -92,6 +94,10 @@ class Node {
DataType output_type(int32 o) const { return props_->output_types_[o]; }
const DataTypeVector& output_types() const { return props_->output_types_; }
+ // The device requested by the user. For the actual assigned device,
+ // use assigned_device_name() below.
+ const string& requested_device() const { return def().device(); }
+
// This gives the device the runtime has assigned this node to. If
// you want the device the user requested, use def().device() instead.
// TODO(josh11b): Validate that the assigned_device, if not empty:
@@ -103,6 +109,14 @@ class Node {
assigned_device_name_ = device_name;
}
+ // Read only access to attributes
+ AttrSlice attrs() const { return AttrSlice(def()); }
+
+ // Inputs requested by the NodeDef. For the actual inputs, use in_edges.
+ const protobuf::RepeatedPtrField<string>& requested_inputs() const {
+ return def().input();
+ }
+
// Get the neighboring nodes via edges either in or out of this node.
gtl::iterator_range<NeighborIter> in_nodes() const;
gtl::iterator_range<NeighborIter> out_nodes() const;
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 9d4a0a52f7..70087b8fe1 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -424,7 +424,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
// For nodes with the _output_shapes atttribute, override the shape.
std::vector<TensorShapeProto> shape_attrs;
const char* kAttrName = "_output_shapes";
- if (!GetNodeAttr(node->def(), kAttrName, &shape_attrs).ok()) {
+ if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) {
// No _output_shapes attribute, the AddNode call above was sufficient.
return Status::OK();
}
@@ -458,7 +458,7 @@ Status GraphConstructor::ValidateShape(Node* node) {
// functions that are not critical to correct execution but
// would cause graphs to fail if imported after correcting.
//
- const string& op = node->def().op();
+ const string& op = node->type_string();
const std::vector<string> whitelist = {
// To be removed after 2017/03/08.
"RandomShuffleQueue", "PaddingFIFOQueue", "FIFOQueue",
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index e3b7f322cb..6013b2ff51 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -146,7 +146,7 @@ class GraphConstructorTest : public ::testing::Test {
return "";
}
std::vector<string> value;
- Status s = GetNodeAttr(n->def(), kColocationAttrName, &value);
+ Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value);
if (!s.ok()) {
return "";
}
@@ -997,7 +997,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) {
}
ASSERT_TRUE(a != nullptr);
int value = 0;
- s = GetNodeAttr(a->def(), "default_int", &value);
+ s = GetNodeAttr(a->attrs(), "default_int", &value);
ASSERT_EQ(Status::OK(), s) << s << " -- " << a->def().DebugString();
EXPECT_EQ(31415, value);
}
@@ -1201,9 +1201,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) {
// Check that t1's NodeDef is consistent with graph
Node* t1 = FindNode("t1");
- ASSERT_EQ(t1->def().input_size(), 2);
- ASSERT_EQ(t1->def().input(0), "input:1");
- ASSERT_EQ(t1->def().input(1), "input:0");
+ ASSERT_EQ(t1->requested_inputs().size(), 2);
+ ASSERT_EQ(t1->requested_inputs()[0], "input:1");
+ ASSERT_EQ(t1->requested_inputs()[1], "input:0");
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
@@ -1254,19 +1254,19 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) {
// Check that NodeDefs are consistent with graph
Node* t1 = FindNode("import/t1");
- ASSERT_EQ(t1->def().input_size(), 2);
- EXPECT_EQ(t1->def().input(0), "input:0");
- EXPECT_EQ(t1->def().input(1), "input:0");
+ ASSERT_EQ(t1->requested_inputs().size(), 2);
+ EXPECT_EQ(t1->requested_inputs()[0], "input:0");
+ EXPECT_EQ(t1->requested_inputs()[1], "input:0");
Node* t2 = FindNode("import/t2");
- ASSERT_EQ(t2->def().input_size(), 2);
- EXPECT_EQ(t2->def().input(0), "import/t1:0");
- EXPECT_EQ(t2->def().input(1), "import/t1:0");
+ ASSERT_EQ(t2->requested_inputs().size(), 2);
+ EXPECT_EQ(t2->requested_inputs()[0], "import/t1:0");
+ EXPECT_EQ(t2->requested_inputs()[1], "import/t1:0");
Node* t3 = FindNode("import/t3");
- ASSERT_EQ(t3->def().input_size(), 2);
- EXPECT_EQ(t3->def().input(0), "import/unmapped_input:0");
- EXPECT_EQ(t3->def().input(1), "import/unmapped_input:1");
+ ASSERT_EQ(t3->requested_inputs().size(), 2);
+ EXPECT_EQ(t3->requested_inputs()[0], "import/unmapped_input:0");
+ EXPECT_EQ(t3->requested_inputs()[1], "import/unmapped_input:1");
}
TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) {
@@ -1795,24 +1795,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) {
// Test that node defs are consistent with graph
Node* w1 = FindNode("import/W1");
- ASSERT_EQ(w1->def().input_size(), 2);
- EXPECT_EQ(w1->def().input(0), "^W1");
- EXPECT_EQ(w1->def().input(1), "^W2");
+ ASSERT_EQ(w1->requested_inputs().size(), 2);
+ EXPECT_EQ(w1->requested_inputs()[0], "^W1");
+ EXPECT_EQ(w1->requested_inputs()[1], "^W2");
Node* input = FindNode("import/input");
- ASSERT_EQ(input->def().input_size(), 2);
- EXPECT_EQ(input->def().input(0), "^W1");
- EXPECT_EQ(input->def().input(1), "^W2");
+ ASSERT_EQ(input->requested_inputs().size(), 2);
+ EXPECT_EQ(input->requested_inputs()[0], "^W1");
+ EXPECT_EQ(input->requested_inputs()[1], "^W2");
Node* input2 = FindNode("import/input2");
- ASSERT_EQ(input2->def().input_size(), 2);
- EXPECT_EQ(input2->def().input(0), "^W1");
- EXPECT_EQ(input2->def().input(1), "^W2");
+ ASSERT_EQ(input2->requested_inputs().size(), 2);
+ EXPECT_EQ(input2->requested_inputs()[0], "^W1");
+ EXPECT_EQ(input2->requested_inputs()[1], "^W2");
Node* t1 = FindNode("import/t1");
- ASSERT_EQ(t1->def().input_size(), 2);
- EXPECT_EQ(t1->def().input(0), "import/input:0");
- EXPECT_EQ(t1->def().input(1), "import/input:1");
+ ASSERT_EQ(t1->requested_inputs().size(), 2);
+ EXPECT_EQ(t1->requested_inputs()[0], "import/input:0");
+ EXPECT_EQ(t1->requested_inputs()[1], "import/input:1");
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
@@ -1856,15 +1856,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) {
// Test that node defs are consistent with graph
Node* merge = FindNode("merge");
- ASSERT_EQ(merge->def().input_size(), 3);
- EXPECT_EQ(merge->def().input(0), "input:0");
- EXPECT_EQ(merge->def().input(1), "t1:0");
- EXPECT_EQ(merge->def().input(2), "^W1");
+ ASSERT_EQ(merge->requested_inputs().size(), 3);
+ EXPECT_EQ(merge->requested_inputs()[0], "input:0");
+ EXPECT_EQ(merge->requested_inputs()[1], "t1:0");
+ EXPECT_EQ(merge->requested_inputs()[2], "^W1");
Node* t1 = FindNode("t1");
- ASSERT_EQ(t1->def().input_size(), 2);
- EXPECT_EQ(t1->def().input(0), "merge:0");
- EXPECT_EQ(t1->def().input(1), "merge:0");
+ ASSERT_EQ(t1->requested_inputs().size(), 2);
+ EXPECT_EQ(t1->requested_inputs()[0], "merge:0");
+ EXPECT_EQ(t1->requested_inputs()[1], "merge:0");
}
TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) {
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index c7ad6a1e77..57a2f399e0 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -356,7 +356,7 @@ string ControlLoopName(const string& name) {
}
bool IsControlLoop(const Node* node) {
- const string& name = node->def().name();
+ const string& name = node->name();
return StringPiece(name).starts_with("_cloop");
}
@@ -468,7 +468,7 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
const string& device_name = edge->dst()->assigned_device_name();
const string& frame_name = src_info.frame_name;
int parallel_iterations;
- status = GetNodeAttr(src_info.frame->def(), "parallel_iterations",
+ status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
&parallel_iterations);
if (!status.ok()) return status;
@@ -903,11 +903,11 @@ Status Partition(const PartitionOptions& opts, Graph* g,
send_start_time = opts.start_times[src->id()].value();
recv_start_time = opts.start_times[dst->id()].value();
} else {
- status = GetNodeAttr(src->def(), "_start_time", &send_start_time);
+ status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
if (!status.ok()) {
return status;
}
- status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time);
+ status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
if (!status.ok()) {
return status;
}
diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc
index 4afc878f76..89784c631f 100644
--- a/tensorflow/core/graph/graph_test.cc
+++ b/tensorflow/core/graph/graph_test.cc
@@ -318,21 +318,21 @@ TEST_F(GraphTest, AddAttr) {
n1->AddAttr("_a", "new_attr");
string attr;
- EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr));
+ EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
EXPECT_EQ("new_attr", attr);
Node* n2 = graph_.CopyNode(n1);
n1->AddAttr("_b", "new_attr_2");
- EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr));
+ EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr));
EXPECT_EQ("new_attr", attr);
- EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_b", &attr));
+ EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr));
EXPECT_EQ("new_attr_2", attr);
- EXPECT_EQ(Status::OK(), GetNodeAttr(n2->def(), "_a", &attr));
+ EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr));
EXPECT_EQ("new_attr", attr);
- EXPECT_NE(Status::OK(), GetNodeAttr(n2->def(), "_b", &attr));
+ EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr));
}
// Convert edge iteration results into a sorted string.
diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc
index a679eac0e7..a22a9b3fa3 100644
--- a/tensorflow/core/graph/optimizer_cse.cc
+++ b/tensorflow/core/graph/optimizer_cse.cc
@@ -56,11 +56,9 @@ class OptimizerCSE {
bool Optimize(const std::function<bool(const Node*)>& consider_fn);
private:
- struct Scratch;
-
static size_t NodeHash(const Node* n);
- static bool Equivalent(const Node* a, const Node* b, Scratch* s);
- static bool EqualAttrs(const Node* a, const Node* b, Scratch* s);
+ static bool Equivalent(const Node* a, const Node* b,
+ AttrSlice::Scratch* scratch);
Graph* g_;
};
@@ -110,7 +108,7 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
// Hash the attrs. For example, this makes sure different constants
// end up in different hash buckets.
string tmp;
- for (const auto& attr : n->def().attr()) {
+ for (const auto& attr : n->attrs()) {
tmp = attr.first;
attr.second.AppendToString(&tmp);
// Add hashes of attrs, so the order of attrs doesn't matter.
@@ -122,28 +120,6 @@ size_t OptimizerCSE::NodeHash(const Node* n) {
return h;
}
-struct OptimizerCSE::Scratch {
- // For EqualAttrs():
- string a;
- string b;
-};
-
-bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) {
- if (a->def().attr_size() != b->def().attr_size()) return false;
-
- for (const auto& attr : b->def().attr()) {
- auto iter = a->def().attr().find(attr.first);
- if (iter == a->def().attr().end()) return false;
- // Note: it should be safe to compare proto serializations of the attr
- // values since at most one field should be set in each (indeed, it
- // should be the same field).
- iter->second.SerializeToString(&scratch->a);
- attr.second.SerializeToString(&scratch->b);
- if (scratch->a != scratch->b) return false;
- }
- return true;
-}
-
static bool HasRefInput(const Node* n) {
for (auto dt : n->input_types()) {
if (IsRefType(dt)) return true;
@@ -151,7 +127,8 @@ static bool HasRefInput(const Node* n) {
return false;
}
-bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
+bool OptimizerCSE::Equivalent(const Node* a, const Node* b,
+ AttrSlice::Scratch* scratch) {
// Different op names are different
if (a->type_string() != b->type_string()) return false;
@@ -164,7 +141,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) {
// Compare attrs. Note that equal attrs implies equal input and
// output types.
- if (!EqualAttrs(a, b, scratch)) return false;
+ if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false;
// Compare input sources
if (a->num_inputs() != b->num_inputs()) return false;
@@ -206,7 +183,7 @@ bool OptimizerCSE::Optimize(
// Scratch space for Equivalent calls. Allocated here and passed in to
// Equivalent to avoid allocation inside the loop below.
bool changed = false;
- Scratch scratch;
+ AttrSlice::Scratch scratch;
for (Node* n : order) {
if (!n->IsOp()) continue;
diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc
index e3ef5e2f0c..4a479d3258 100644
--- a/tensorflow/core/graph/quantize_training.cc
+++ b/tensorflow/core/graph/quantize_training.cc
@@ -192,9 +192,9 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
Tensor tensor_names;
Tensor shape_and_slices;
TF_RETURN_IF_ERROR(
- GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names));
- TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value",
- &shape_and_slices));
+ GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
int tn_size = tensor_names.NumElements();
int var_size = added_variables.size();
diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc
index 9cbb928c11..d817d980de 100644
--- a/tensorflow/core/graph/quantize_training_test.cc
+++ b/tensorflow/core/graph/quantize_training_test.cc
@@ -112,17 +112,15 @@ TEST_F(QuantizeTrainingTest, SignedInput) {
TF_ASSERT_OK(
FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"),
&identity_q_node));
- NodeDef identity_q = identity_q_node->def();
ASSERT_EQ("true",
- SummarizeAttrValue(identity_q.attr().find("signed_input")->second));
+ SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input")));
// Quantize_and_dequantize node for relu should have signed_input==false.
Node* relu_q_node;
TF_ASSERT_OK(
FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
&relu_q_node));
- NodeDef relu_q = relu_q_node->def();
ASSERT_EQ("false",
- SummarizeAttrValue(relu_q.attr().find("signed_input")->second));
+ SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input")));
}
TEST_F(QuantizeTrainingTest, RangeGivenTrue) {
@@ -165,17 +163,15 @@ TEST_F(QuantizeTrainingTest, RangeGivenTrue) {
TF_ASSERT_OK(
FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"),
&relu6_q_node));
- NodeDef identity_q = relu6_q_node->def();
ASSERT_EQ("true",
- SummarizeAttrValue(identity_q.attr().find("range_given")->second));
+ SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given")));
// Quantize_and_dequantize node for relu should have range_given==true.
Node* relu_q_node;
TF_ASSERT_OK(
FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"),
&relu_q_node));
- NodeDef relu_q = relu_q_node->def();
ASSERT_EQ("true",
- SummarizeAttrValue(relu_q.attr().find("range_given")->second));
+ SummarizeAttrValue(*relu_q_node->attrs().Find("range_given")));
}
TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) {
diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc
index 9849d9a159..e10b692889 100644
--- a/tensorflow/core/graph/subgraph.cc
+++ b/tensorflow/core/graph/subgraph.cc
@@ -106,7 +106,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
// Copy the _output_shapes from the original node to the feed node,
// if any.
std::vector<PartialTensorShape> output_shapes;
- if (GetNodeAttr(n->def(), "_output_shapes", &output_shapes).ok()) {
+ if (GetNodeAttr(n->attrs(), "_output_shapes", &output_shapes).ok()) {
if (n->num_outputs() != output_shapes.size()) {
return errors::InvalidArgument(
"FeedInputs: ", t,
@@ -129,8 +129,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
if (e->src_output() == id.second) {
to_remove.emplace_back(e);
} else if (e->src_output() == Graph::kControlSlot &&
- (n->def().op() == "Placeholder" ||
- n->def().op() == "PlaceholderV2")) {
+ (n->type_string() == "Placeholder" ||
+ n->type_string() == "PlaceholderV2")) {
// When feeding a Placeholder node, any outgoing control edges
// will be replaced with a control edge from the replacement
// recv_node.
diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc
index 3dc11b7a16..93dcfd5e33 100644
--- a/tensorflow/core/graph/subgraph_test.cc
+++ b/tensorflow/core/graph/subgraph_test.cc
@@ -81,7 +81,7 @@ class SubgraphTest : public ::testing::Test {
for (const string& s : expected_nodes) {
Node* n = FindNode(s);
EXPECT_TRUE(n != nullptr) << s;
- if (n->def().op() == "_Send" || n->def().op() == "_Recv") {
+ if (n->type_string() == "_Send" || n->type_string() == "_Recv") {
EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s;
}
}
@@ -367,7 +367,7 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
for (Node* node : graph()->nodes()) {
if (node->name() == "_recv_input_1") {
std::vector<PartialTensorShape> shapes;
- TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes));
+ TF_ASSERT_OK(GetNodeAttr(node->attrs(), "_output_shapes", &shapes));
ASSERT_EQ(1, shapes.size());
EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0]));
break;