aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-05-24 19:12:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-24 19:15:01 -0700
commitb59833c3fd91511b33255369016868e4ae6cda2e (patch)
treeecbd70cfd3abb5d934f6eb4b7280a35e8589f5cf /tensorflow/contrib/tensorrt/segment
parent2b99d9cbc7166efedaff9eee11744348da30fc8a (diff)
Merge changes from github.
Revert #18413. Too many internal test failures due to the name scope change caused by this change. Revert #18192. Cannot use re2::StringPiece internally. Need alternative for set call. Will pull and clean this up in a separate change. PiperOrigin-RevId: 197991247
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc379
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h18
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc16
3 files changed, 347 insertions, 66 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 8fc4697c51..cc42913eca 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -25,18 +25,239 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace tensorrt {
namespace segment {
+using ::tensorflow::strings::StrAppend;
+// A simple graph representation to mirror tensorflow::Graph. This structure
+// helps saving memory since segmenter modifies the graph in place, preventing
+// the need to create a copy of the graph. It is composed of edges and nodes.
+// Nodes keep pointers to original TF nodes.
+class SimpleNode;
+class SimpleGraph;
+class SimpleEdge {
+ public:
+ SimpleEdge(int id, SimpleNode* src, int src_port, SimpleNode* dst,
+ int dst_port, bool is_control = false)
+ : id_(id),
+ src_(src),
+ src_port_(src_port),
+ dst_(dst),
+ dst_port_(dst_port),
+ control_(is_control) {}
+ ~SimpleEdge() {}
+
+ SimpleNode* src() const { return src_; }
+ SimpleNode* dst() const { return dst_; }
+ int src_output() const { return src_port_; }
+ int dst_input() const { return dst_port_; }
+ int id() const { return id_; }
+ bool IsControlEdge() const { return control_; }
+
+ private:
+ int id_;
+ SimpleNode* src_;
+ int src_port_;
+ SimpleNode* dst_;
+ int dst_port_;
+ bool control_;
+};
+
+class SimpleNode {
+ public:
+ SimpleNode(const tensorflow::Node* node, const int id);
+
+ const std::vector<SimpleEdge*>& in_edges() const { return in_edges_; }
+ const std::vector<SimpleEdge*>& out_edges() const { return out_edges_; }
+ std::vector<SimpleNode*> in_nodes() const {
+ std::vector<SimpleNode*> res;
+ res.reserve(in_edges_.size());
+ for (const auto e : in_edges_) {
+ if (e) res.push_back(e->src());
+ }
+ return res;
+ }
+ const string& name() const { return node_->name(); }
+ const tensorflow::Node* tf_node() const { return node_; }
+ int id() const { return id_; }
+
+ private:
+ const tensorflow::Node* node_;
+ std::vector<SimpleEdge*> in_edges_;
+ std::vector<SimpleEdge*> out_edges_;
+ int id_;
+
+ friend class SimpleGraph;
+};
+
+class SimpleGraph {
+ public:
+ explicit SimpleGraph(const tensorflow::Graph* g);
+ ~SimpleGraph();
+
+ void AddControlEdge(SimpleNode* src, SimpleNode* dst);
+ void AddEdge(SimpleNode* src, int out_port, SimpleNode* dst, int in_port);
+ void RemoveEdge(const SimpleEdge*);
+ SimpleNode* FindNodeId(int node_id) {
+ if (node_id < 0 || node_id > static_cast<int>(nodes_.size())) {
+ return nullptr;
+ }
+ return nodes_[node_id];
+ }
+ int num_node_ids() const { return nodes_.size(); }
+ const SimpleNode* source_node() const {
+ return nodes_[tensorflow::Graph::kSourceId];
+ }
+ const SimpleNode* sink_node() const {
+ return nodes_[tensorflow::Graph::kSinkId];
+ }
+
+ private:
+ const tensorflow::Graph* g_;
+ std::vector<SimpleNode*> nodes_;
+ std::vector<SimpleEdge*> edges_;
+ // free_edge_ids_ and free_node_ids_ contain freed indices.
+ std::set<int> free_edge_ids_;
+ std::set<int> free_node_ids_;
+};
+
+SimpleNode::SimpleNode(const tensorflow::Node* node, const int id)
+ : node_(node), id_(id) {
+ if (node_) {
+ in_edges_.reserve(node_->in_edges().size());
+ out_edges_.reserve(node_->out_edges().size());
+ }
+}
+
+SimpleGraph::SimpleGraph(const tensorflow::Graph* g) : g_(g) {
+ int n_nodes = g_->num_node_ids();
+ nodes_.resize(n_nodes, nullptr);
+ nodes_[g->kSourceId] = new SimpleNode(g->source_node(), g->kSourceId);
+ nodes_[g->kSinkId] = new SimpleNode(g->sink_node(), g->kSinkId);
+ int n_edges = g->num_edge_ids();
+ edges_.resize(n_edges, nullptr);
+ for (int i = 2; i < n_nodes; i++) {
+ const auto n = g->FindNodeId(i);
+ if (n) {
+ nodes_[i] = new SimpleNode(n, i);
+ } else {
+ free_node_ids_.insert(i);
+ }
+ }
+ for (int i = 0; i < n_edges; i++) {
+ const auto e = g->FindEdgeId(i);
+ if (e) {
+ const auto tfsrc = e->src();
+ const auto tfdst = e->dst();
+ bool is_control = e->IsControlEdge();
+ auto src = nodes_[tfsrc->id()];
+ auto dst = nodes_[tfdst->id()];
+ auto edge = new SimpleEdge(i, src, e->src_output(), dst, e->dst_input(),
+ is_control);
+ edges_[i] = edge;
+ src->out_edges_.push_back(edge);
+ dst->in_edges_.push_back(edge);
+ } else {
+ free_edge_ids_.insert(i);
+ }
+ }
+}
+
+void SimpleGraph::AddEdge(SimpleNode* src, int out_port, SimpleNode* dst,
+ int in_port) {
+ int i = edges_.size();
+ if (!free_edge_ids_.empty()) {
+ auto it = free_edge_ids_.begin();
+ i = *it;
+ free_edge_ids_.erase(it);
+ } else {
+ edges_.push_back(nullptr);
+ }
+ bool is_control = (out_port == tensorflow::Graph::kControlSlot);
+ is_control |= (in_port == tensorflow::Graph::kControlSlot);
+ auto edge = new SimpleEdge(i, src, out_port, dst, in_port, is_control);
+ edges_[i] = edge;
+ src->out_edges_.push_back(edge);
+ dst->in_edges_.push_back(edge);
+}
+
+void SimpleGraph::AddControlEdge(SimpleNode* src, SimpleNode* dst) {
+ AddEdge(src, tensorflow::Graph::kControlSlot, dst,
+ tensorflow::Graph::kControlSlot);
+}
+
+void SimpleGraph::RemoveEdge(const SimpleEdge* edge) {
+ auto src = edge->src();
+ auto dst = edge->dst();
+ for (auto it = src->out_edges_.begin(); it != src->out_edges_.end(); ++it) {
+ if (*it == edge) {
+ src->out_edges_.erase(it);
+ break;
+ }
+ }
+ for (auto it = dst->in_edges_.begin(); it != dst->in_edges_.end(); ++it) {
+ if (*it == edge) {
+ dst->in_edges_.erase(it);
+ break;
+ }
+ }
+}
+
+SimpleGraph::~SimpleGraph() {
+ for (auto x : nodes_) delete x;
+ for (auto x : edges_) delete x;
+}
namespace {
-bool CanContractEdge(const tensorflow::Edge* edge,
- const tensorflow::Graph& graph) {
- const tensorflow::Node* src = edge->src();
- const tensorflow::Node* dst = edge->dst();
+bool CheckCycles(const std::unique_ptr<SimpleGraph>& g, const SimpleNode* src,
+ const std::vector<SimpleNode*>& start) {
+ // copied from TF ReverseDFS.
+ struct Work {
+ SimpleNode* node;
+ bool leave; // Are we entering or leaving n?
+ };
+
+ std::vector<Work> stack(start.size());
+ for (int i = 0; i < start.size(); ++i) {
+ stack[i] = Work{start[i], false};
+ }
+
+ std::vector<bool> visited(g->num_node_ids(), false);
+ while (!stack.empty()) {
+ Work w = stack.back();
+ stack.pop_back();
+
+ auto n = w.node;
+ if (w.leave) {
+ if (n == src) {
+ return true;
+ }
+ continue;
+ }
+
+ if (visited[n->id()]) continue;
+ visited[n->id()] = true;
+ // Arrange to call leave(n) when all done with descendants.
+ stack.push_back(Work{n, true});
+
+ auto nodes = n->in_nodes();
+ for (const auto node : nodes) {
+ if (!visited[node->id()]) {
+ stack.push_back(Work{node, false});
+ }
+ }
+ }
+ return false;
+}
+
+bool CanContractEdge(const SimpleEdge* edge,
+ const std::unique_ptr<SimpleGraph>& graph) {
+ const auto src = edge->src();
+ const auto dst = edge->dst();
// Can't contract edge if doing so would cause a cycle in the
// graph. So, if there is a directed path from 'src' to 'dst', other
@@ -48,46 +269,38 @@ bool CanContractEdge(const tensorflow::Edge* edge,
// 1. Get all nodes incoming to 'dst', excluding 'src'
// 2. Reverse DFS from those nodes
// 3. If reverse DFS reaches 'src' then we have a cycle
- std::vector<tensorflow::Node*> dfs_start_nodes;
- for (tensorflow::Node* node : dst->in_nodes()) {
+ std::vector<SimpleNode*> dfs_start_nodes;
+ for (SimpleNode* node : dst->in_nodes()) {
if (node != src) {
dfs_start_nodes.push_back(node);
}
}
- bool is_cycle = false;
- if (!dfs_start_nodes.empty()) {
- tensorflow::ReverseDFSFrom(graph, dfs_start_nodes, {},
- [&is_cycle, src](tensorflow::Node* node) {
- if (node == src) {
- is_cycle = true;
- }
- });
- }
-
+ bool is_cycle = CheckCycles(graph, src, dfs_start_nodes);
return !is_cycle;
}
+} // namespace
-void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
- std::vector<const tensorflow::Edge*>* remove_edges) {
+void ContractEdge(SimpleEdge* edge, SimpleGraph* graph,
+ std::vector<const SimpleEdge*>* remove_edges) {
// Transfer all inputs and outputs of 'dst' to 'src' except edges
// connecting the two.
- tensorflow::Node* src = edge->src();
- tensorflow::Node* dst = edge->dst();
+ auto src = edge->src();
+ auto dst = edge->dst();
// We can use '0' for input/output index because we don't need them
// to be accurate for the way we are using the graph.
- std::vector<const tensorflow::Edge*> in_edges(dst->in_edges().begin(),
- dst->in_edges().end());
- for (const tensorflow::Edge* in_edge : in_edges) {
+ std::vector<const SimpleEdge*> in_edges(dst->in_edges().begin(),
+ dst->in_edges().end());
+ for (const SimpleEdge* in_edge : in_edges) {
if (in_edge->IsControlEdge()) {
if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
graph->AddControlEdge(e->src(), src);
}
} else {
if (in_edge->src() != src) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(in_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(in_edge);
if (e->src() == graph->source_node()) {
graph->AddEdge(e->src(), e->src_output(), src,
tensorflow::Graph::kControlSlot);
@@ -98,14 +311,14 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
}
}
- std::vector<const tensorflow::Edge*> out_edges(dst->out_edges().begin(),
- dst->out_edges().end());
- for (const tensorflow::Edge* out_edge : out_edges) {
+ std::vector<const SimpleEdge*> out_edges(dst->out_edges().begin(),
+ dst->out_edges().end());
+ for (const SimpleEdge* out_edge : out_edges) {
if (out_edge->IsControlEdge()) {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
graph->AddControlEdge(src, e->dst());
} else {
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(out_edge);
+ SimpleEdge* e = const_cast<SimpleEdge*>(out_edge);
if (e->dst() == graph->sink_node()) {
VLOG(1) << " edge to sink node " << src->name() << " -> "
<< e->dst()->name();
@@ -128,8 +341,6 @@ void ContractEdge(tensorflow::Edge* edge, tensorflow::Graph* graph,
}
}
-} // namespace
-
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
@@ -140,17 +351,22 @@ tensorflow::Status SegmentGraph(
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), gdef, &graph));
+ return SegmentGraph(&graph, candidate_fn, options, segments);
+}
- // tensorflow::DumpGraph("Pre-Segment", &graph);
-
+tensorflow::Status SegmentGraph(
+ tensorflow::Graph* tf_graph,
+ const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments) {
+ auto graph = std::unique_ptr<SimpleGraph>(new SimpleGraph(tf_graph));
// Use a union-find to collect the nodes that belong to the same
- // segment. A node value of nullptr indicates that the node is not a
- // candidate for TRT.
- std::vector<UnionFind<tensorflow::Node*>> node_segments;
- for (int i = 0; i < graph.num_node_ids(); ++i) {
- tensorflow::Node* node = graph.FindNodeId(i);
+ // segment. A node value of nullptr indicates that the node is not a candidate
+ // for TRT.
+ std::vector<UnionFind<SimpleNode*>> node_segments;
+ for (int i = 0; i < graph->num_node_ids(); ++i) {
+ SimpleNode* node = graph->FindNodeId(i);
if (options.exclude_node_list.count(node->name()) != 0 ||
- !candidate_fn(node)) {
+ !candidate_fn(node->tf_node())) {
node = nullptr;
}
node_segments.emplace_back(node);
@@ -164,10 +380,16 @@ tensorflow::Status SegmentGraph(
// a measure of how beneficial it is to include a given node in a
// TRT subgraph then we can revisit this algorithm to take advantage
// of that information.
- std::vector<tensorflow::Node*> order;
- tensorflow::GetPostOrder(graph, &order);
-
- for (const tensorflow::Node* node : order) {
+ std::vector<tensorflow::Node*> tforder;
+ tensorflow::GetPostOrder(*tf_graph, &tforder);
+ // use postorder implementation from tensorflow and construct mirror in
+ // internal format
+ std::vector<SimpleNode*> order;
+ order.reserve(tforder.size());
+ for (const auto tfnode : tforder) {
+ order.push_back(graph->FindNodeId(tfnode->id()));
+ }
+ for (const SimpleNode* node : order) {
// All output nodes of 'node' have been visited...
VLOG(2) << "Trying node " << node->name() << " id=" << node->id();
@@ -181,8 +403,8 @@ tensorflow::Status SegmentGraph(
// nodes. Iterate since combining two nodes may unblock other
// combining.
while (true) {
- std::set<const tensorflow::Edge*> contract_edges;
- for (const tensorflow::Edge* out_edge : node->out_edges()) {
+ std::set<const SimpleEdge*> contract_edges;
+ for (const SimpleEdge* out_edge : node->out_edges()) {
VLOG(2) << "... out node " << out_edge->dst()->name() << " ( "
<< out_edge->dst()->id() << " <- " << node->id() << " )";
if (out_edge->IsControlEdge()) {
@@ -210,9 +432,9 @@ tensorflow::Status SegmentGraph(
// Contract edges and collect the adjacent nodes into the same
// segment/subgraph.
while (!contract_edges.empty()) {
- const tensorflow::Edge* contract_edge = *contract_edges.begin();
- const tensorflow::Node* src = contract_edge->src();
- const tensorflow::Node* dst = contract_edge->dst();
+ const SimpleEdge* contract_edge = *contract_edges.begin();
+ const SimpleNode* src = contract_edge->src();
+ const SimpleNode* dst = contract_edge->dst();
VLOG(2) << "Merge " << src->name() << " <- " << dst->name() << " ("
<< src->id() << " <- " << dst->id();
@@ -221,13 +443,13 @@ tensorflow::Status SegmentGraph(
// Contracting the edge leaves disconnected graph edges.
// Remove these from the graph and from 'contract_edges' so we
// don't visit them again.
- tensorflow::Edge* e = const_cast<tensorflow::Edge*>(contract_edge);
- std::vector<const tensorflow::Edge*> remove_edges;
- ContractEdge(e, &graph, &remove_edges);
+ SimpleEdge* e = const_cast<SimpleEdge*>(contract_edge);
+ std::vector<const SimpleEdge*> remove_edges;
+ ContractEdge(e, graph.get(), &remove_edges);
- for (const tensorflow::Edge* r : remove_edges) {
+ for (const SimpleEdge* r : remove_edges) {
contract_edges.erase(r);
- graph.RemoveEdge(r);
+ graph->RemoveEdge(r);
}
}
}
@@ -236,9 +458,27 @@ tensorflow::Status SegmentGraph(
// Collect the segments/subgraphs. Each subgraph is represented by a
// set of the names of the nodes in that subgraph.
std::unordered_map<string, std::set<string>> sg_map;
+ std::unordered_map<string, std::set<string>> device_maps;
for (auto& u : node_segments) {
if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) {
sg_map[u.ParentValue()->name()].insert(u.Value()->name());
+ auto tf_node = u.Value()->tf_node();
+ // has_assigned_device_name() is expected to return true
+ // when called from optimization pass. However, since graph
+ // is converted back and forth between graph and graphdef,
+ // assigned devices demoted to requested devices. If the graph
+ // is passed directly to this module, assigned devices will be set.
+ if (tf_node->has_assigned_device_name()) {
+ device_maps[u.ParentValue()->name()].insert(
+ tf_node->assigned_device_name());
+ } else if (!tf_node->requested_device().empty()) {
+ device_maps[u.ParentValue()->name()].insert(
+ tf_node->requested_device());
+ } else {
+ VLOG(1) << "Node " << tf_node->name()
+ << " has no device assigned requested device is: "
+ << tf_node->requested_device();
+ }
}
}
@@ -260,10 +500,35 @@ tensorflow::Status SegmentGraph(
<< segment_node_names.size() << " nodes, dropping";
continue;
}
-
- segments->emplace_back(segment_node_names);
+ // TODO(sami): Make segmenter placement aware once trtscopes are in place
+ const auto& dev_itr = device_maps.find(itr.first);
+ if (dev_itr == device_maps.end() || dev_itr->second.empty()) {
+ VLOG(1) << "No device assigned to segment " << segments->size();
+ segments->emplace_back(std::make_pair(segment_node_names, string()));
+ } else if (dev_itr->second.size() > 1) {
+ string s("Segment ");
+ StrAppend(&s, segments->size(), " has multiple devices attached: ");
+ for (const auto& dev : dev_itr->second) {
+ StrAppend(&s, dev, ", ");
+ }
+ LOG(WARNING) << s << " choosing " << *(dev_itr->second.begin());
+ segments->emplace_back(
+ std::make_pair(segment_node_names, *(dev_itr->second.begin())));
+ } else {
+ segments->emplace_back(
+ std::make_pair(segment_node_names, *(dev_itr->second.begin())));
+ }
+ }
+ if (VLOG_IS_ON(1)) {
+ for (const auto& d : device_maps) {
+ string s("Segment ");
+ StrAppend(&s, ": '", d.first, "' ");
+ for (const auto& dd : d.second) {
+ StrAppend(&s, dd, ", ");
+ }
+ VLOG(1) << "Devices " << s;
+ }
}
-
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 7e8685f44a..1568dd9153 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -29,7 +29,9 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
-using SegmentNodesVector = std::vector<std::set<string>>;
+// vector of segments, each entry contains a device name and a set of nodes in
+// segment
+using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions {
// Segment must contain at least this many nodes.
@@ -51,6 +53,20 @@ tensorflow::Status SegmentGraph(
const std::function<bool(const tensorflow::Node*)>& candidate_fn,
const SegmentOptions& options, SegmentNodesVector* segments);
+// Get the subgraphs of a graph that can be handled by TensorRT.
+//
+// @param graph tensorflow::Graph of the network
+// @param candidate_fn A function that returns true for a Node* if
+// that node can be handled by TensorRT.
+// @param segments Returns the TensorRT segments/subgraphs. Each entry
+// in the vector describes a subgraph by giving a set of the names of
+// all the NodeDefs in that subgraph.
+// @return the status.
+tensorflow::Status SegmentGraph(
+ tensorflow::Graph* tf_graph,
+ const std::function<bool(const tensorflow::Node*)>& candidate_fn,
+ const SegmentOptions& options, SegmentNodesVector* segments);
+
} // namespace segment
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index 6f7655fcab..2de3923b06 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -34,7 +34,7 @@ class SegmentTest : public ::testing::Test {
TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_Status* s, const char* name);
- std::function<bool(const Node*)> MakeCandidateFn(
+ std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
const std::set<string>& node_names);
protected:
@@ -59,9 +59,9 @@ bool SegmentTest::GetGraphDef(TF_Graph* graph,
return ret;
}
-std::function<bool(const Node*)> SegmentTest::MakeCandidateFn(
+std::function<bool(const tensorflow::Node*)> SegmentTest::MakeCandidateFn(
const std::set<string>& node_names) {
- return [node_names](const Node* node) -> bool {
+ return [node_names](const tensorflow::Node* node) -> bool {
return node_names.find(node->name()) != node_names.end();
};
}
@@ -164,7 +164,7 @@ TEST_F(SegmentTest, Simple) {
ASSERT_EQ(segments.size(), 1);
std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
for (const auto& ex : expected) {
- EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
TF_DeleteGraph(graph);
@@ -277,13 +277,13 @@ TEST_F(SegmentTest, Multiple) {
std::vector<string> expected0{"add0", "add1", "add2", "add3"};
for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
std::vector<string> expected1{"add6", "add8"};
for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
<< "Missing expected node " << ex;
}
TF_DeleteGraph(graph);
@@ -347,13 +347,13 @@ TEST_F(SegmentTest, BigIfElse) {
std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].find(ex) != segments[0].end())
+ EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
<< "Missing expected node " << ex;
}
std::vector<string> expected1{"add0", "add1"};
for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].find(ex) != segments[1].end())
+ EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
<< "Missing expected node " << ex;
}
TF_DeleteGraph(graph);