aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-02 07:32:33 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-02 07:32:33 -0700
commit26d52994cd3bf16b765799494b1f1c1070231b8c (patch)
treeee6b0fba6cc3e2a6d287d09c679136b08701f327 /tensorflow/contrib/tensorrt
parentcf235fd3065b80d5bc0d0e6175ffefe113723e58 (diff)
Fix ci build errors and internal test errors.
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD2
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc7
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc4
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment_test.cc4
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py28
5 files changed, 25 insertions, 20 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 3ad44ca353..6a0feb1aaf 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -422,7 +422,7 @@ cc_library(
srcs = ["test/utils.cc"],
hdrs = ["test/utils.h"],
deps = [
- "@com_googlesource_code_re2//:re2",
"//tensorflow/core:lib",
+ "@com_googlesource_code_re2//:re2",
],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 42ac4d63dc..15a1f68205 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -615,6 +615,8 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
// Up until this point, graph is not modified. If we return !status.ok() from
// here, this segment will be skipped
+ // TODO(aaroey): let it return proper error status for the following logic
+ // instead of checking fail.
tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
(*engine_nodes)[pos] = engine_node;
if (!status.ok()) {
@@ -629,9 +631,8 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
}
VLOG(1) << "input_nodes size = " << input_nodes.size();
for (int i = 0; i < input_nodes.size(); ++i) {
- Node* n = input_nodes[i];
+ Node* n = CHECK_NOTNULL(input_nodes[i]);
const auto& in = inputs[i];
- CHECK_NOTNULL(n);
VLOG(1) << "Connecting data edge from " << n->name() << ":" << in.index
<< " to " << engine_node->name() << ":" << i;
graph->AddEdge(n, in.index, engine_node, i);
@@ -662,7 +663,7 @@ tensorflow::Status CreateTRTNode(const std::vector<EngineInfo>& infos, int pos,
<< output_node->name() << ":" << conn.outside_port;
}
}
- return status;
+ return Status::OK();
}
// Function to construct a funcdef from the segment and add it to the graph.
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 5d20ef2145..b43f1b190f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -478,7 +478,7 @@ tensorflow::Status SegmentGraph(
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the segment nodes set.
- std::unordered_map<string, std::set<const tensorflow::Node*>> sg_map;
+ std::map<string, std::set<const tensorflow::Node*>> sg_map;
// A map from the segment identifier (currently the name of the root node of
// the segment tree) to the device names that the nodes in the segment are
@@ -603,7 +603,7 @@ tensorflow::Status SegmentGraph(
for (const auto& itr : sg_map) {
const std::set<const tensorflow::Node*>& segment_nodes = itr.second;
if (VLOG_IS_ON(1)) {
- string s;
+ string s = "parent=" + itr.first + ":";
for (auto node : segment_nodes) s += " " + node->name();
VLOG(1) << "Segment " << segments->size() << ": " << s;
}
diff --git a/tensorflow/contrib/tensorrt/segment/segment_test.cc b/tensorflow/contrib/tensorrt/segment/segment_test.cc
index 432e7b1c04..5937fa8259 100644
--- a/tensorflow/contrib/tensorrt/segment/segment_test.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment_test.cc
@@ -206,7 +206,7 @@ TEST_F(SegmentTest, Multiple) {
// Make add5 not a TRT candidate, and we expect two segments.
auto without_add5 = all_adds - "add5";
RunTest(&g, without_add5, without_add5, without_add5,
- {{"add6", "add8"}, {"add0", "add1", "add2", "add3"}});
+ {{"add0", "add1", "add2", "add3"}, {"add6", "add8"}});
// Make add8 not a candidate and add6 not an input candidate, then all direct
// and indirect inputs of add6 will be removed from the segment.
@@ -252,7 +252,7 @@ TEST_F(SegmentTest, BigIfElse) {
const std::set<string> all_adds = {"add0", "add1", "add2", "add3",
"add4", "add5", "add6", "add7"};
RunTest(&g, all_adds - "add2", all_adds, all_adds,
- {{"add3", "add4", "add5", "add6", "add7"}, {"add0", "add1"}});
+ {{"add0", "add1"}, {"add3", "add4", "add5", "add6", "add7"}});
}
} // namespace test
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index 74186e3d95..8ea5a63735 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -136,8 +136,8 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
def setUp(self):
"""Setup method."""
super(PartiallyConvertedTestA, self).setUp()
- # Let it fail to build the first engine.
- trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
+ # Let it fail to build the second engine.
+ trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail")
def GetParams(self):
"""Create a graph containing two segment."""
@@ -167,8 +167,8 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase):
input_names=[input_name],
input_dims=[input_dims],
expected_engines={
- # Only the second engine is built.
- "my_trt_op_1": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
+ # Only the first engine is built.
+ "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"]
},
expected_output_dims=tuple(input_dims),
allclose_atol=1.e-06,
@@ -180,16 +180,16 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA):
def setUp(self):
"""Setup method."""
super(PartiallyConvertedTestB, self).setUp()
- # Let it fail to build the second engine.
+ # Let it fail to build the first engine.
trt_convert.clear_test_values("")
- trt_convert.add_test_value("my_trt_op_1:CreateTRTNode", "fail")
+ trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail")
def GetParams(self):
"""Create a graph containing two segment."""
return super(PartiallyConvertedTestB, self).GetParams()._replace(
expected_engines={
- # Only the first engine is built.
- "my_trt_op_0": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
+ # Only the second engine is built.
+ "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"]
})
@@ -227,8 +227,8 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
input_names=[input_name],
input_dims=[input_dims],
expected_engines={
- "my_trt_op_0": ["add2", "add3", "mul1"],
- "my_trt_op_1": ["add", "add1", "mul"]
+ "my_trt_op_0": ["add", "add1", "mul"],
+ "my_trt_op_1": ["add2", "add3", "mul1"]
},
expected_output_dims=tuple(input_dims),
allclose_atol=1.e-06,
@@ -289,6 +289,10 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
input_dims=[input_dims],
expected_engines={
"my_trt_op_0": ["add2", "add3", "mul1"],
+ # Why segment ["add", "add1", "mul"] was assigned segment id 1
+ # instead of 0: the parent node of this segment is actually const
+ # node 'c', but it's removed later since it's const output of the
+ # segment which is not allowed.
"my_trt_op_1": ["add", "add1", "mul"]
},
expected_output_dims=tuple(input_dims),
@@ -330,8 +334,8 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
input_names=[input_name],
input_dims=[input_dims],
expected_engines={
- "my_trt_op_0": ["c2", "add2", "add3", "mul1"],
- "my_trt_op_1": ["c1", "add", "add1", "mul"]
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
},
expected_output_dims=tuple(input_dims),
allclose_atol=1.e-06,