diff options
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/control_flow.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph.h | 14 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_constructor_test.cc | 66 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_test.cc | 10 | ||||
-rw-r--r-- | tensorflow/core/graph/optimizer_cse.cc | 37 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/quantize_training_test.cc | 12 | ||||
-rw-r--r-- | tensorflow/core/graph/subgraph.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/graph/subgraph_test.cc | 4 |
12 files changed, 80 insertions, 93 deletions
diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc index 8409fb4cd0..db6683d1e7 100644 --- a/tensorflow/core/graph/control_flow.cc +++ b/tensorflow/core/graph/control_flow.cc @@ -88,7 +88,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector<ControlFlowInfo>* info) { out_info->frame = out; out_info->parent_frame = frame; TF_RETURN_IF_ERROR( - GetNodeAttr(out->def(), "frame_name", &out_info->frame_name)); + GetNodeAttr(out->attrs(), "frame_name", &out_info->frame_name)); if (out_info->frame_name.empty()) { return errors::InvalidArgument("The Enter node ", out->name(), " must have a frame name."); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index d765959ca0..9066de5668 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -78,7 +78,7 @@ string Node::DebugString() const { } else { strings::StrAppend(&ret, " op device:"); strings::StrAppend(&ret, "{", assigned_device_name_, "}"); - strings::StrAppend(&ret, " def:{", SummarizeNodeDef(def()), "}}"); + strings::StrAppend(&ret, " def:{", SummarizeNode(*this), "}}"); } return ret; } @@ -474,7 +474,7 @@ void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const { for (size_t i = 0; i < inputs.size(); ++i) { const Edge* edge = inputs[i]; if (edge == nullptr) { - node_def->add_input(node->def().input(i)); + node_def->add_input(node->requested_inputs()[i]); } else { const Node* src = edge->src(); if (!src->IsOp()) continue; diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index ac22dfc324..8554cb2f4b 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -71,6 +71,7 @@ class Node { int cost_id() const { return cost_id_; } const string& name() const { return props_->node_def_.name(); } const string& type_string() const { return props_->node_def_.op(); } + // def() provides the NodeDef the user supplied, but the specifics // of this Node may have changed due to placement, optimization, etc. // In particular: @@ -80,6 +81,7 @@ class Node { // * def().device() is the "user's requested device" and may not match // the actual assigned device, see assigned_device_name() below; // * def().attr() is authoritative. + // TODO(irving): Replace with NodeInfo. const NodeDef& def() const { return props_->node_def_; } const OpDef& op_def() const { return *props_->op_def_; } @@ -92,6 +94,10 @@ class Node { DataType output_type(int32 o) const { return props_->output_types_[o]; } const DataTypeVector& output_types() const { return props_->output_types_; } + // The device requested by the user. For the actual assigned device, + // use assigned_device_name() below. + const string& requested_device() const { return def().device(); } + // This gives the device the runtime has assigned this node to. If // you want the device the user requested, use def().device() instead. // TODO(josh11b): Validate that the assigned_device, if not empty: @@ -103,6 +109,14 @@ class Node { assigned_device_name_ = device_name; } + // Read only access to attributes + AttrSlice attrs() const { return AttrSlice(def()); } + + // Inputs requested by the NodeDef. For the actual inputs, use in_edges. + const protobuf::RepeatedPtrField<string>& requested_inputs() const { + return def().input(); + } + // Get the neighboring nodes via edges either in or out of this node. gtl::iterator_range<NeighborIter> in_nodes() const; gtl::iterator_range<NeighborIter> out_nodes() const; diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 9d4a0a52f7..70087b8fe1 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -424,7 +424,7 @@ Status GraphConstructor::ValidateShape(Node* node) { // For nodes with the _output_shapes atttribute, override the shape. std::vector<TensorShapeProto> shape_attrs; const char* kAttrName = "_output_shapes"; - if (!GetNodeAttr(node->def(), kAttrName, &shape_attrs).ok()) { + if (!GetNodeAttr(node->attrs(), kAttrName, &shape_attrs).ok()) { // No _output_shapes attribute, the AddNode call above was sufficient. return Status::OK(); } @@ -458,7 +458,7 @@ Status GraphConstructor::ValidateShape(Node* node) { // functions that are not critical to correct execution but // would cause graphs to fail if imported after correcting. // - const string& op = node->def().op(); + const string& op = node->type_string(); const std::vector<string> whitelist = { // To be removed after 2017/03/08. "RandomShuffleQueue", "PaddingFIFOQueue", "FIFOQueue", diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index e3b7f322cb..6013b2ff51 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -146,7 +146,7 @@ class GraphConstructorTest : public ::testing::Test { return ""; } std::vector<string> value; - Status s = GetNodeAttr(n->def(), kColocationAttrName, &value); + Status s = GetNodeAttr(n->attrs(), kColocationAttrName, &value); if (!s.ok()) { return ""; } @@ -997,7 +997,7 @@ TEST_F(GraphConstructorTest, ImportGraphDef_DefaultAttrs) { } ASSERT_TRUE(a != nullptr); int value = 0; - s = GetNodeAttr(a->def(), "default_int", &value); + s = GetNodeAttr(a->attrs(), "default_int", &value); ASSERT_EQ(Status::OK(), s) << s << " -- " << a->def().DebugString(); EXPECT_EQ(31415, value); } @@ -1201,9 +1201,9 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMap) { // Check that t1's NodeDef is consistent with graph Node* t1 = FindNode("t1"); - ASSERT_EQ(t1->def().input_size(), 2); - ASSERT_EQ(t1->def().input(0), "input:1"); - ASSERT_EQ(t1->def().input(1), "input:0"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + ASSERT_EQ(t1->requested_inputs()[0], "input:1"); + ASSERT_EQ(t1->requested_inputs()[1], "input:0"); } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) { @@ -1254,19 +1254,19 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithPrefix) { // Check that NodeDefs are consistent with graph Node* t1 = FindNode("import/t1"); - ASSERT_EQ(t1->def().input_size(), 2); - EXPECT_EQ(t1->def().input(0), "input:0"); - EXPECT_EQ(t1->def().input(1), "input:0"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + EXPECT_EQ(t1->requested_inputs()[0], "input:0"); + EXPECT_EQ(t1->requested_inputs()[1], "input:0"); Node* t2 = FindNode("import/t2"); - ASSERT_EQ(t2->def().input_size(), 2); - EXPECT_EQ(t2->def().input(0), "import/t1:0"); - EXPECT_EQ(t2->def().input(1), "import/t1:0"); + ASSERT_EQ(t2->requested_inputs().size(), 2); + EXPECT_EQ(t2->requested_inputs()[0], "import/t1:0"); + EXPECT_EQ(t2->requested_inputs()[1], "import/t1:0"); Node* t3 = FindNode("import/t3"); - ASSERT_EQ(t3->def().input_size(), 2); - EXPECT_EQ(t3->def().input(0), "import/unmapped_input:0"); - EXPECT_EQ(t3->def().input(1), "import/unmapped_input:1"); + ASSERT_EQ(t3->requested_inputs().size(), 2); + EXPECT_EQ(t3->requested_inputs()[0], "import/unmapped_input:0"); + EXPECT_EQ(t3->requested_inputs()[1], "import/unmapped_input:1"); } TEST_F(GraphConstructorTest, ImportGraphDef_InputMapWithControlEdges) { @@ -1795,24 +1795,24 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDeps) { // Test that node defs are consistent with graph Node* w1 = FindNode("import/W1"); - ASSERT_EQ(w1->def().input_size(), 2); - EXPECT_EQ(w1->def().input(0), "^W1"); - EXPECT_EQ(w1->def().input(1), "^W2"); + ASSERT_EQ(w1->requested_inputs().size(), 2); + EXPECT_EQ(w1->requested_inputs()[0], "^W1"); + EXPECT_EQ(w1->requested_inputs()[1], "^W2"); Node* input = FindNode("import/input"); - ASSERT_EQ(input->def().input_size(), 2); - EXPECT_EQ(input->def().input(0), "^W1"); - EXPECT_EQ(input->def().input(1), "^W2"); + ASSERT_EQ(input->requested_inputs().size(), 2); + EXPECT_EQ(input->requested_inputs()[0], "^W1"); + EXPECT_EQ(input->requested_inputs()[1], "^W2"); Node* input2 = FindNode("import/input2"); - ASSERT_EQ(input2->def().input_size(), 2); - EXPECT_EQ(input2->def().input(0), "^W1"); - EXPECT_EQ(input2->def().input(1), "^W2"); + ASSERT_EQ(input2->requested_inputs().size(), 2); + EXPECT_EQ(input2->requested_inputs()[0], "^W1"); + EXPECT_EQ(input2->requested_inputs()[1], "^W2"); Node* t1 = FindNode("import/t1"); - ASSERT_EQ(t1->def().input_size(), 2); - EXPECT_EQ(t1->def().input(0), "import/input:0"); - EXPECT_EQ(t1->def().input(1), "import/input:1"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + EXPECT_EQ(t1->requested_inputs()[0], "import/input:0"); + EXPECT_EQ(t1->requested_inputs()[1], "import/input:1"); } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { @@ -1856,15 +1856,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsWithCycle) { // Test that node defs are consistent with graph Node* merge = FindNode("merge"); - ASSERT_EQ(merge->def().input_size(), 3); - EXPECT_EQ(merge->def().input(0), "input:0"); - EXPECT_EQ(merge->def().input(1), "t1:0"); - EXPECT_EQ(merge->def().input(2), "^W1"); + ASSERT_EQ(merge->requested_inputs().size(), 3); + EXPECT_EQ(merge->requested_inputs()[0], "input:0"); + EXPECT_EQ(merge->requested_inputs()[1], "t1:0"); + EXPECT_EQ(merge->requested_inputs()[2], "^W1"); Node* t1 = FindNode("t1"); - ASSERT_EQ(t1->def().input_size(), 2); - EXPECT_EQ(t1->def().input(0), "merge:0"); - EXPECT_EQ(t1->def().input(1), "merge:0"); + ASSERT_EQ(t1->requested_inputs().size(), 2); + EXPECT_EQ(t1->requested_inputs()[0], "merge:0"); + EXPECT_EQ(t1->requested_inputs()[1], "merge:0"); } TEST_F(GraphConstructorTest, ImportGraphDef_ControlDepsErrors) { diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index c7ad6a1e77..57a2f399e0 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -356,7 +356,7 @@ string ControlLoopName(const string& name) { } bool IsControlLoop(const Node* node) { - const string& name = node->def().name(); + const string& name = node->name(); return StringPiece(name).starts_with("_cloop"); } @@ -468,7 +468,7 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, const string& device_name = edge->dst()->assigned_device_name(); const string& frame_name = src_info.frame_name; int parallel_iterations; - status = GetNodeAttr(src_info.frame->def(), "parallel_iterations", + status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations", ¶llel_iterations); if (!status.ok()) return status; @@ -903,11 +903,11 @@ Status Partition(const PartitionOptions& opts, Graph* g, send_start_time = opts.start_times[src->id()].value(); recv_start_time = opts.start_times[dst->id()].value(); } else { - status = GetNodeAttr(src->def(), "_start_time", &send_start_time); + status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time); if (!status.ok()) { return status; } - status = GetNodeAttr(dst->def(), "_start_time", &recv_start_time); + status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time); if (!status.ok()) { return status; } diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index 4afc878f76..89784c631f 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -318,21 +318,21 @@ TEST_F(GraphTest, AddAttr) { n1->AddAttr("_a", "new_attr"); string attr; - EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr)); EXPECT_EQ("new_attr", attr); Node* n2 = graph_.CopyNode(n1); n1->AddAttr("_b", "new_attr_2"); - EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_a", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_a", &attr)); EXPECT_EQ("new_attr", attr); - EXPECT_EQ(Status::OK(), GetNodeAttr(n1->def(), "_b", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n1->attrs(), "_b", &attr)); EXPECT_EQ("new_attr_2", attr); - EXPECT_EQ(Status::OK(), GetNodeAttr(n2->def(), "_a", &attr)); + EXPECT_EQ(Status::OK(), GetNodeAttr(n2->attrs(), "_a", &attr)); EXPECT_EQ("new_attr", attr); - EXPECT_NE(Status::OK(), GetNodeAttr(n2->def(), "_b", &attr)); + EXPECT_NE(Status::OK(), GetNodeAttr(n2->attrs(), "_b", &attr)); } // Convert edge iteration results into a sorted string. diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index a679eac0e7..a22a9b3fa3 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -56,11 +56,9 @@ class OptimizerCSE { bool Optimize(const std::function<bool(const Node*)>& consider_fn); private: - struct Scratch; - static size_t NodeHash(const Node* n); - static bool Equivalent(const Node* a, const Node* b, Scratch* s); - static bool EqualAttrs(const Node* a, const Node* b, Scratch* s); + static bool Equivalent(const Node* a, const Node* b, + AttrSlice::Scratch* scratch); Graph* g_; }; @@ -110,7 +108,7 @@ size_t OptimizerCSE::NodeHash(const Node* n) { // Hash the attrs. For example, this makes sure different constants // end up in different hash buckets. string tmp; - for (const auto& attr : n->def().attr()) { + for (const auto& attr : n->attrs()) { tmp = attr.first; attr.second.AppendToString(&tmp); // Add hashes of attrs, so the order of attrs doesn't matter. @@ -122,28 +120,6 @@ size_t OptimizerCSE::NodeHash(const Node* n) { return h; } -struct OptimizerCSE::Scratch { - // For EqualAttrs(): - string a; - string b; -}; - -bool OptimizerCSE::EqualAttrs(const Node* a, const Node* b, Scratch* scratch) { - if (a->def().attr_size() != b->def().attr_size()) return false; - - for (const auto& attr : b->def().attr()) { - auto iter = a->def().attr().find(attr.first); - if (iter == a->def().attr().end()) return false; - // Note: it should be safe to compare proto serializations of the attr - // values since at most one field should be set in each (indeed, it - // should be the same field). - iter->second.SerializeToString(&scratch->a); - attr.second.SerializeToString(&scratch->b); - if (scratch->a != scratch->b) return false; - } - return true; -} - static bool HasRefInput(const Node* n) { for (auto dt : n->input_types()) { if (IsRefType(dt)) return true; @@ -151,7 +127,8 @@ static bool HasRefInput(const Node* n) { return false; } -bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { +bool OptimizerCSE::Equivalent(const Node* a, const Node* b, + AttrSlice::Scratch* scratch) { // Different op names are different if (a->type_string() != b->type_string()) return false; @@ -164,7 +141,7 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, Scratch* scratch) { // Compare attrs. Note that equal attrs implies equal input and // output types. - if (!EqualAttrs(a, b, scratch)) return false; + if (!a->attrs().EqualAttrs(b->attrs(), scratch)) return false; // Compare input sources if (a->num_inputs() != b->num_inputs()) return false; @@ -206,7 +183,7 @@ bool OptimizerCSE::Optimize( // Scratch space for Equivalent calls. Allocated here and passed in to // Equivalent to avoid allocation inside the loop below. bool changed = false; - Scratch scratch; + AttrSlice::Scratch scratch; for (Node* n : order) { if (!n->IsOp()) continue; diff --git a/tensorflow/core/graph/quantize_training.cc b/tensorflow/core/graph/quantize_training.cc index e3ef5e2f0c..4a479d3258 100644 --- a/tensorflow/core/graph/quantize_training.cc +++ b/tensorflow/core/graph/quantize_training.cc @@ -192,9 +192,9 @@ Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op, Tensor tensor_names; Tensor shape_and_slices; TF_RETURN_IF_ERROR( - GetNodeAttr(AttrSlice(tensor_names_op->def()), "value", &tensor_names)); - TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(shape_and_slices_op->def()), "value", - &shape_and_slices)); + GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names)); + TF_RETURN_IF_ERROR( + GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices)); int tn_size = tensor_names.NumElements(); int var_size = added_variables.size(); diff --git a/tensorflow/core/graph/quantize_training_test.cc b/tensorflow/core/graph/quantize_training_test.cc index 9cbb928c11..d817d980de 100644 --- a/tensorflow/core/graph/quantize_training_test.cc +++ b/tensorflow/core/graph/quantize_training_test.cc @@ -112,17 +112,15 @@ TEST_F(QuantizeTrainingTest, SignedInput) { TF_ASSERT_OK( FindNode(g, strings::StrCat(identity->name(), "/QuantizeAndDequantizeV2"), &identity_q_node)); - NodeDef identity_q = identity_q_node->def(); ASSERT_EQ("true", - SummarizeAttrValue(identity_q.attr().find("signed_input")->second)); + SummarizeAttrValue(*identity_q_node->attrs().Find("signed_input"))); // Quantize_and_dequantize node for relu should have signed_input==false. Node* relu_q_node; TF_ASSERT_OK( FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &relu_q_node)); - NodeDef relu_q = relu_q_node->def(); ASSERT_EQ("false", - SummarizeAttrValue(relu_q.attr().find("signed_input")->second)); + SummarizeAttrValue(*relu_q_node->attrs().Find("signed_input"))); } TEST_F(QuantizeTrainingTest, RangeGivenTrue) { @@ -165,17 +163,15 @@ TEST_F(QuantizeTrainingTest, RangeGivenTrue) { TF_ASSERT_OK( FindNode(g, strings::StrCat(relu6->name(), "/QuantizeAndDequantizeV2"), &relu6_q_node)); - NodeDef identity_q = relu6_q_node->def(); ASSERT_EQ("true", - SummarizeAttrValue(identity_q.attr().find("range_given")->second)); + SummarizeAttrValue(*relu6_q_node->attrs().Find("range_given"))); // Quantize_and_dequantize node for relu should have range_given==true. Node* relu_q_node; TF_ASSERT_OK( FindNode(g, strings::StrCat(relu->name(), "/QuantizeAndDequantizeV2"), &relu_q_node)); - NodeDef relu_q = relu_q_node->def(); ASSERT_EQ("true", - SummarizeAttrValue(relu_q.attr().find("range_given")->second)); + SummarizeAttrValue(*relu_q_node->attrs().Find("range_given"))); } TEST_F(QuantizeTrainingTest, WithBackwardNodes_QuantizeAndDequantize) { diff --git a/tensorflow/core/graph/subgraph.cc b/tensorflow/core/graph/subgraph.cc index 9849d9a159..e10b692889 100644 --- a/tensorflow/core/graph/subgraph.cc +++ b/tensorflow/core/graph/subgraph.cc @@ -106,7 +106,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, // Copy the _output_shapes from the original node to the feed node, // if any. std::vector<PartialTensorShape> output_shapes; - if (GetNodeAttr(n->def(), "_output_shapes", &output_shapes).ok()) { + if (GetNodeAttr(n->attrs(), "_output_shapes", &output_shapes).ok()) { if (n->num_outputs() != output_shapes.size()) { return errors::InvalidArgument( "FeedInputs: ", t, @@ -129,8 +129,8 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info, if (e->src_output() == id.second) { to_remove.emplace_back(e); } else if (e->src_output() == Graph::kControlSlot && - (n->def().op() == "Placeholder" || - n->def().op() == "PlaceholderV2")) { + (n->type_string() == "Placeholder" || + n->type_string() == "PlaceholderV2")) { // When feeding a Placeholder node, any outgoing control edges // will be replaced with a control edge from the replacement // recv_node. diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index 3dc11b7a16..93dcfd5e33 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -81,7 +81,7 @@ class SubgraphTest : public ::testing::Test { for (const string& s : expected_nodes) { Node* n = FindNode(s); EXPECT_TRUE(n != nullptr) << s; - if (n->def().op() == "_Send" || n->def().op() == "_Recv") { + if (n->type_string() == "_Send" || n->type_string() == "_Recv") { EXPECT_EQ(device_info_.name(), n->assigned_device_name()) << s; } } @@ -367,7 +367,7 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) { for (Node* node : graph()->nodes()) { if (node->name() == "_recv_input_1") { std::vector<PartialTensorShape> shapes; - TF_ASSERT_OK(GetNodeAttr(node->def(), "_output_shapes", &shapes)); + TF_ASSERT_OK(GetNodeAttr(node->attrs(), "_output_shapes", &shapes)); ASSERT_EQ(1, shapes.size()); EXPECT_TRUE(PartialTensorShape({23}).IsIdenticalTo(shapes[0])); break; |