aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/segment/segment_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/segment/segment_test.cc')
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc473
1 files changed, 184 insertions, 289 deletions
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index f5b2d258d7..432e7b1c04 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -14,350 +14,245 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/segment/segment.h"
-#include "tensorflow/c/c_api.h"
-#include "tensorflow/core/framework/graph.pb.h"
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
namespace tensorflow {
namespace tensorrt {
namespace segment {
namespace test {
+namespace ops = ::tensorflow::ops;
class SegmentTest : public ::testing::Test {
- public:
- bool GetGraphDef(TF_Graph* graph, tensorflow::GraphDef* graph_def);
-
- TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s, const char* name);
- TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name);
-
- std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
- const std::set<string>& node_names);
-
protected:
- void PlaceholderHelper(TF_Graph* graph, TF_Status* s, const char* name,
- TF_Operation** op);
- void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name, TF_Operation** op, bool check);
-
- SegmentOptions default_options_;
-};
-
-bool SegmentTest::GetGraphDef(TF_Graph* graph,
- tensorflow::GraphDef* graph_def) {
- TF_Status* s = TF_NewStatus();
- TF_Buffer* buffer = TF_NewBuffer();
- TF_GraphToGraphDef(graph, buffer, s);
- bool ret = TF_GetCode(s) == TF_OK;
- EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- if (ret) ret = graph_def->ParseFromArray(buffer->data, buffer->length);
- TF_DeleteBuffer(buffer);
- TF_DeleteStatus(s);
- return ret;
-}
+ std::function<bool(const tensorflow::Node*)> MakeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Node* node) -> bool {
+ return node_names.find(node->name()) != node_names.end();
+ };
+ }
-std::function<bool(const tensorflow::Node*)> SegmentTest::MakeCandidateFn(
- const std::set<string>& node_names) {
- return [node_names](const tensorflow::Node* node) -> bool {
- return node_names.find(node->name()) != node_names.end();
- };
-}
+ std::function<bool(const tensorflow::Edge*)> MakeInputEdgeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Edge* in_edge) -> bool {
+ return node_names.find(in_edge->dst()->name()) != node_names.end();
+ };
+ }
-void SegmentTest::PlaceholderHelper(TF_Graph* graph, TF_Status* s,
- const char* name, TF_Operation** op) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
- TF_SetAttrType(desc, "dtype", TF_INT32);
- *op = TF_FinishOperation(desc, s);
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- ASSERT_NE(*op, nullptr);
-}
+ std::function<bool(const tensorflow::Edge*)> MakeOutputEdgeCandidateFn(
+ const std::set<string>& node_names) {
+ return [node_names](const tensorflow::Edge* out_edge) -> bool {
+ return node_names.find(out_edge->src()->name()) != node_names.end();
+ };
+ }
-TF_Operation* SegmentTest::Placeholder(TF_Graph* graph, TF_Status* s,
- const char* name) {
- TF_Operation* op;
- PlaceholderHelper(graph, s, name, &op);
- return op;
-}
+ void RunTest(const tensorflow::Graph* graph,
+ const std::set<string>& candidates,
+ const std::set<string>& input_candidates,
+ const std::set<string>& output_candidates,
+ const std::vector<std::set<string>>& expected_segments) {
+ SegmentNodesVector segments;
+ TF_EXPECT_OK(SegmentGraph(graph, MakeCandidateFn(candidates),
+ MakeInputEdgeCandidateFn(input_candidates),
+ MakeOutputEdgeCandidateFn(output_candidates),
+ default_options_, &segments));
+ ValidateSegment(segments, expected_segments);
+ }
-void SegmentTest::AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
- TF_Status* s, const char* name, TF_Operation** op,
- bool check) {
- TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name);
- TF_Output add_inputs[2] = {{l, 0}, {r, 0}};
- TF_AddInputList(desc, add_inputs, 2);
- *op = TF_FinishOperation(desc, s);
- if (check) {
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- ASSERT_NE(*op, nullptr);
+ void ValidateSegment(const SegmentNodesVector& segments,
+ const std::vector<std::set<string>>& expected_segments) {
+ EXPECT_EQ(expected_segments.size(), segments.size());
+ for (int i = 0; i < segments.size(); ++i) {
+ const auto& segment_node_names = segments[i].first;
+ const auto& expected = expected_segments[i];
+ for (const auto& name : expected) {
+ EXPECT_TRUE(segment_node_names.count(name))
+ << "Segment " << i << " is missing expected node: " << name;
+ }
+ if (segment_node_names.size() == expected.size()) continue;
+ for (const auto& name : segment_node_names) {
+ EXPECT_TRUE(expected.count(name))
+ << "Unexpected node found in segment " << i << ": " << name;
+ }
+ }
}
-}
-TF_Operation* SegmentTest::Add(TF_Operation* l, TF_Operation* r,
- TF_Graph* graph, TF_Status* s,
- const char* name) {
- TF_Operation* op;
- AddHelper(l, r, graph, s, name, &op, true);
- return op;
+ SegmentOptions default_options_;
+};
+
+std::set<string> operator-(const std::set<string>& lhs, const string& rhs) {
+ std::set<string> result = lhs;
+ CHECK(result.erase(rhs));
+ return result;
}
TEST_F(SegmentTest, Empty) {
- TF_Graph* graph = TF_NewGraph();
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(
- SegmentGraph(graph_def, MakeCandidateFn({}), default_options_, &segments),
- tensorflow::Status::OK());
-
+ Scope s = Scope::NewRootScope();
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
// Expect no segments/subgraphs.
- EXPECT_TRUE(segments.empty());
- TF_DeleteGraph(graph);
+ RunTest(&g, {}, {}, {}, {});
}
TEST_F(SegmentTest, Simple) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
// feed
- // // ||
+ // // \\
// add0 add1
- // | | /
+ // | \ /
// | add2
- // | / ||
+ // | / \\
// add3 add4
- // | /
+ // \ /
// <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
- TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(
- SegmentGraph(graph_def,
- MakeCandidateFn({"add0", "add1", "add2", "add3", "add4"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect all Add operations to be collapsed into a single segment
- ASSERT_EQ(segments.size(), 1);
- std::vector<string> expected{"add0", "add1", "add2", "add3", "add4"};
- for (const auto& ex : expected) {
- EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
- << "Missing expected node " << ex;
- }
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
+ auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
+ auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ // All Add operations are candidates, and we expect all of them to be
+ // collapsed into a single segment
+ const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4"};
+ RunTest(&g, all_adds, all_adds, all_adds, {all_adds});
+
+ // Make add1 not a candidate, and we expect all other Add operations to be
+ // collapsed into a single segment
+ auto without_add1 = all_adds - "add1";
+ RunTest(&g, without_add1, without_add1, without_add1, {without_add1});
+
+ // Make add1 not a candidate and add2 not an input candidate, and we expect
+ // add0 and add2 are removed from the segment.
+ auto without_add2 = all_adds - "add2";
+ RunTest(&g, without_add1, without_add2, without_add1, {{"add3", "add4"}});
+
+ // Making add2 not an input candidate itself won't affect anything.
+ RunTest(&g, all_adds, without_add2, all_adds, {all_adds});
+
+ // Making add1 not an input candidate.
+ RunTest(&g, all_adds, without_add1, all_adds, {without_add1});
+
+ // Making add3 not an output candidate doesn't affect anything, since it's
+ // output is sink.
+ auto without_add3 = all_adds - "add3";
+ RunTest(&g, all_adds, all_adds, without_add3, {all_adds});
}
TEST_F(SegmentTest, AvoidCycle) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
- // add2 is not a TRT candidate so add0/add3 cannot be formed as a
- // subgraph
- //
// feed
- // // ||
+ // // \\
// add0 add1
- // | | /
+ // | \ /
// | add2
- // | / ||
+ // | / \\
// add3 add4
- // | /
+ // \ /
// <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
- TF_Operation* add4 = Add(add2, add2, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(
- SegmentGraph(graph_def, MakeCandidateFn({"add0", "add1", "add3", "add4"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect no subgraphs
- EXPECT_EQ(segments.size(), 0);
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
+ auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
+ auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add2, add2);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ // add2 is not a TRT candidate so there should be no segments generated.
+ const std::set<string> without_add2 = {"add0", "add1", "add3", "add4"};
+ RunTest(&g, without_add2, without_add2, without_add2, {});
}
TEST_F(SegmentTest, Multiple) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
- // add5 is not a TRT candidate so two subgraphs should be formed
- //
- // feed
- // // || ||
- // add0 add1 add7
- // | | / / ||
- // | add2-----add5 add8
- // | / | | | |
- // add3 add4 add6
- // | | /
- // <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(feed, feed, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add7 = Add(feed, feed, graph, s, "add7");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add0, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add5 = Add(add2, add7, graph, s, "add5");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add8 = Add(add7, add7, graph, s, "add8");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add0, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add3"), string(TF_OperationName(add3)));
- TF_Operation* add4 = Add(add2, add5, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add4"), string(TF_OperationName(add4)));
- TF_Operation* add6 = Add(add5, add8, graph, s, "add6");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add6"), string(TF_OperationName(add6)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(SegmentGraph(graph_def,
- MakeCandidateFn({"add0", "add1", "add2", "add3",
- "add4", "add6", "add7", "add8"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect two subgraphs
- EXPECT_EQ(segments.size(), 2);
-
- std::vector<string> expected0{"add6", "add8"};
- for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
- << "Missing expected node " << ex;
- }
-
- std::vector<string> expected1{"add0", "add1", "add2", "add3"};
- for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
- << "Missing expected node " << ex;
- }
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ // feed
+ // // || \\
+ // add0 add1 add7
+ // | \ / / \\
+ // | add2 / \\
+ // | || \ | ||
+ // | || add5 add8
+ // | / \ / \ /
+ // add3 add4 add6
+ // \ | /
+ // <sink>
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), feed, feed);
+ auto add7 = ops::Add(s.WithOpName("add7"), feed, feed);
+ auto add2 = ops::Add(s.WithOpName("add2"), add0, add1);
+ auto add5 = ops::Add(s.WithOpName("add5"), add2, add7);
+ auto add8 = ops::Add(s.WithOpName("add8"), add7, add7);
+ auto add3 = ops::Add(s.WithOpName("add3"), add0, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add2, add5);
+ auto add6 = ops::Add(s.WithOpName("add6"), add5, add8);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4",
+ "add5", "add6", "add7", "add8"};
+ // Make add5 not a TRT candidate, and we expect two segments.
+ auto without_add5 = all_adds - "add5";
+ RunTest(&g, without_add5, without_add5, without_add5,
+ {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}});
+
+ // Make add8 not a candidate and add6 not an input candidate, then all direct
+ // and indirect inputs of add6 will be removed from the segment.
+ auto without_add8 = all_adds - "add8";
+ auto without_add6 = all_adds - "add6";
+ RunTest(&g, without_add8, without_add6, all_adds, {{"add3", "add4"}});
+
+ // Make add3 not a candidate and add0 not an output candidate, then all
+ // direct and indirect outputs of add0 will be removed from the segment.
+ auto without_add3 = all_adds - "add3";
+ auto without_add0 = all_adds - "add0";
+ RunTest(&g, without_add3, all_adds, without_add0, {{"add1", "add7", "add8"}});
}
TEST_F(SegmentTest, BigIfElse) {
- TF_Status* s = TF_NewStatus();
- TF_Graph* graph = TF_NewGraph();
-
- // add2 is not a TRT candidate
- //
// feed
// ||
// add0
- // // ||
+ // // \\
// add1 add4
// || ||
// add2 add5
// || ||
// add3 add6
- // || //
+ // \\ //
// add7
// ||
// <sink>
- //
- TF_Operation* feed = Placeholder(graph, s, "feed");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
-
- TF_Operation* add0 = Add(feed, feed, graph, s, "add0");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add1 = Add(add0, add0, graph, s, "add1");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add2 = Add(add1, add1, graph, s, "add2");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add3 = Add(add2, add2, graph, s, "add3");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add4 = Add(add0, add0, graph, s, "add4");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add5 = Add(add4, add4, graph, s, "add5");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add6 = Add(add5, add5, graph, s, "add6");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- TF_Operation* add7 = Add(add3, add6, graph, s, "add7");
- ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
- EXPECT_EQ(string("add7"), string(TF_OperationName(add7)));
-
- GraphDef graph_def;
- ASSERT_TRUE(GetGraphDef(graph, &graph_def));
-
- SegmentNodesVector segments;
- ASSERT_EQ(SegmentGraph(graph_def,
- MakeCandidateFn({"add0", "add1", "add3", "add4",
- "add5", "add6", "add7"}),
- default_options_, &segments),
- tensorflow::Status::OK());
-
- // Expect 2 subgraphs
- EXPECT_EQ(segments.size(), 2);
-
- std::vector<string> expected0{"add3", "add4", "add5", "add6", "add7"};
- for (const auto& ex : expected0) {
- EXPECT_TRUE(segments[0].first.find(ex) != segments[0].first.end())
- << "Missing expected node " << ex;
- }
-
- std::vector<string> expected1{"add0", "add1"};
- for (const auto& ex : expected1) {
- EXPECT_TRUE(segments[1].first.find(ex) != segments[1].first.end())
- << "Missing expected node " << ex;
- }
- TF_DeleteGraph(graph);
- TF_DeleteStatus(s);
+ Scope s = Scope::NewRootScope();
+ auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT);
+ auto add0 = ops::Add(s.WithOpName("add0"), feed, feed);
+ auto add1 = ops::Add(s.WithOpName("add1"), add0, add0);
+ auto add2 = ops::Add(s.WithOpName("add2"), add1, add1);
+ auto add3 = ops::Add(s.WithOpName("add3"), add2, add2);
+ auto add4 = ops::Add(s.WithOpName("add4"), add0, add0);
+ auto add5 = ops::Add(s.WithOpName("add5"), add4, add4);
+ auto add6 = ops::Add(s.WithOpName("add6"), add5, add5);
+ auto add7 = ops::Add(s.WithOpName("add7"), add3, add6);
+ tensorflow::Graph g(OpRegistry::Global());
+ TF_EXPECT_OK(s.ToGraph(&g));
+
+ // Make add2 not a TRT candidate, and we expect 2 segments.
+ const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
+ "add4", "add5", "add6", "add7"};
+ RunTest(&g, all_adds - "add2", all_adds, all_adds,
+ {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}});
}
} // namespace test