aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-06-19 12:14:10 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-06-19 12:14:10 -0700
commitb5a8d9ea0ec49b1e3fee5441a78a3fb33cd4d470 (patch)
tree5e1aa2478541a54baf43539afe6dfbc1b5b5e57a
parent0fb21f608c334dfcaadab7b918c06b88afa8c592 (diff)
Multiple changes:
1. use unique_ptr instead of shared_ptr, and fix a bug in destructor of TrtEngineOp where it did't reset the shared_ptr but a copy of it 2. fix the include order 3. shorten the reference to tensorflow::tensorrt::xxx 4. remove some code that sets something which will be overwritten later 5. fix format, including: function signature, variable names, const reference, etc 6. remove some deadcode 7. add a lot of comments and TODOs 8. in TrtEngineOp, replace the map of allocators with a single unique_ptr 9. in TrtEngineOp, remove parameter ignore_dim_change from GetEngine(), since it always uses member fixed_input_size_
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc272
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h8
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc214
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h61
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc306
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h33
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h32
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h37
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.h7
9 files changed, 514 insertions, 456 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index c17ef5fdab..bd6ed2d593 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
-#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <fstream>
#include <list>
@@ -25,6 +24,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/contrib/tensorrt/segment/segment.h"
@@ -76,6 +77,7 @@ std::vector<int> GetLoadedTensorRTVersion() {
int ver_patch = ver - ver_minor * 100;
return {ver_major, ver_minor, ver_patch};
}
+
namespace {
bool IsTensorRTCandidate(const tensorflow::Node* node) {
@@ -121,13 +123,14 @@ tensorflow::Status BuildNodeMap(
}
} // namespace
+
// Function to get calibration from ResourceMgr and put them into nodedef.
tensorflow::Status ConvertCalibGraphToInferGraph(
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* infer_graph,
bool is_dyn_op) {
VLOG(0) << "Starting Calib Conversion";
infer_graph->CopyFrom(graph_def);
- auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
+ auto trt_rm = TRTResourceManager::instance();
auto calib_rm = trt_rm->getManager("TRTCalibration");
int num_nodes = infer_graph->node_size();
if (!is_dyn_op) {
@@ -139,7 +142,7 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
if (n->op() == "TRTEngineOp") {
VLOG(1) << "Processing " << n->name();
string container_name = n->attr().at("segment_funcdef_name").s();
- tensorflow::tensorrt::TRTCalibrationResource* cres = nullptr;
+ TRTCalibrationResource* cres = nullptr;
auto status = calib_rm->Lookup(container_name, "Calibrator", &cres);
if (!status.ok()) {
LOG(ERROR) << "Could not get Calibration information. Did you run with "
@@ -240,14 +243,16 @@ EngineInfo GetEngineInfo(
const tensorflow::grappler::GraphProperties& graph_properties,
const std::set<string>& segment_nodes,
const std::unordered_map<string, tensorflow::Node*>& node_map,
- const std::vector<tensorflow::Node*>& topological_order) {
+ const std::vector<tensorflow::Node*>& reverse_topo_order) {
std::vector<int> subgraph_node_ids;
EngineInfo info;
std::set<string> segment_devices;
int input_port = 0;
int output_port = 0;
+ // TODO(aaroey): consider using node id and port instead. Also, here we assume
+ // that input edge set and output edge set have no intersection, is this true?
std::unordered_map<string, int> created_edges;
- for (auto it = topological_order.rbegin(); it != topological_order.rend();
+ for (auto it = reverse_topo_order.rbegin(); it != reverse_topo_order.rend();
++it) {
auto node_name = (*it)->name();
@@ -287,9 +292,11 @@ EngineInfo GetEngineInfo(
created_edges.insert({s, port});
input_port++;
}
- EngineConnections ec(input_node->name(), input_node->id(),
+ EngineConnection ec(input_node->name(), input_node->id(),
edge->src_output(), node_name, node_id,
edge->dst_input(), true, port);
+ // TODO(aaroey): this will be rewritten in
+ // ConvertSegmentToSubGraphDef, fix it.
ec.connection_type = input_node->output_type(edge->src_output());
info.connections.emplace_back(std::move(ec));
@@ -317,10 +324,9 @@ EngineInfo GetEngineInfo(
}
}
- ConvertSegmentToGraphDef(g, graph_properties, subgraph_node_ids,
- &info.connections, &info.segment_graph_def,
- &info.engine_name);
- info.engine_type = EngineInfo::EngineType::TRTStatic;
+ ConvertSegmentToSubGraphDef(g, graph_properties, subgraph_node_ids,
+ &info.connections, &info.segment_graph_def,
+ &info.engine_name);
// TODO(sami): This should not happen once segmenter is updated.
if (segment_devices.size() == 1) {
info.device = *segment_devices.begin();
@@ -336,23 +342,27 @@ EngineInfo GetEngineInfo(
}
// Function to insert a TRT node into the graph.
+// 'alloc' is only used for creating static engine.
tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
const std::vector<EngineInfo>& infos, int pos,
- tensorflow::NodeDef* trt_node,
nvinfer1::IGpuAllocator* alloc,
int max_batch_size) {
- auto& info = infos.at(pos);
+ const auto& info = infos.at(pos);
std::vector<tensorflow::TensorShapeProto> out_shapes;
std::vector<tensorflow::TensorShapeProto> input_shapes;
std::vector<tensorflow::PartialTensorShape> shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
std::vector<tensorflow::DataType> out_types;
VLOG(1) << "Processing " << info.engine_name;
- for (const auto conn : info.connections) {
- if (!conn.is_input_edge) { // output edge
+
+ // Update the shape and data types of input/output nodes, and find all unique
+ // inputs.
+ for (const auto& conn : info.connections) {
+ if (!conn.is_input_edge) {
+ // Set the shapes and data types of output edge.
tensorflow::TensorShapeProto out_shape;
- conn.inside_shape.AsProto(
- &out_shape); // shape of the output node inside segment
+ // shape of the output node inside segment
+ conn.inside_shape.AsProto(&out_shape);
if (out_shapes.size() <= conn.port_number) {
out_shapes.resize(conn.port_number + 1);
out_types.resize(conn.port_number + 1);
@@ -360,10 +370,11 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
out_shapes.at(conn.port_number) = out_shape;
out_types.at(conn.port_number) = conn.connection_type;
continue;
- } // input edge
+ }
+
+ // Set the shapes and data types of input edge.
tensorflow::TensorShapeProto in_shape;
conn.outside_shape.AsProto(&in_shape);
-
if (input_shapes.size() <= conn.port_number) {
input_shapes.resize(conn.port_number + 1);
shapes.resize(conn.port_number + 1);
@@ -373,18 +384,13 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
string input_node = conn.outside_node_name;
int input_port = conn.outside_port;
- auto dtype = conn.connection_type;
bool found_engine = false;
// Rewire the inputs to other engines if they contain original input node
for (size_t t = 0; t < infos.size(); ++t) {
- if (t == pos) {
- continue;
- }
+ if (t == pos) continue;
auto& engine_info = infos.at(t);
for (const auto& eng_conn : engine_info.connections) {
- if (eng_conn.is_input_edge) {
- continue;
- }
+ if (eng_conn.is_input_edge) continue;
if (eng_conn.inside_node_name == input_node) {
input_node = engine_info.engine_name;
if (eng_conn.inside_port == input_port) {
@@ -398,6 +404,7 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
}
VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
<< info.engine_name << ":" << inputs.size();
+ // Skip duplicate inputs.
bool new_input = true;
for (const auto& inp : inputs) {
if (inp.node == input_node && inp.index == input_port) {
@@ -406,78 +413,63 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
}
}
if (new_input) {
- inputs.emplace_back(input_node, input_port, dtype);
+ inputs.emplace_back(input_node, input_port, conn.connection_type);
}
}
+
+ // Build the engine and get its serialized representation.
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic ||
info.precision_mode == INT8MODE) {
// Create static engine and for int8 test validity of the engine.
- tensorflow::tensorrt::Logger trt_logger;
- auto builder = std::shared_ptr<nvinfer1::IBuilder>(
- nvinfer1::createInferBuilder(trt_logger), [](nvinfer1::IBuilder* p) {
- if (p) p->destroy();
- });
+ Logger trt_logger;
+ auto builder = std::unique_ptr<
+ nvinfer1::IBuilder, std::function<void(nvinfer1::IBuilder*)>>(
+ nvinfer1::createInferBuilder(trt_logger),
+ [](nvinfer1::IBuilder* p) { if (p) p->destroy(); });
builder->setMaxBatchSize(max_batch_size);
- if (info.precision_mode == tensorflow::tensorrt::convert::FP16MODE) {
- builder->setHalf2Mode(true);
- }
+ if (info.precision_mode == FP16MODE) builder->setHalf2Mode(true);
builder->setMaxWorkspaceSize(info.max_workspace_size_bytes);
#if NV_TENSORRT_MAJOR > 3
builder->setGpuAllocator(alloc);
#endif
- nvinfer1::ICudaEngine* engine = nullptr;
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
// TODO(sami): What happens if 1st dim is not batch?
- auto status = ConvertSubgraphToEngine(info.segment_graph_def, builder.get(),
- shapes, &engine, info.precision_mode);
- if (!status.ok()) {
- if (engine) engine->destroy();
- return status;
- }
- if (engine) {
- auto engine_data = std::shared_ptr<nvinfer1::IHostMemory>(
- engine->serialize(), [](nvinfer1::IHostMemory* p) {
- if (p) p->destroy();
- });
- segment_string =
- string((const char*)engine_data->data(), engine_data->size());
- engine->destroy();
- }
+ TF_RETURN_IF_ERROR(ConvertSubGraphDefToEngine(
+ info.segment_graph_def, info.precision_mode, shapes, builder.get(),
+ &engine, /*convert_successfully=*/nullptr));
+ TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
+ segment_string =
+ string((const char*)engine_data->data(), engine_data->size());
if (info.precision_mode == INT8MODE) {
+ // TODO(aaroey): why not put this inside the 'else' branch?
segment_string = info.segment_graph_def.SerializeAsString();
}
} else {
segment_string = info.segment_graph_def.SerializeAsString();
}
+
+ // TODO(aaroey): use enum instead, and add a helper method to do the
+ // conversion.
string prec_string;
switch (info.precision_mode) {
- case FP32MODE: {
+ case FP32MODE:
prec_string = "FP32";
break;
- }
- case FP16MODE: {
+ case FP16MODE:
prec_string = "FP16";
break;
- }
- case INT8MODE: {
+ case INT8MODE:
prec_string = "INT8";
- auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
- auto calib_rm = trt_rm->getManager("TRTCalibration");
- if (!calib_rm) {
+ if (!TRTResourceManager::instance()->getManager("TRTCalibration")) {
LOG(ERROR) << "Failed to construct calibration storage";
}
break;
- }
- default: {
+ default:
return tensorflow::errors::OutOfRange("Unknown precision mode");
- }
}
- tensorflow::Status status;
- tensorflow::Node* engine_node = nullptr;
tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
- if (!info.device.empty()) {
- node_builder.Device(info.device);
- }
+ if (!info.device.empty()) node_builder.Device(info.device);
if (VLOG_IS_ON(1)) {
string ins=StrCat(info.engine_name," inputs= ");
for (const auto& ii : inputs) {
@@ -486,50 +478,53 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << ins;
}
node_builder.Input(inputs);
- if (info.engine_type == EngineInfo::EngineType::TRTStatic) {
- if (info.cached_engine_batches.size()) {
- LOG(WARNING) << "Cached engine batches are ignored for static engines";
- }
+ if (info.engine_type == EngineInfo::EngineType::TRTStatic &&
+ info.cached_engine_batches.size()) {
+ LOG(WARNING) << "Cached engine batches are ignored for static engines";
}
- status = node_builder.Attr("input_shapes", input_shapes)
- .Attr("output_shapes", out_shapes)
- .Attr("static_engine",
- info.engine_type == EngineInfo::EngineType::TRTStatic)
- .Attr("segment_funcdef_name",
- StrCat(info.engine_name, "_native_segment"))
- .Attr("serialized_segment", segment_string)
- .Attr("calibration_data", "")
- .Attr("max_cached_engines_count", info.maximum_cached_engines)
- .Attr("cached_engine_batches", {max_batch_size})
- .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
- .Attr("precision_mode", prec_string)
- .Attr("OutT", out_types)
- .Finalize(trt_node);
+ tensorflow::NodeDef trt_node;
+ tensorflow::Status status =
+ node_builder.Attr("input_shapes", input_shapes)
+ .Attr("output_shapes", out_shapes)
+ .Attr("static_engine",
+ info.engine_type == EngineInfo::EngineType::TRTStatic)
+ .Attr("segment_funcdef_name",
+ StrCat(info.engine_name, "_native_segment"))
+ .Attr("serialized_segment", segment_string)
+ .Attr("calibration_data", "")
+ .Attr("max_cached_engines_count", info.maximum_cached_engines)
+ .Attr("cached_engine_batches", {max_batch_size})
+ .Attr("workspace_size_bytes", info.max_workspace_size_bytes)
+ .Attr("precision_mode", prec_string)
+ .Attr("OutT", out_types)
+ .Finalize(&trt_node);
if (!status.ok()) {
LOG(ERROR) << "Node construction failed with" << status;
return status;
}
VLOG(1) << "Adding TRTEngine " << info.engine_name << " to graph";
- engine_node = graph->AddNode(*trt_node, &status);
+ tensorflow::Node* engine_node = graph->AddNode(trt_node, &status);
if (!status.ok()) {
LOG(ERROR) << "Adding node failed " << status;
return status;
}
-
+ // Updates the inputs of output edges destination nodes, and point them to the
+ // engine node.
for (auto& conn : info.connections) {
if (conn.is_input_edge) continue;
VLOG(1) << " Updating DBG " << engine_node->name() << " out_port "
<< conn.port_number << " out_id " << conn.outside_id
<< " name=" << conn.outside_node_name;
auto dst_node = graph->FindNodeId(conn.outside_id);
- if (!dst_node) { // node removed skip.
- continue;
- }
+ // TODO(aaroey): node could be removed during construction of other TRT
+ // nodes, but then in that case who is going to update their input nodes?
+ if (!dst_node) continue;
VLOG(1) << "Updating " << engine_node->name() << ":" << conn.port_number
<< " to " << dst_node->name() << ":" << conn.outside_port;
status = graph->UpdateEdge(engine_node, conn.port_number, dst_node,
conn.outside_port);
if (!status.ok()) {
+ // TODO(aaroey): should we return the status?
LOG(ERROR) << "Edge update failed " << engine_node->name() << ":"
<< conn.port_number << " -> " << dst_node->name() << ":"
<< conn.outside_port << " status= " << status;
@@ -631,9 +626,7 @@ tensorflow::Status RegisterSegmentFunctionToFunctionLibrary(
std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
ConversionParams& params, EngineInfo& engine) {
int cuda_device_id = -1;
- // we need to us PM here since in python path there is no way to get
- // to allocators
- auto CheckDeviceID = [](int tfid) -> int {
+ auto check_device_id = [](int tfid) -> int {
tensorflow::TfGpuId tf_gpu_id(tfid);
CudaGpuId cuda_gpu_id;
Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
@@ -646,6 +639,9 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
return -1;
};
tensorflow::Allocator* dev_allocator = nullptr;
+ // we need to us PM here since in python path there is no way to get
+ // to allocators
+ // TODO(aaroey): fix this.
auto pm = tensorflow::ProcessState::singleton();
if (params.cluster) { // get allocator
const tensorflow::Device* device = nullptr;
@@ -653,15 +649,15 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
device = params.cluster->GetDeviceSet()->FindDeviceByName(engine.device);
}
if (device) {
- cuda_device_id = CheckDeviceID(device->parsed_name().id);
+ cuda_device_id = check_device_id(device->parsed_name().id);
if (cuda_device_id < 0) {
- LOG(ERROR) << "Cuda device identification failed, using device "
- "0.";
+ LOG(ERROR) << "Cuda device identification failed, using device 0.";
cuda_device_id = 0;
}
tensorflow::GPUOptions gpuoptions;
// this should be instantiated by now
tensorflow::TfGpuId tf_gpu_id(device->parsed_name().id);
+ // TODO(aaroey): why not using device->GetAllocator()?
dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
VLOG(1) << "Got an allocator for device tf_device=" << tf_gpu_id.value()
<< " cuda device= " << cuda_device_id << " at " << dev_allocator;
@@ -676,19 +672,16 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
// if device is set, try to find the device. Might be a problem for multi
// host case but TensorRT do not support multi host setups yet.
if (!engine.device.empty()) {
- tensorflow::DeviceNameUtils::ParsedName parsed_name;
- if (tensorflow::DeviceNameUtils::ParseFullName(engine.device,
- &parsed_name)) {
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(engine.device, &parsed_name)) {
cuda_device_id = parsed_name.has_id ? parsed_name.id : -1;
}
try_gpu_ids = !parsed_name.has_id;
}
if (try_gpu_ids) {
while (found_device < 100) {
- cuda_device_id = CheckDeviceID(found_device);
- if (cuda_device_id >= 0) {
- break;
- }
+ cuda_device_id = check_device_id(found_device);
+ if (cuda_device_id >= 0) break;
found_device++;
}
}
@@ -698,31 +691,32 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
return std::make_pair(cuda_device_id, dev_allocator);
}
LOG(WARNING)
- << "Can't determine the device constructing an allocator at device "
+ << "Can't determine the device, constructing an allocator at device "
<< found_device;
tensorflow::GPUOptions gpuoptions;
- gpuoptions.set_allow_growth(
- true); // this will be a noop if device is already initialized
+ // this will be a noop if device is already initialized
+ gpuoptions.set_allow_growth(true);
tensorflow::TfGpuId tf_gpu_id(found_device);
dev_allocator = pm->GetGPUAllocator(gpuoptions, tf_gpu_id, 1);
}
return std::make_pair(cuda_device_id, dev_allocator);
}
+
// Entry function from optimization pass.
tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
- // Segment the graph into subgraphs that can be converted to TensorRT
- tensorflow::tensorrt::segment::SegmentOptions segment_options;
+ // Convert graphdef to graph.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
params.input_graph_def->library());
tensorflow::Graph graph(flib);
TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToGraph(
tensorflow::GraphConstructorOptions(), *params.input_graph_def, &graph));
+ // Segment the graph into subgraphs that can be converted to TensorRT
+ tensorflow::tensorrt::segment::SegmentOptions segment_options;
// TODO(ben,jie,sami): exclude output nodes (DISCUSS IT)
for (auto node : *(params.output_names)) {
segment_options.exclude_node_list.insert(node);
}
-
segment_options.minimum_segment_size = params.minimum_segment_size;
tensorflow::tensorrt::segment::SegmentNodesVector segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
@@ -730,34 +724,38 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
if (segments.size() > 1) {
VLOG(0) << "MULTIPLE tensorrt candidate conversion: " << segments.size();
}
+
+ // Get the EngineInfo for each segment.
std::unordered_map<string, tensorflow::Node*> node_map;
TF_RETURN_IF_ERROR(BuildNodeMap(graph, &node_map));
- std::unordered_map<string, std::pair<int, string>> output_edge_map;
float total_num_nodes_in_segments = 0.;
std::vector<EngineInfo> engine_segments;
engine_segments.reserve(segments.size());
- std::vector<tensorflow::Node*> topo_order;
- tensorflow::GetPostOrder(graph, &topo_order);
- size_t total_engine_size = 0;
- std::vector<size_t> engine_sizes;
+ std::vector<tensorflow::Node*> reverse_topo_order;
+ tensorflow::GetPostOrder(graph, &reverse_topo_order);
+ size_t total_engine_bytes_size = 0;
+ std::vector<size_t> engine_bytes_size;
for (size_t t = 0; t < segments.size(); t++) {
auto& s = segments.at(t);
- engine_segments.emplace_back(GetEngineInfo(&graph, *params.graph_properties,
- s.first, node_map, topo_order));
+ engine_segments.emplace_back(GetEngineInfo(
+ &graph, *params.graph_properties, s.first, node_map,
+ reverse_topo_order));
auto& curr_engine = engine_segments.back();
curr_engine.precision_mode = params.precision_mode;
- engine_sizes.push_back(curr_engine.segment_graph_def.ByteSizeLong());
curr_engine.engine_type =
(params.is_dyn_op || params.precision_mode == INT8MODE
? EngineInfo::EngineType::TRTDynamic
: EngineInfo::EngineType::TRTStatic);
curr_engine.cached_engine_batches = params.cached_engine_batches;
curr_engine.maximum_cached_engines = params.max_cached_engines;
- total_engine_size += engine_sizes.back();
- total_num_nodes_in_segments += s.first.size();
StrAppend(&curr_engine.engine_name, "my_trt_op_", t);
RegisterSegmentFunctionToFunctionLibrary(
&graph, curr_engine.segment_graph_def, curr_engine.engine_name);
+
+ engine_bytes_size.push_back(curr_engine.segment_graph_def.ByteSizeLong());
+ total_engine_bytes_size += engine_bytes_size.back();
+ total_num_nodes_in_segments += s.first.size();
+
if (VLOG_IS_ON(8)) {
string fname = curr_engine.engine_name;
StrAppend(&fname, ".pb");
@@ -767,54 +765,54 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
f.close();
}
}
- std::vector<tensorflow::NodeDef*> trt_nodes;
- trt_nodes.reserve(engine_segments.size());
+
+ // Create a TRT node for each segment using its EngineInfo.
int old_cuda_device = 0;
auto err = cudaGetDevice(&old_cuda_device);
if (err != cudaSuccess) {
- LOG(ERROR) << "Couldn't get current device error is "
- << cudaGetErrorString(err);
+ LOG(ERROR) << "Couldn't get current device: " << cudaGetErrorString(err);
}
VLOG(1) << "Current cuda device is " << old_cuda_device;
for (int i = 0; i < engine_segments.size(); ++i) {
- auto trt_node = new tensorflow::NodeDef;
- trt_nodes.push_back(trt_node);
auto& engine = engine_segments.at(i);
// Partition the workspace size by the average of node ratio and segment
// graphdef size
engine.max_workspace_size_bytes =
params.max_workspace_size_bytes *
- (engine_sizes.at(i) / total_engine_size +
+ (engine_bytes_size.at(i) / total_engine_bytes_size +
segments.at(i).first.size() / total_num_nodes_in_segments) /
2.0;
- std::shared_ptr<nvinfer1::IGpuAllocator> alloc;
+ // The allocator is used to build the engine. The build and the built engine
+ // will be destroyed after we get the serialized engine string, so it's fine
+ // to use unique_ptr here.
+ std::unique_ptr<nvinfer1::IGpuAllocator> alloc;
auto device_alloc = GetDeviceAndAllocator(params, engine);
int cuda_device_id = 0;
if (device_alloc.first >= 0) {
cuda_device_id = device_alloc.first;
alloc.reset(new TRTDeviceAllocator(device_alloc.second));
- } else { // Setting allocator as nullptr should get revert to the
- // cudamalloc
+ } else {
+ // Setting allocator as nullptr should get revert to the cudamalloc
LOG(WARNING) << "Can't identify the cuda device. Running on device 0 ";
}
cudaSetDevice(cuda_device_id);
- auto status = CreateTRTNode(&graph, engine_segments, i, trt_node,
- alloc.get(), params.max_batch_size);
+ auto status = CreateTRTNode(
+ &graph, engine_segments, i, alloc.get(), params.max_batch_size);
if (status.ok()) {
- const auto& internal_nodes = segments.at(i).first;
- for (auto node_id : internal_nodes) {
- graph.RemoveNode(node_map.at(node_id));
+ for (auto node_name : segments.at(i).first) {
+ graph.RemoveNode(node_map.at(node_name));
}
} else {
+ // TODO(aaroey): in this case, the graph is already modified, we should
+ // return the status?
LOG(WARNING) << "Engine creation for segment " << i << ", composed of "
- << segments.at(i).first.size() << " nodes failed. Skipping";
- VLOG(1) << "Failure reason " << status;
+ << segments.at(i).first.size() << " nodes failed: "
+ << status << ". Skipping...";
}
}
cudaSetDevice(old_cuda_device);
graph.ToGraphDef(params.output_graph_def);
- for (auto tn : trt_nodes) delete tn;
- VLOG(1)<<"Returning from conversion";
+ VLOG(1) << "Returning from conversion";
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index e2f4c1c83f..9d986e4890 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -64,10 +64,10 @@ tensorflow::Status ConvertCalibGraphToInferGraph(
const tensorflow::GraphDef& graph_def, tensorflow::GraphDef* new_graph_def,
bool is_dyn_op);
-// max_batch_size: maximum batch size which can be used for inference for
-// optimization targets inference run with max batch size.
-// max_workspace_size_bytes: The upper bound of memory allowance for
-// engine building.
+// - max_batch_size: maximum batch size which can be used for inference for
+// optimization targets inference run with max batch size.
+// - max_workspace_size_bytes: The upper bound of memory allowance for engine
+// building.
tensorflow::Status ConvertGraphDefToTensorRT(
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 6ad2d7e68f..a252ea67df 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
-#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include <algorithm>
#include <list>
@@ -25,7 +24,9 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
@@ -125,12 +126,10 @@ static std::vector<std::pair<int, int>> CreateSamePadding(
string GetCommonNameScope(const string& op_name_a, const string& op_name_b) {
size_t last_scope_separator = 0;
- for (size_t i = 0; i < std::min(op_name_a.size(), op_name_b.size()); ++i) {
- if (op_name_a[i] != op_name_b[i]) {
- break;
- } else if (op_name_a[i] == '/') {
- last_scope_separator = i + 1;
- }
+ const size_t min_size = std::min(op_name_a.size(), op_name_b.size());
+ for (size_t i = 0; i < min_size; ++i) {
+ if (op_name_a[i] != op_name_b[i]) break;
+ if (op_name_a[i] == '/') last_scope_separator = i + 1;
}
return op_name_a.substr(0, last_scope_separator);
}
@@ -2144,10 +2143,14 @@ void Converter::register_op_converters() {
} // namespace
-tensorflow::Status ConvertSubgraphToEngine(
- const tensorflow::GraphDef& gdef, nvinfer1::IBuilder* builder,
+tensorflow::Status ConvertSubGraphDefToEngine(
+ const tensorflow::GraphDef& gdef, int precision_mode,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
- nvinfer1::ICudaEngine** engine, int precision_mode) {
+ nvinfer1::IBuilder* builder,
+ TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
+ bool* convert_successfully) {
+ engine->reset();
+ if (convert_successfully) *convert_successfully = false;
auto trt_network = infer_object(builder->createNetwork());
if (!trt_network) {
return tensorflow::errors::Internal(
@@ -2159,7 +2162,7 @@ tensorflow::Status ConvertSubgraphToEngine(
VLOG(1) << "Starting engine conversion ";
Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE);
std::vector<std::pair<string, string>> output_tensors;
- // graph nodes are already topologically sorted during construction
+ // Graph nodes are already topologically sorted during construction
for (const auto& node_def : gdef.node()) {
string node_name = node_def.name();
VLOG(1) << "Converting op name=" << node_name << ", op=" << node_def.op();
@@ -2215,7 +2218,7 @@ tensorflow::Status ConvertSubgraphToEngine(
}
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
(node_def.op() == "Identity")) {
- tensorflow::int32 slot_number = -1;
+ int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
&slot_number)) {
LOG(ERROR) << "Failed to parse slot number from " << node_name
@@ -2248,122 +2251,130 @@ tensorflow::Status ConvertSubgraphToEngine(
converter.network()->markOutput(*tensor);
}
+ if (convert_successfully) *convert_successfully = true;
+
+ // Build the engine.
VLOG(1) << "Starting engine creation";
- *engine = builder->buildCudaEngine(*converter.network());
+ engine->reset(builder->buildCudaEngine(*converter.network()));
+ if (engine->get() == nullptr) {
+ return tensorflow::errors::Internal("Failed to build TensorRT engine");
+ }
VLOG(1) << "Finished conversion";
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertSegmentToGraphDef(
+tensorflow::Status ConvertSegmentToSubGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
- const std::vector<int>& subgraph_node_ids,
- std::vector<EngineConnections>* connections,
+ const std::vector<int>& subgraph_node_ids, // In topological order
+ std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope) {
std::set<string> marker_nodes;
+ // Update connection shapes/data types and add corresponding input/output
+ // nodes in the segment graphdef.
for (size_t i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
auto outside_node = graph->FindNodeId(connection.outside_id);
- if (outside_node) {
- tensorflow::DataType input_type = tensorflow::DT_FLOAT;
- tensorflow::PartialTensorShape partial_shape;
- if (connection.is_input_edge) {
- if (graph_properties.HasOutputProperties(
- connection.outside_node_name)) {
- auto output_params = graph_properties.GetOutputProperties(
- connection.outside_node_name);
- auto out_shape = output_params.at(connection.outside_port);
- input_type = out_shape.dtype();
- std::vector<tensorflow::int64> dims;
- partial_shape = out_shape.shape();
- connection.outside_shape = partial_shape;
- } else {
- VLOG(0) << "Unknown output shape" << outside_node->name();
- input_type = graph->FindNodeId(connection.outside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
-
- } else { // output edge
- if (graph_properties.HasInputProperties(connection.outside_node_name)) {
- auto input_params =
- graph_properties.GetInputProperties(connection.outside_node_name);
- auto in_shape = input_params.at(connection.outside_port);
- input_type = in_shape.dtype();
- partial_shape = in_shape.shape();
- connection.inside_shape = partial_shape;
- } else {
- input_type = graph->FindNodeId(connection.inside_id)
- ->output_type(connection.outside_port);
- }
- connection.connection_type = input_type;
+ if (!outside_node) {
+ // TODO(aaroey): this should never happen, so make it a CHECK?
+ return tensorflow::errors::NotFound(
+ "Cannot find node with id ", connection.outside_id, " in the graph.");
+ }
+ // Updates the shape and data types of input/output connections.
+ tensorflow::DataType input_type = tensorflow::DT_FLOAT;
+ tensorflow::PartialTensorShape partial_shape;
+ if (connection.is_input_edge) {
+ if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
+ auto output_params = graph_properties.GetOutputProperties(
+ connection.outside_node_name);
+ auto out_shape = output_params.at(connection.outside_port);
+ input_type = out_shape.dtype();
+ std::vector<tensorflow::int64> dims;
+ partial_shape = out_shape.shape();
+ connection.outside_shape = partial_shape;
+ } else {
+ VLOG(0) << "Unknown output shape" << outside_node->name();
+ input_type = graph->FindNodeId(connection.outside_id)
+ ->output_type(connection.outside_port);
}
+ connection.connection_type = input_type;
+
+ } else { // output edge
+ if (graph_properties.HasInputProperties(connection.outside_node_name)) {
+ auto input_params =
+ graph_properties.GetInputProperties(connection.outside_node_name);
+ auto in_shape = input_params.at(connection.outside_port);
+ input_type = in_shape.dtype();
+ partial_shape = in_shape.shape();
+ connection.inside_shape = partial_shape;
+ } else {
+ input_type = graph->FindNodeId(connection.inside_id)
+ ->output_type(connection.outside_port);
+ }
+ connection.connection_type = input_type;
+ }
- tensorflow::NodeDef dummy_placeholder;
- string node_name;
- if (connection.is_input_edge) {
- StrAppend(&node_name, kInputPHName, connection.port_number);
- if (marker_nodes.count(node_name)) {
- VLOG(1) << "Reusing input " << node_name << " for the edge "
- << connection.outside_node_name << ":"
- << connection.outside_port << " -> "
- << connection.inside_node_name << ":"
- << connection.inside_port;
- continue;
- }
- marker_nodes.insert(node_name);
- auto seg_node = segment_def->add_node();
- tensorflow::NodeDefBuilder dph_builder(node_name, "Placeholder");
- auto status = dph_builder.Attr("shape", partial_shape)
- .Attr("dtype", input_type)
- .Finalize(seg_node);
- VLOG(1) << "Constructing input " << node_name << " for the edge "
+ // Add dummy input/output nodes to the segment graphdef.
+ if (connection.is_input_edge) {
+ const string node_name = StrCat(kInputPHName, connection.port_number);
+ if (marker_nodes.count(node_name)) {
+ VLOG(1) << "Reusing input " << node_name << " for the edge "
<< connection.outside_node_name << ":"
<< connection.outside_port << " -> "
- << connection.inside_node_name << ":" << connection.inside_port;
- } else {
- StrAppend(&node_name, kOutputPHName, connection.port_number);
- if (marker_nodes.count(node_name)) {
- VLOG(1) << "Reusing output " << node_name << " for the edge "
- << connection.inside_node_name << ":"
- << connection.inside_port << " -> "
- << connection.outside_node_name << ":"
- << connection.outside_port;
- continue;
- }
- marker_nodes.insert(node_name);
- auto seg_node = segment_def->add_node();
- tensorflow::NodeDefBuilder dph_builder(node_name, "Identity");
- auto status =
- dph_builder.Input(connection.inside_node_name, 0, input_type)
- .Finalize(seg_node);
- VLOG(1) << "Constructing output " << node_name << " for the edge "
- << connection.inside_node_name << ":" << connection.inside_port
- << " -> " << connection.outside_node_name << ":"
+ << connection.inside_node_name << ":"
+ << connection.inside_port;
+ continue;
+ }
+ marker_nodes.insert(node_name);
+ auto seg_node = segment_def->add_node();
+ tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
+ auto status = builder.Attr("shape", partial_shape)
+ .Attr("dtype", input_type).Finalize(seg_node);
+ VLOG(1) << "Constructing input " << node_name << " for the edge "
+ << connection.outside_node_name << ":"
+ << connection.outside_port << " -> "
+ << connection.inside_node_name << ":" << connection.inside_port;
+ } else {
+ const string node_name = StrCat(kOutputPHName, connection.port_number);
+ if (marker_nodes.count(node_name)) {
+ VLOG(1) << "Reusing output " << node_name << " for the edge "
+ << connection.inside_node_name << ":"
+ << connection.inside_port << " -> "
+ << connection.outside_node_name << ":"
<< connection.outside_port;
+ continue;
}
+ marker_nodes.insert(node_name);
+ auto seg_node = segment_def->add_node();
+ tensorflow::NodeDefBuilder builder(node_name, "Identity");
+ auto status = builder.Input(connection.inside_node_name, 0, input_type)
+ .Finalize(seg_node);
+ VLOG(1) << "Constructing output " << node_name << " for the edge "
+ << connection.inside_node_name << ":" << connection.inside_port
+ << " -> " << connection.outside_node_name << ":"
+ << connection.outside_port;
}
- }
- std::unordered_map<int, int> newIdMap;
- // Copy nodes to new graphdef
+ } // for each connection.
+
+ std::unordered_map<int, int> old_to_new_id_map;
+ // Copy internal nodes to new graphdef
string local_scope = graph->FindNodeId(*subgraph_node_ids.begin())->name();
for (const auto node_id : subgraph_node_ids) {
const auto node = graph->FindNodeId(node_id);
local_scope = GetCommonNameScope(local_scope, node->name());
- if (node) {
- newIdMap[node_id] = segment_def->node_size();
- auto snode = segment_def->add_node();
- snode->CopyFrom(node->def());
- VLOG(1) << "Copying " << snode->name() << " to subgraph";
- }
+ old_to_new_id_map[node_id] = segment_def->node_size();
+ auto snode = segment_def->add_node();
+ snode->CopyFrom(node->def());
+ VLOG(1) << "Copying " << snode->name() << " to subgraph";
}
- // update the inputs of the new nodes to point to dummy inputs
+ // Update the inputs of the new input nodes to point to placeholder nodes.
for (int i = 0; i < connections->size(); ++i) {
auto& connection = connections->at(i);
if (!connection.is_input_edge) continue;
- auto snode = segment_def->mutable_node(newIdMap[connection.inside_id]);
- string placeholder_name(kInputPHName);
- StrAppend(&placeholder_name, connection.port_number);
+ auto snode = segment_def->mutable_node(
+ old_to_new_id_map[connection.inside_id]);
+ const string placeholder_name =
+ StrCat(kInputPHName, connection.port_number);
VLOG(1) << "Updating " << snode->name() << ":" << connection.inside_port
<< " from " << snode->input(connection.inside_port) << " to "
<< placeholder_name;
@@ -2373,6 +2384,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
VLOG(0) << "Segment @scope '" << local_scope << "', converted to graph";
return tensorflow::Status::OK();
}
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index 971322d07c..b8d6012df2 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -22,11 +22,13 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
+
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -36,11 +38,13 @@ static const char* kInputPHName = "InputPH_";
static const char* kOutputPHName = "OutputPH_";
namespace convert {
+// TODO(aaroey): use an enum instead.
const int FP32MODE = 0;
const int FP16MODE = 1;
const int INT8MODE = 2;
-struct EngineConnections {
- EngineConnections(const string& outside, int out_id, int out_port,
+
+struct EngineConnection {
+ EngineConnection(const string& outside, int out_id, int out_port,
const string& inside, int in_id, int in_port,
bool input_edge, int port)
: outside_node_name(outside),
@@ -51,16 +55,21 @@ struct EngineConnections {
inside_port(in_port),
is_input_edge(input_edge),
port_number(port) {}
+
const string outside_node_name;
const int outside_id;
const int outside_port;
tensorflow::PartialTensorShape outside_shape;
- tensorflow::DataType connection_type;
+
const string inside_node_name;
const int inside_id;
const int inside_port;
tensorflow::PartialTensorShape inside_shape;
+
+ tensorflow::DataType connection_type;
bool is_input_edge;
+
+ // The port number of the TRT node connecting to this edge.
int port_number;
};
@@ -68,36 +77,54 @@ struct EngineInfo {
EngineInfo()
: engine_type(EngineType::TRTStatic),
max_workspace_size_bytes(0),
- precision_mode(FP32MODE){};
+ precision_mode(FP32MODE) {};
+
string engine_name;
string device;
tensorflow::GraphDef segment_graph_def;
- std::vector<EngineConnections> connections; // order matters!
+
+ // The segment nodes that are on one side of the edges are topological sorted.
+ std::vector<EngineConnection> connections;
+
enum class EngineType { TRTStatic = 0, TRTDynamic = 1 };
EngineType engine_type;
- tensorflow::int64 max_workspace_size_bytes;
+ int64 max_workspace_size_bytes;
int maximum_cached_engines;
std::vector<int> cached_engine_batches;
int precision_mode;
};
-;
-// Constructs a graphdef from the segment in the given graph. Adds placeholder
-// nodes for input edges (InputPH_*) and identity nodes for output edges
-// (OutputPH_*). This function needs to be called before TensorRT nodes
-// inserted in order to correctly get sizes from the original graph.
-tensorflow::Status ConvertSegmentToGraphDef(
+// Constructs a graphdef from the segment in the given graph. Adds placeholder
+// nodes for input edges (InputPH_*) and identity nodes for output edges
+// (OutputPH_*). This function needs to be called before TensorRT nodes
+// inserted in order to correctly get sizes from the original graph.
+//
+// - subgraph_node_ids: the node ids of the subgraph, must be sorted in
+// topological order.
+// - segment_def: the output GraphDef, whose non-input/output nodedefs will be
+// sorted in topological order.
+tensorflow::Status ConvertSegmentToSubGraphDef(
const tensorflow::Graph* graph,
const tensorflow::grappler::GraphProperties& graph_properties,
const std::vector<int>& subgraph_node_ids,
- std::vector<EngineConnections>* connections,
+ std::vector<EngineConnection>* connections,
tensorflow::GraphDef* segment_def, string* common_scope);
-// Converts given subgraph to a TRT engine.
-tensorflow::Status ConvertSubgraphToEngine(
- const tensorflow::GraphDef& gdef, nvinfer1::IBuilder* builder,
+// Converts given subgraph to a TRT engine saved in 'engine'. Returns ok iff
+// 'builder' successfully build the engine. If the result is not ok, 'engine'
+// will be set to nullptr
+// Once returned, 'builder' is not needed any more and can be safely detroyed.
+//
+// - convert_successfully: indicates whether the converson to TensorRT network
+// is successful. This is different than successfully building the engine:
+// building can still fail afterwards.
+tensorflow::Status ConvertSubGraphDefToEngine(
+ const tensorflow::GraphDef& gdef, int precision_mode,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
- nvinfer1::ICudaEngine** engine, int precision_mode);
+ nvinfer1::IBuilder* builder,
+ TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
+ bool* convert_successfully);
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 2dddc4541c..0d1d7e3b0e 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h"
#include <algorithm>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h"
@@ -32,14 +33,14 @@ limitations under the License.
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
-static ::tensorflow::tensorrt::Logger logger;
-using IRuntime = nvinfer1::IRuntime;
-using Dims = nvinfer1::Dims;
-
namespace tensorrt {
-using tensorflow::strings::StrAppend;
-using tensorflow::strings::StrCat;
-// A helper class to call done() for asynchronous execution.
+static Logger logger;
+using ::nvinfer1::IRuntime;
+using ::nvinfer1::Dims;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
+// A helper class to call done() when destructed for asynchronous execution.
// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public tensorflow::core::RefCounted {
public:
@@ -78,8 +79,8 @@ tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
if (fdef == nullptr) {
return tensorflow::errors::Internal(
- StrCat("Native FunctionDef ", funcdef_name_,
- " can't be found in function library"));
+ "Native FunctionDef ", funcdef_name_,
+ " can't be found in function library");
}
tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.overlay_lib = nullptr;
@@ -122,15 +123,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
if (precision_string == "FP32") {
- precision_mode_ = tensorflow::tensorrt::convert::FP32MODE;
+ precision_mode_ = convert::FP32MODE;
} else if (precision_string == "FP16") {
- precision_mode_ = tensorflow::tensorrt::convert::FP16MODE;
+ precision_mode_ = convert::FP16MODE;
} else if (precision_string == "INT8") {
- precision_mode_ = tensorflow::tensorrt::convert::INT8MODE;
+ precision_mode_ = convert::INT8MODE;
}
- calibration_mode_ =
- precision_mode_ == tensorflow::tensorrt::convert::INT8MODE &&
- calibration_data.size() == 0;
+ calibration_mode_ = (precision_mode_ == convert::INT8MODE &&
+ calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
@@ -190,21 +190,20 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
ctx->set_output(t, outputs->at(t));
}
delete outputs;
- return;
});
- return;
}
void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
AsyncHelper* helper) {
+ helper->Ref();
tensorflow::core::ScopedUnref sc(helper);
- auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance();
+ // TODO(aaroey): remove the ResourceMgr singleton.
+ auto trt_rm = TRTResourceManager::instance();
auto res_mgr = trt_rm->getManager("TRTCalibration");
- tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
+ TRTCalibrationResource* calib_res = nullptr;
auto status = res_mgr->LookupOrCreate(
funcdef_name_, "Calibrator", &calib_res,
- {[ctx, this](tensorflow::tensorrt::TRTCalibrationResource** cr)
- -> tensorflow::Status {
+ {[ctx, this](TRTCalibrationResource** cr) -> tensorflow::Status {
return this->AllocateCalibrationResources(ctx, cr);
}});
if (!status.ok()) {
@@ -219,7 +218,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
void* data_address = GetTensorAddress(&t);
if (data_address == nullptr) {
ctx->SetStatus(tensorflow::errors::InvalidArgument(
- StrCat("Unsupported data type encountered in input ", i)));
+ "Unsupported data type encountered in input ", i));
return;
}
// Check the allocated buffer is sufficient for input
@@ -237,7 +236,6 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
calib_res->calibrator_->setBatch(input_data, *stream);
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
- return;
}
int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext* ctx) {
@@ -274,27 +272,28 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
auto helper = new AsyncHelper(done);
tensorflow::core::ScopedUnref sc(helper);
if (calibration_mode_) {
- helper->Ref();
ExecuteCalibration(ctx, helper);
return;
}
- int num_binding = ctx->num_inputs() + ctx->num_outputs();
- std::vector<void*> buffers(num_binding);
- int smallest_engine = GetEngineBatch(ctx);
- if (smallest_engine < 0) return;
- int num_batch = ctx->input(0).shape().dim_size(0);
- size_t binding_index;
- auto engine_ctx_pair = GetEngine(smallest_engine, ctx, fixed_input_size_);
- auto trt_engine_ptr = engine_ctx_pair.first;
+ const int smallest_engine = GetEngineBatch(ctx);
+ if (smallest_engine < 0) return; // GetEngineBatch already set the status.
+
+ const int num_batch = ctx->input(0).shape().dim_size(0);
+ auto& engine_ctx_pair = GetEngine(smallest_engine, ctx);
+ auto& trt_engine_ptr = engine_ctx_pair.first;
if (!trt_engine_ptr) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
<< " failed Running native segment";
ExecuteNativeSegment(ctx, helper);
return;
}
+
+ const int num_binding = ctx->num_inputs() + ctx->num_outputs();
+ std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
- string inp_name = StrCat(kInputPHName, i);
- binding_index = trt_engine_ptr->getBindingIndex(inp_name.c_str());
+ const string inp_name = StrCat(kInputPHName, i);
+ const size_t binding_index = trt_engine_ptr->getBindingIndex(
+ inp_name.c_str());
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
@@ -322,17 +321,16 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unknown ouput TRT data type! " + int(dtype)));
+ "Unknown ouput TRT data type! ", int(dtype)));
return;
}
}
for (int i = 0; i < ctx->num_outputs(); i++) {
- // This is bad that we have to reallocate output buffer every run.
// Create an output tensor
-
- auto output_name = StrCat(kOutputPHName, i);
- binding_index = trt_engine_ptr->getBindingIndex(output_name.c_str());
+ const string output_name = StrCat(kOutputPHName, i);
+ const size_t binding_index = trt_engine_ptr->getBindingIndex(
+ output_name.c_str());
Tensor* output_tensor = nullptr;
TensorShape output_shape;
@@ -346,8 +344,8 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
&output_shape));
} else {
LOG(ERROR) << "output node not found, at " << output_name;
- ctx->SetStatus(tensorflow::errors::Internal("output " + output_name +
- " but couldn't be found!"));
+ ctx->SetStatus(tensorflow::errors::Internal(
+ "output ", output_name, " couldn't be found!"));
return;
}
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
@@ -375,7 +373,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
ctx->SetStatus(tensorflow::errors::InvalidArgument(
- "Unsupported output data type! " + int(dtype)));
+ "Unsupported output data type! ", int(dtype)));
return;
}
}
@@ -387,46 +385,47 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
->CudaStreamMemberHack()));
// TODO(jie): trt enqueue does not return error
- auto trt_execution_context_ptr = engine_ctx_pair.second;
+ auto& trt_execution_context_ptr = engine_ctx_pair.second;
auto ret = trt_execution_context_ptr->enqueue(num_batch, &buffers[0], *stream,
nullptr);
if (!ret) {
- LOG(ERROR) << "Enqueueing of TRT execution failed!";
+ LOG(ERROR) << "Failed to enqueue batch for TRT engine: " << name();
+ ctx->SetStatus(tensorflow::errors::Internal(
+ "Failed to enqueue batch for TRT engine: ", name()));
}
// sync should be done by TF.
}
TRTEngineOp::~TRTEngineOp() {
- // Order matters!
- for (auto eng : engine_map_) {
+ // We need to manually destroy the engine and execution context before
+ // the allocator is destructed.
+ for (auto& eng : engine_map_) {
eng.second.first.reset();
eng.second.second.reset();
}
- for (auto alloc : allocators_) alloc.second.reset();
+ allocator_.reset();
}
nvinfer1::IGpuAllocator* TRTEngineOp::GetAllocator(OpKernelContext* ctx) {
+ if (allocator_) return allocator_.get();
auto device = ctx->device();
- const auto& device_name = device->name();
- if (allocators_.count(device_name)) {
- return allocators_.at(device_name).get();
- }
- auto dev_allocator = device->GetAllocator(tensorflow::AllocatorAttributes());
- if (!dev_allocator) {
+ auto alloc = device->GetAllocator(tensorflow::AllocatorAttributes());
+ if (!alloc) {
LOG(ERROR) << "Can't find device allocator for gpu device "
<< device->name();
ctx->SetStatus(tensorflow::errors::Internal(
- StrCat("Can't get device allocator for device ", device_name)));
+ "Can't get device allocator for device ", device->name()));
return nullptr;
}
- auto allocator = std::make_shared<TRTDeviceAllocator>(dev_allocator);
- allocators_.insert({device_name, allocator});
- return allocator.get();
+ allocator_.reset(new TRTDeviceAllocator(alloc));
+ return allocator_.get();
}
-TRTEngineOp::EngineCtxPair TRTEngineOp::GetEngine(int batch_size,
- OpKernelContext* ctx,
- bool ignore_dim_change) {
+TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
+ OpKernelContext* ctx) {
+ static EngineCtxPair null_pair = {
+ TrtUniquePtrType<nvinfer1::ICudaEngine>(nullptr),
+ TrtUniquePtrType<nvinfer1::IExecutionContext>(nullptr)};
// TODO(sami): This method needs to be re-written to use resource manager and
// with LRU mechanism option.
tensorflow::mutex_lock lock(engine_mutex_);
@@ -435,113 +434,106 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::GetEngine(int batch_size,
if (engine_map_.size()) {
if (engine_map_.begin()->first >= batch_size) {
return engine_map_.begin()->second;
- } else {
- return {nullptr, nullptr};
}
- } else {
- std::shared_ptr<IRuntime> infer(nvinfer1::createInferRuntime(logger),
- [](IRuntime* p) {
- if (p) p->destroy();
- });
+ return null_pair;
+ }
+ TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
#if NV_TENSORRT_MAJOR > 3
- auto allocator = GetAllocator(ctx);
- if (allocator == nullptr) {
- return {nullptr, nullptr};
- };
- infer->setGpuAllocator(allocator);
+ auto allocator = GetAllocator(ctx);
+ if (allocator == nullptr) {
+ return null_pair;
+ };
+ infer->setGpuAllocator(allocator);
#endif
- std::shared_ptr<nvinfer1::ICudaEngine> static_engine(
- infer->deserializeCudaEngine(serialized_segment_.c_str(),
- serialized_segment_.size(), nullptr),
- Destroyer<nvinfer1::ICudaEngine>());
- engine_map_.insert({static_engine->getMaxBatchSize(),
- {static_engine,
- {static_engine->createExecutionContext(),
- Destroyer<nvinfer1::IExecutionContext>()}}});
- // Runtime is safe to delete after engine creation
- serialized_segment_.clear();
- if (static_engine->getMaxBatchSize() < batch_size) {
- return {nullptr, nullptr};
- }
- return engine_map_.at(static_engine->getMaxBatchSize());
- }
- } else {
- auto engine_it = engine_map_.find(batch_size);
- if (engine_it == engine_map_.end() &&
- engine_map_.size() < (size_t)max_cached_engines_) {
- auto builder = std::shared_ptr<nvinfer1::IBuilder>(
- nvinfer1::createInferBuilder(logger),
- Destroyer<nvinfer1::IBuilder>()); // reset the builder to ensure
- // device is correct
+ TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
+ infer->deserializeCudaEngine(serialized_segment_.c_str(),
+ serialized_segment_.size(), nullptr));
+ auto raw_static_engine = static_engine.get();
+ const auto max_batch_size = raw_static_engine->getMaxBatchSize();
+ engine_map_[max_batch_size] = {
+ std::move(static_engine),
+ TrtUniquePtrType<nvinfer1::IExecutionContext>(
+ raw_static_engine->createExecutionContext())};
+ // Runtime is safe to delete after engine creation
+ serialized_segment_.clear();
+ if (max_batch_size < batch_size) return null_pair;
+ return engine_map_.at(max_batch_size);
+ } // static_engine_
+
+ // Handle the dynamic engine case.
+ auto engine_it = engine_map_.find(batch_size);
+ if (engine_it == engine_map_.end() &&
+ engine_map_.size() < (size_t)max_cached_engines_) {
+ TrtUniquePtrType<nvinfer1::IBuilder> builder(
+ nvinfer1::createInferBuilder(logger));
#if NV_TENSORRT_MAJOR > 3
- auto allocator = GetAllocator(ctx);
- if (allocator == nullptr) {
- return {nullptr, nullptr};
- }
- builder->setGpuAllocator(allocator);
+ auto allocator = GetAllocator(ctx);
+ if (allocator == nullptr) {
+ // GetAllocator already set the Status.
+ return null_pair;
+ }
+ builder->setGpuAllocator(allocator);
#endif
- VLOG(0) << name() << " Constructing a new engine with batch size "
- << batch_size;
- builder->setMaxBatchSize(batch_size);
- if (precision_mode_ == tensorflow::tensorrt::convert::FP16MODE) {
- builder->setHalf2Mode(true);
- } else if (precision_mode_ == tensorflow::tensorrt::convert::INT8MODE) {
- builder->setInt8Mode(true);
- builder->setInt8Calibrator(calibrator_.get());
- }
- builder->setMaxWorkspaceSize(workspace_size_);
- nvinfer1::ICudaEngine* engine = nullptr;
- std::vector<tensorflow::PartialTensorShape> shapes;
- for (int i = 0; i < ctx->num_inputs(); ++i) {
- shapes.emplace_back(ctx->input(i).shape());
- }
- VLOG(1) << "Calling conversion for " << batch_size << " " << name();
- auto status = tensorflow::tensorrt::convert::ConvertSubgraphToEngine(
- segment_graph_, builder.get(), shapes, &engine, precision_mode_);
- VLOG(1) << "Conversion is done";
- if (engine) {
- engine_map_[batch_size] = {
- std::shared_ptr<nvinfer1::ICudaEngine>(
- engine, Destroyer<nvinfer1::ICudaEngine>()),
- std::shared_ptr<nvinfer1::IExecutionContext>(
- engine->createExecutionContext(),
- Destroyer<nvinfer1::IExecutionContext>())};
- } else {
- LOG(ERROR) << "Engine creation for batch size " << batch_size
- << " failed";
- ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
+ VLOG(0) << name() << " Constructing a new engine with batch size "
+ << batch_size;
+ builder->setMaxBatchSize(batch_size);
+ if (precision_mode_ == convert::FP16MODE) {
+ builder->setHalf2Mode(true);
+ } else if (precision_mode_ == convert::INT8MODE) {
+ builder->setInt8Mode(true);
+ // TODO(aaroey): what if it's empty? I.e. when calibration data is empty?
+ builder->setInt8Calibrator(calibrator_.get());
+ }
+ // TODO(aaroey): use the allocator to allocate the TRT workspace.
+ builder->setMaxWorkspaceSize(workspace_size_);
+ std::vector<tensorflow::PartialTensorShape> shapes;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ shapes.emplace_back(ctx->input(i).shape());
+ }
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
+ bool convert_successfully = false;
+ VLOG(1) << "Calling conversion for " << batch_size << " " << name();
+ auto status = convert::ConvertSubGraphDefToEngine(
+ segment_graph_, precision_mode_, shapes, builder.get(), &engine,
+ &convert_successfully);
+ if (!status.ok()) {
+ if (convert_successfully) {
+ // This means it fail to build the engine even when the network is built
+ // successfully, probably due to internal issues. In this case we don't
+ // retry in the future.
engine_map_[batch_size] = {nullptr, nullptr};
- return {nullptr, nullptr};
}
+ LOG(ERROR) << "Engine creation for batch size " << batch_size
+ << " failed " << status;
+ ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
+ return null_pair;
}
- return engine_map_.at(batch_size);
+ VLOG(1) << "Conversion is done";
+ TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
+ engine->createExecutionContext());
+ engine_map_[batch_size] = {std::move(engine), std::move(exec_context)};
}
+ return engine_map_.at(batch_size);
}
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
tensorflow::OpKernelContext* ctx,
- tensorflow::tensorrt::TRTCalibrationResource** cr) {
+ TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
- cres->logger_ = new tensorflow::tensorrt::Logger();
+ cres->logger_ = new Logger();
#if NV_TENSORRT_MAJOR > 3
- auto dev = ctx->device();
- auto dev_allocator = dev->GetAllocator(tensorflow::AllocatorAttributes());
- if (!dev_allocator) {
+ auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
+ if (!alloc) {
LOG(WARNING) << "Can't get device allocator will not be able to "
"allocate memory from TensorFlow memory pool";
- cres->allocator_ =
- std::make_shared<tensorflow::tensorrt::TRTCudaAllocator>();
+ cres->allocator_.reset(new TRTCudaAllocator);
} else {
- cres->allocator_ =
- std::make_shared<tensorflow::tensorrt::TRTDeviceAllocator>(
- dev_allocator);
+ cres->allocator_.reset(new TRTDeviceAllocator(alloc));
}
-
#endif
int batch_size = ctx->input(0).dim_size(0);
- cres->engine_ = nullptr;
std::vector<tensorflow::PartialTensorShape> shapes;
int num_inputs = ctx->num_inputs();
// first run instantiate calibrator
@@ -558,7 +550,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
void* device_address = GetTensorAddress(device_tensor);
if (device_address == nullptr) {
return tensorflow::errors::InvalidArgument(
- StrCat("Unsupported data type encountered in input ", i));
+ "Unsupported data type encountered in input ", i);
}
device_buffers_.emplace(
StrCat(kInputPHName, i),
@@ -579,26 +571,29 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
batch_size, workspace_size]() {
VLOG(0) << "Starting calibration thread on device " << cuda_device
<< ", Calibration Resource @ " << cres;
- // ConvertSubgraphToEngine() will try to build the engine and this thread
- // will be consuming the calibration data that is set by the TF op, driving
- // the builder until calibrator returns false; Engine is discarded after
- // calibration table is generated
auto err = cudaSetDevice(cuda_device);
if (err != cudaSuccess) {
VLOG(0) << "Couldn't set cuda device to " << cuda_device
<< " in calibration thread";
}
// initialize builder here
- cres->builder_ = nvinfer1::createInferBuilder(*(cres->logger_));
- cres->builder_->setGpuAllocator(cres->allocator_.get());
+ cres->builder_.reset(nvinfer1::createInferBuilder(*(cres->logger_)));
+ // TODO(aaroey): maybe setting the max batch size using the python
+ // calibration wrapper class.
cres->builder_->setMaxBatchSize(batch_size);
+#if NV_TENSORRT_MAJOR > 3
+ cres->builder_->setGpuAllocator(cres->allocator_.get());
+#endif
cres->builder_->setInt8Mode(true);
cres->builder_->setMaxWorkspaceSize(workspace_size);
cres->builder_->setInt8Calibrator(cres->calibrator_);
- auto s = tensorflow::tensorrt::convert::ConvertSubgraphToEngine(
- *segment_graph, cres->builder_, shapes, &cres->engine_,
- tensorflow::tensorrt::convert::INT8MODE); // calibrator will loop until
- // we terminate calibration
+ // ConvertSubGraphDefToEngine() will try to build the engine. This thread
+ // will loop inside buildCudaEngine() consuming the calibration data
+ // that is set by the TF op, and drive the builder until calibrator returns
+ // false. Engine is discarded after calibration table is generated
+ auto s = convert::ConvertSubGraphDefToEngine(
+ *segment_graph, convert::INT8MODE, shapes, cres->builder_.get(),
+ &cres->engine_, /*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR)
<< "Calibration failed. Engine will not be calibrated! Error is" << s;
@@ -609,6 +604,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
VLOG(1) << "initialized calibrator resource";
return tensorflow::Status::OK();
}
+
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 6faef09b62..cb43403130 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -33,7 +34,6 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-class Logger;
class TRTInt8Calibrator;
class TRTCalibrationResource;
class AsyncHelper;
@@ -50,13 +50,6 @@ class TRTEngineOp : public AsyncOpKernel {
~TRTEngineOp();
private:
- template <typename T>
- struct Destroyer {
- void operator()(T* d) {
- if (d) d->destroy();
- }
- };
-
// Execute calibration
void ExecuteCalibration(tensorflow::OpKernelContext* ctx,
AsyncHelper* helper);
@@ -74,11 +67,10 @@ class TRTEngineOp : public AsyncOpKernel {
tensorflow::tensorrt::TRTCalibrationResource** cr);
// TODO(samikama): context should go to a resource manager!
- typedef std::pair<std::shared_ptr<nvinfer1::ICudaEngine>,
- std::shared_ptr<nvinfer1::IExecutionContext>>
+ typedef std::pair<TrtUniquePtrType<nvinfer1::ICudaEngine>,
+ TrtUniquePtrType<nvinfer1::IExecutionContext>>
EngineCtxPair;
- EngineCtxPair GetEngine(int batch_size, OpKernelContext* ctx,
- bool ignore_dim_change = true);
+ EngineCtxPair& GetEngine(int batch_size, OpKernelContext* ctx);
// Return engine batch closest to input batch.
int GetEngineBatch(OpKernelContext* ctx);
@@ -89,32 +81,45 @@ class TRTEngineOp : public AsyncOpKernel {
std::unordered_map<int, EngineCtxPair> engine_map_;
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
+
// keep device allocator for TRT.
- std::unordered_map<string, std::shared_ptr<TRTDeviceAllocator>> allocators_;
+ std::unique_ptr<TRTDeviceAllocator> allocator_;
+
// serialized protobuf segment or trt engine depending on static_engine_ flag.
string serialized_segment_;
+
// Name of the function for TF native execution of the segment.
string funcdef_name_;
+
// GraphDef representation of the segment.
tensorflow::GraphDef segment_graph_;
+
// Lookup table for temporary staging areas of input tensors for calibration.
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
+
// Temporary staging areas for calibration inputs.
std::vector<tensorflow::PersistentTensor> dev_tensors_;
+
// Engine Precision mode.
int precision_mode_;
+
// Whether engine is constructed during the conversion or needs to be
// constructed from protobuf segment.
bool static_engine_;
+
// Whether to calibrate INT8 engine.
bool calibration_mode_;
+
// Whether non-batch ranks of the inputs are assumed to be fixed or not for
- // engine construction
+ // engine construction.
bool fixed_input_size_;
+
// Batches of the cached engines
std::vector<int> cached_engine_batches_;
+
// Maximum number of cached engines
int max_cached_engines_;
+
tensorflow::int64 workspace_size_;
tensorflow::mutex engine_mutex_;
tensorflow::FunctionLibraryRuntime::Handle native_func_;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index 894e9d6e85..994312d7c3 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -39,30 +39,46 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
TRTInt8Calibrator(
const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
int batch_size, string engine_name);
+
TRTInt8Calibrator(const string& calibration_data);
+
+ ~TRTInt8Calibrator();
+
int getBatchSize() const override;
+
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
+
bool setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream);
+
void setDone();
+
+ // If not null, calibration is skipped.
const void* readCalibrationCache(std::size_t& length) override;
+
void writeCalibrationCache(const void* ptr, std::size_t length) override;
+
const string& getCalibrationTableAsString() { return calibration_table_; }
- ~TRTInt8Calibrator();
private:
const int batch_size_;
- tensorflow::mutex cond_mtx_; // mutex for condition_variable
- tensorflow::condition_variable cond_; // condition variable to implement
- // producer-consumer queue for
- // calibration
+
+ // mutex for condition_variable
+ tensorflow::mutex cond_mtx_;
+
+ // condition variable to implement producer-consumer queue for calibration
+ tensorflow::condition_variable cond_;
+
+ // Is calibration finished?
bool done_;
- const std::unordered_map<string, std::pair<void*, size_t>>
- dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with
- // buffer names
+
+ // Map to keep tensorrt input buffers and sizes keyed with buffer names
+ const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_;
+
bool calib_running_;
bool batch_is_set_;
+
string engine_name_;
string calibration_table_;
};
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index 022639dc01..43734bbdd8 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <thread>
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
@@ -34,21 +35,21 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
+
class TRTCalibrationResource : public tensorflow::ResourceBase {
public:
TRTCalibrationResource()
: calibrator_(nullptr),
- builder_(nullptr),
- network_(nullptr),
- engine_(nullptr),
logger_(nullptr),
thr_(nullptr) {}
~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
- builder_->destroy();
- network_->destroy();
- engine_->destroy();
+ builder_.reset();
+ engine_.reset();
+ // We need to manually destroy the builder and engine before the allocator
+ // is destroyed.
+ allocator_.reset();
delete thr_;
delete logger_;
delete calibrator_;
@@ -56,22 +57,22 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
string DebugString() override {
std::stringstream oss;
- oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl
- << " Builder = " << std::hex << builder_ << std::dec << std::endl
- << " Network = " << std::hex << network_ << std::dec << std::endl
- << " Engine = " << std::hex << engine_ << std::dec << std::endl
- << " Logger = " << std::hex << logger_ << std::dec << std::endl
- << " Allocator = " << std::hex << allocator_.get() << std::dec
- << std::endl
- << " Thread = " << std::hex << thr_ << std::dec << std::endl;
+ using std::hex;
+ using std::dec;
+ using std::endl;
+ oss << " Calibrator = " << hex << calibrator_ << dec << endl
+ << " Builder = " << hex << builder_.get() << dec << endl
+ << " Engine = " << hex << engine_.get() << dec << endl
+ << " Logger = " << hex << logger_ << dec << endl
+ << " Allocator = " << hex << allocator_.get() << dec << endl
+ << " Thread = " << hex << thr_ << dec << endl;
return oss.str();
}
TRTInt8Calibrator* calibrator_;
- nvinfer1::IBuilder* builder_;
- nvinfer1::INetworkDefinition* network_;
- nvinfer1::ICudaEngine* engine_;
- std::shared_ptr<nvinfer1::IGpuAllocator> allocator_;
+ TrtUniquePtrType<nvinfer1::IBuilder> builder_;
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
+ std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
tensorflow::tensorrt::Logger* logger_;
// TODO(sami): Use threadpool threads!
std::thread* thr_;
diff --git a/tensorflow/contrib/tensorrt/segment/segment.h b/tensorflow/contrib/tensorrt/segment/segment.h
index 1568dd9153..81b4bfe49f 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.h
+++ b/tensorflow/contrib/tensorrt/segment/segment.h
@@ -29,8 +29,9 @@ namespace tensorflow {
namespace tensorrt {
namespace segment {
-// vector of segments, each entry contains a device name and a set of nodes in
-// segment
+// Vector of segments, each entry contains a set of node names and a device name
+// in the segment.
+// TODO(aaroey): use node pointer instead of node name.
using SegmentNodesVector = std::vector<std::pair<std::set<string>, string>>;
struct SegmentOptions {
@@ -48,6 +49,8 @@ struct SegmentOptions {
// in the vector describes a subgraph by giving a set of the names of
// all the NodeDefs in that subgraph.
// @return the status.
+//
+// TODO(aaroey): remove this method.
tensorflow::Status SegmentGraph(
const tensorflow::GraphDef& gdef,
const std::function<bool(const tensorflow::Node*)>& candidate_fn,