diff options
author | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-08-02 07:32:33 -0700 |
---|---|---|
committer | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-08-02 07:32:33 -0700 |
commit | 26d52994cd3bf16b765799494b1f1c1070231b8c (patch) | |
tree | ee6b0fba6cc3e2a6d287d09c679136b08701f327 /tensorflow/contrib/tensorrt | |
parent | cf235fd3065b80d5bc0d0e6175ffefe113723e58 (diff) |
Fix ci build errors and internal test errors.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 7 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/base_test.py | 28 |
5 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 3ad44ca353..6a0feb1aaf 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -422,7 +422,7 @@ cc_library( srcs = ["test/utils.cc"], hdrs = ["test/utils.h"], deps = [ - "@com_googlesource_code_re2//:re2", "//tensorflow/core:lib", + "@com_googlesource_code_re2//:re2", ], ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 42ac4d63dc..15a1f68205 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -615,6 +615,8 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, // Up until this point, graph is not modified. If we return !status.ok() from // here, this segment will be skipped + // TODO(aaroey): let it return proper error status for the following logic + // instead of checking fail. tensorflow::Node* engine_node = graph->AddNode(trt_node, &status); (*engine_nodes)[pos] = engine_node; if (!status.ok()) { @@ -629,9 +631,8 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, } VLOG(1) << "input_nodes size = " << input_nodes.size(); for (int i = 0; i < input_nodes.size(); ++i) { - Node* n = input_nodes[i]; + Node* n = CHECK_NOTNULL(input_nodes[i]); const auto& in = inputs[i]; - CHECK_NOTNULL(n); VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index << " to " << engine_node->name() << ":" << i; graph->AddEdge(n, in.index, engine_node, i); @@ -662,7 +663,7 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos, << output_node->name() << ":" << conn.outside_port; } } - return status; + return Status::OK(); } // Function to construct a funcdef from the segment and add it to the graph. diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 5d20ef2145..b43f1b190f 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph( // A map from the segment identifier (currently the name of the root node of // the segment tree) to the segment nodes set. - std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map; + std::map<string, std::set<const tensorflow::Node*>> sg_map; // A map from the segment identifier (currently the name of the root node of // the segment tree) to the device names that the nodes in the segment are @@ -603,7 +603,7 @@ tensorflow::Status SegmentGraph( for (const auto& itr : sg_map) { const std::set<const tensorflow::Node*>& segment_nodes = itr.second; if (VLOG_IS_ON(1)) { - string s; + string s = "parent=" + itr.first + ":"; for (auto node : segment_nodes) s += " " + node->name(); VLOG(1) << "Segment " << segments->size() << ": " << s; } diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc index 432e7b1c04..5937fa8259 100644 --- a/tensorflow/contrib/tensorrt/segment/segment_test.cc +++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc @@ -206,7 +206,7 @@ TEST_F(SegmentTest, Multiple) { // Make add5 not a TRT candidate, and we expect two segments. auto without_add5 = all_adds - "add5"; RunTest(&g, without_add5, without_add5, without_add5, - {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}}); + {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}}); // Make add8 not a candidate and add6 not an input candidate, then all direct // and indirect inputs of add6 will be removed from the segment. @@ -252,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) { const std::set<string> all_adds = {"add0", "add1", "add2", "add3", "add4", "add5", "add6", "add7"}; RunTest(&g, all_adds - "add2", all_adds, all_adds, - {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}}); + {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}}); } } // namespace test diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index 74186e3d95..8ea5a63735 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -136,8 +136,8 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): def setUp(self): """Setup method.""" super(PartiallyConvertedTestA, self).setUp() - # Let it fail to build the first engine. - trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") + # Let it fail to build the second engine. + trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") def GetParams(self): """Create a graph containing two segment.""" @@ -167,8 +167,8 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): input_names=[input_name], input_dims=[input_dims], expected_engines={ - # Only the second engine is built. - "my_trt_op_1": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + # Only the first engine is built. + "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] }, expected_output_dims=tuple(input_dims), allclose_atol=1.e-06, @@ -180,16 +180,16 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA): def setUp(self): """Setup method.""" super(PartiallyConvertedTestB, self).setUp() - # Let it fail to build the second engine. + # Let it fail to build the first engine. trt_convert.clear_test_values("") - trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail") + trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") def GetParams(self): """Create a graph containing two segment.""" return super(PartiallyConvertedTestB, self).GetParams()._replace( expected_engines={ - # Only the first engine is built. - "my_trt_op_0": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + # Only the second engine is built. + "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] }) @@ -227,8 +227,8 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase): input_names=[input_name], input_dims=[input_dims], expected_engines={ - "my_trt_op_0": ["add2", "add3", "mul1"], - "my_trt_op_1": ["add", "add1", "mul"] + "my_trt_op_0": ["add", "add1", "mul"], + "my_trt_op_1": ["add2", "add3", "mul1"] }, expected_output_dims=tuple(input_dims), allclose_atol=1.e-06, @@ -289,6 +289,10 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): input_dims=[input_dims], expected_engines={ "my_trt_op_0": ["add2", "add3", "mul1"], + # Why segment ["add", "add1", "mul"] was assigned segment id 1 + # instead of 0: the parent node of this segment is actually const + # node 'c', but it's removed later since it's const output of the + # segment which is not allowed. "my_trt_op_1": ["add", "add1", "mul"] }, expected_output_dims=tuple(input_dims), @@ -330,8 +334,8 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): input_names=[input_name], input_dims=[input_dims], expected_engines={ - "my_trt_op_0": ["c2", "add2", "add3", "mul1"], - "my_trt_op_1": ["c1", "add", "add1", "mul"] + "my_trt_op_0": ["c1", "add", "add1", "mul"], + "my_trt_op_1": ["c2", "add2", "add3", "mul1"] }, expected_output_dims=tuple(input_dims), allclose_atol=1.e-06, |