diff options
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/segment/segment.cc | 16 |
2 files changed, 12 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index cb2daa7b12..e3248699dd 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -278,11 +278,14 @@ tf_cc_test( tags = ["no_windows"], deps = [ ":segment", - "//tensorflow/c:c_api", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:scope", + "//tensorflow/core:core_cpu", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc index 92807bed14..008fffc954 100644 --- a/tensorflow/contrib/tensorrt/segment/segment.cc +++ b/tensorflow/contrib/tensorrt/segment/segment.cc @@ -562,8 +562,8 @@ tensorflow::Status SegmentGraph( // input, for output nodes remove all its output. In this way, for common // cases the number of removed nodes should be minimum. auto remove_nodes = [&segment_nodes]( - bool is_input_nodes, - std::deque<const tensorflow::Node*>* que) { + bool is_input_nodes, + std::deque<const tensorflow::Node*>* que) { // Run a BFS on the queue to find all the input/output nodes. std::set<const tensorflow::Node*> visited; while (!que->empty()) { @@ -571,13 +571,14 @@ tensorflow::Status SegmentGraph( que->pop_front(); if (!visited.insert(node).second) continue; segment_nodes.erase(node); - for (auto in : is_input_nodes ? node->in_nodes() : node->out_nodes()) { + for (auto in : + is_input_nodes ? node->in_nodes() : node->out_nodes()) { if (segment_nodes.count(in)) { que->push_back(in); VLOG(2) << "Need to remove node " << in->name() - << " because one of its " - << (is_input_nodes ? "output" : "input") - << " nodes in the graph was removed: " << node->name(); + << " because one of its " + << (is_input_nodes ? "output" : "input") + << " nodes in the graph was removed: " << node->name(); } } } @@ -599,8 +600,7 @@ tensorflow::Status SegmentGraph( } // Don't use small segments. - if (static_cast<int>(segment_nodes.size()) < - options.minimum_segment_size) { + if (static_cast<int>(segment_nodes.size()) < options.minimum_segment_size) { VLOG(1) << "Segment " << segments->size() << " has only " << segment_nodes.size() << " nodes, dropping"; continue; |