aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert/convert_graph.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert/convert_graph.cc')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc20
1 files changed, 9 insertions, 11 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index ff8cc6374d..eea8c8efa2 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -49,13 +49,12 @@ namespace tensorrt {
namespace convert {
namespace {
-bool IsTensorRTCandidate(const tensorflow::Node* node) {
+bool IsTensorRTCandidate(const tensorflow::NodeDef& node_def) {
// LINT.IfChange
// TODO(jie): Segmentation shouldn't associated with op name.
// Split it into a registration for each kernel.
static const std::set<string> candidate_ops = {
"Identity",
- "Snapshot",
"Const",
"Conv2D",
"MaxPool",
@@ -75,7 +74,7 @@ bool IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(ben,jie): ...
};
// LINT.ThenChange(//tensorflow/contrib/tensorrt/convert/convert_nodes.h)
- return candidate_ops.count(node->type_string());
+ return candidate_ops.count(node_def.op());
}
void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
@@ -85,10 +84,10 @@ void GetSubGraphIncomingEdges(const tensorflow::Graph& graph,
const tensorflow::Node* node = graph.FindNodeId(node_id);
for (const tensorflow::Edge* edge : node->in_edges()) {
if (!subgraph_node_ids.count(edge->src()->id()) &&
- !edge->src()->IsSource() && !edge->IsControlEdge()) {
+ !edge->src()->IsSource()) {
incoming_edges->insert(edge);
} else {
- VLOG(2) << node->name() << " -> " << edge->src()->name() << " N, ";
+ VLOG(2) << edge->src()->name() << " N, ";
}
}
}
@@ -101,11 +100,11 @@ void GetSubGraphOutgoingEdges(const tensorflow::Graph& graph,
const tensorflow::Node* node = graph.FindNodeId(node_id);
for (const tensorflow::Edge* edge : node->out_edges()) {
if (!subgraph_node_ids.count(edge->dst()->id()) &&
- !edge->dst()->IsSink() && !edge->IsControlEdge()) {
- VLOG(2) << node->name() << " -> " << edge->dst()->name() << " Y, ";
+ !edge->dst()->IsSink()) {
+ VLOG(2) << edge->dst()->name() << " Y, ";
outgoing_edges->insert(edge);
} else {
- VLOG(2) << node->name() << " -> " << edge->dst()->name() << " N, ";
+ VLOG(2) << edge->dst()->name() << " N, ";
}
}
}
@@ -410,9 +409,8 @@ tensorflow::Status ConvertGraphDefToTensorRT(
tensorflow::Status status = ConvertSubGraphToTensorRT(&p);
if (status != tensorflow::Status::OK()) {
LOG(WARNING) << "subgraph conversion error for subgraph_index:" << count
- << " due to: \"" << status.ToString()
- << "\" SKIPPING......( " << subgraph_node_names.size()
- << " nodes)";
+ << " due to: \n"
+ << status.ToString() << " SKIPPING......";
}
count++;
}