aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensorrt/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc16
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h2
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc20
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc211
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h37
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc17
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h7
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h6
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc60
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i24
11 files changed, 178 insertions, 223 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 55a5a45692..fd0f97f3af 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -187,7 +187,6 @@ tf_py_wrap_cc(
deps = [
":trt_conversion",
":trt_engine_op_kernel",
- #"//tensorflow/core:framework_lite",
"//third_party/python_runtime:headers",
],
)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 36191b5cc6..6ddfb01d9f 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -189,6 +189,11 @@ tensorflow::Status ConvertGraphDefToTensorRT(
VLOG(2) << "cpu_cores: " << num_cpu_cores;
VLOG(2) << "gpus: " << num_gpus;
tensorflow::RewriterConfig rw_cfg;
+ // use only const folding and layout for the time being since new optimizers
+ // break the graph for us
+ rw_cfg.add_optimizers("constfold");
+ rw_cfg.add_optimizers("layout");
+
tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg);
tensorflow::GraphDef gdef;
TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef));
@@ -210,10 +215,13 @@ tensorflow::Status ConvertGraphDefToTensorRT(
cp.minimum_segment_size = minimum_segment_size;
cp.graph_properties = &static_graph_properties;
cp.max_workspace_size_bytes = max_workspace_size_bytes;
- // return ConvertAfterShapes(gdef, output_names, max_batch_size,
- // max_workspace_size_bytes, new_graph_def,
- // precision_mode, minimum_segment_size,
- // static_graph_properties, nullptr);
+ if (VLOG_IS_ON(5)) {
+ std::fstream f;
+ f.open("TRTConversionInput.pb",
+ std::fstream::out | std::fstream::binary | std::fstream::trunc);
+ f << gdef.SerializeAsString();
+ f.close();
+ }
return ConvertAfterShapes(cp);
}
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 9dd4a69965..f742b8acbc 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -30,8 +30,6 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
-// This method converts an already generated calibration graph which was used in
-// calibration runs to an inference graph
struct ConversionParams {
ConversionParams()
: input_graph_def(nullptr),
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index af7830c4e9..68659e4ab5 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -205,16 +205,16 @@ tensorflow::Status TRTOptimizationPass::Optimize(
tensorflow::grappler::GraphProperties static_graph_properties(item);
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
tensorflow::tensorrt::convert::ConversionParams cp;
- cp.input_graph_def=&item.graph;
- cp.output_names=&item.fetch;
- cp.max_batch_size=maximum_batch_size_;
- cp.max_workspace_size_bytes=maximum_workspace_size_;
- cp.output_graph_def=optimized_graph;
- cp.precision_mode=precision_mode_;
- cp.minimum_segment_size=minimum_segment_size_;
- cp.graph_properties=&static_graph_properties;
- cp.cluster=cluster;
- cp.is_dyn_op=false;
+ cp.input_graph_def = &item.graph;
+ cp.output_names = &item.fetch;
+ cp.max_batch_size = maximum_batch_size_;
+ cp.max_workspace_size_bytes = maximum_workspace_size_;
+ cp.output_graph_def = optimized_graph;
+ cp.precision_mode = precision_mode_;
+ cp.minimum_segment_size = minimum_segment_size_;
+ cp.graph_properties = &static_graph_properties;
+ cp.cluster = cluster;
+ cp.is_dyn_op = false;
auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp);
VLOG(2) << optimized_graph->DebugString();
return status;
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index c1371d4830..76153886a8 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -39,6 +39,8 @@ using Dims = nvinfer1::Dims;
namespace tensorrt {
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
+// A helper class to call done() for asynchronous execution.
+// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public tensorflow::core::RefCounted {
public:
AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; }
@@ -100,8 +102,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
context->GetAttr("serialized_segment", &serialized_segment_));
OP_REQUIRES_OK(context,
context->GetAttr("workspace_size_bytes", &workspace_size_));
- OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine));
- if (!static_engine) {
+ OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
+ if (!static_engine_) {
if (!segment_graph_.ParseFromString(serialized_segment_)) {
LOG(ERROR) << "Parsing segment graph failed!";
context->SetStatus(tensorflow::errors::InvalidArgument(
@@ -119,14 +121,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
if (precision_string == "FP32") {
- precision_mode = tensorflow::tensorrt::convert::FP32MODE;
+ precision_mode_ = tensorflow::tensorrt::convert::FP32MODE;
} else if (precision_string == "FP16") {
- precision_mode = tensorflow::tensorrt::convert::FP16MODE;
+ precision_mode_ = tensorflow::tensorrt::convert::FP16MODE;
} else if (precision_string == "INT8") {
- precision_mode = tensorflow::tensorrt::convert::INT8MODE;
+ precision_mode_ = tensorflow::tensorrt::convert::INT8MODE;
}
- calibration_mode =
- precision_mode == tensorflow::tensorrt::convert::INT8MODE &&
+ calibration_mode_ =
+ precision_mode_ == tensorflow::tensorrt::convert::INT8MODE &&
calibration_data_.size() == 0;
if (calibration_data_.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data_));
@@ -134,15 +136,15 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
}
native_func_ = tensorflow::kInvalidHandle;
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
- &max_cached_engines));
+ &max_cached_engines_));
OP_REQUIRES_OK(context,
- context->GetAttr("fixed_input_size", &fixed_input_size));
+ context->GetAttr("fixed_input_size", &fixed_input_size_));
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
- &cached_engine_batches));
- std::sort(cached_engine_batches.begin(), cached_engine_batches.end());
+ &cached_engine_batches_));
+ std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
if (VLOG_IS_ON(1)) {
string s("Engine Batches= ");
- for (auto i : cached_engine_batches) {
+ for (auto i : cached_engine_batches_) {
StrAppend(&s, i, " ");
}
VLOG(1) << s;
@@ -150,8 +152,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
}
void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
- AsyncHelper* ah) {
- if (!calibration_mode) {
+ AsyncHelper* helper) {
+ if (!calibration_mode_) {
VLOG(1) << "Executing native engine";
}
std::vector<Tensor> inputs;
@@ -173,11 +175,11 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
for (int i = 0; i < ctx->num_inputs(); i++) {
inputs.push_back(ctx->input(i));
}
- ah->Ref(); // Increment count for calculating native graph
+ helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment " << name();
lib->Run(opts, native_func_, inputs, outputs,
- [ctx, outputs, ah](const tensorflow::Status& s) {
- tensorflow::core::ScopedUnref SC(ah);
+ [ctx, outputs, helper](const tensorflow::Status& s) {
+ tensorflow::core::ScopedUnref SC(helper);
VLOG(1) << "Native Segment completed";
if (!s.ok()) {
ctx->SetStatus(s);
@@ -192,55 +194,50 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
return;
}
-void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
- tensorflow::AsyncOpKernel::DoneCallback done) {
- auto ah = new AsyncHelper(done);
- tensorflow::core::ScopedUnref SC(ah);
- if (calibration_mode) {
- auto TRT_RM = tensorflow::tensorrt::TRTResourceManager::instance();
- auto res_mgr = TRT_RM->getManager("TRTCalibration");
- tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
- auto status = res_mgr->LookupOrCreate(
- funcdef_name_, "Calibrator", &calib_res,
- {[ctx, this](tensorflow::tensorrt::TRTCalibrationResource** cr)
- -> tensorflow::Status {
- return this->AllocateCalibrationResources(ctx, cr);
- }});
- if (!status.ok()) {
- ctx->SetStatus(status);
- return;
- }
- ExecuteNativeSegment(ctx, ah);
- int num_inputs = ctx->num_inputs();
- // Pass input data to calibrator
- std::unordered_map<string, void*> input_data;
- for (int i = 0; i < num_inputs; i++) {
- const Tensor& t = ctx->input(i);
- void* data_address = GetTensorAddress(&t);
- const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
- CHECK_EQ(t.TotalBytes(),
- device_tensor->TotalBytes()); // use the tensor so FW keeps it
- input_data.emplace(StrCat("InputPH_", i), data_address);
- }
- VLOG(2) << "Filled map for sending";
- // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
- const cudaStream_t* stream = CHECK_NOTNULL(
- reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
- ->stream()
- ->implementation()
- ->CudaStreamMemberHack()));
- ah->Ref(); // Increment count for calculating calibration data
- calib_res->calibrator_->setBatch(input_data, *stream, ah);
- VLOG(2) << "Passed calibration data";
+void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx,
+ AsyncHelper* helper) {
+ tensorflow::core::ScopedUnref SC(helper);
+ auto TRT_RM = tensorflow::tensorrt::TRTResourceManager::instance();
+ auto res_mgr = TRT_RM->getManager("TRTCalibration");
+ tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr;
+ auto status = res_mgr->LookupOrCreate(
+ funcdef_name_, "Calibrator", &calib_res,
+ {[ctx, this](tensorflow::tensorrt::TRTCalibrationResource** cr)
+ -> tensorflow::Status {
+ return this->AllocateCalibrationResources(ctx, cr);
+ }});
+ if (!status.ok()) {
+ ctx->SetStatus(status);
return;
}
- int num_binding = ctx->num_inputs() + ctx->num_outputs();
- std::vector<void*> buffers(num_binding);
+ ExecuteNativeSegment(ctx, helper);
+ int num_inputs = ctx->num_inputs();
+ // Pass input data to calibrator
+ std::unordered_map<string, void*> input_data;
+ for (int i = 0; i < num_inputs; i++) {
+ const Tensor& t = ctx->input(i);
+ void* data_address = GetTensorAddress(&t);
+ const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx);
+ CHECK_EQ(t.TotalBytes(),
+ device_tensor->TotalBytes()); // use the tensor so FW keeps it
+ input_data.emplace(StrCat("InputPH_", i), data_address);
+ }
+ VLOG(2) << "Filled map for sending";
+ // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
+ const cudaStream_t* stream = CHECK_NOTNULL(
+ reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
+ ->stream()
+ ->implementation()
+ ->CudaStreamMemberHack()));
+ calib_res->calibrator_->setBatch(input_data, *stream);
+ VLOG(2) << "Passed calibration data";
+ return;
+}
- size_t binding_index;
+int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext *ctx){
int num_batch = ctx->input(0).shape().dim_size(0);
int smallest_engine = 0;
- for (const auto i : cached_engine_batches) {
+ for (const auto i : cached_engine_batches_) {
if (i >= num_batch) {
smallest_engine = i;
break;
@@ -248,32 +245,46 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
}
// TODO(sami): Need an LRU here
if (smallest_engine == 0) {
- if (max_cached_engines > cached_engine_batches.size()) {
+ if (max_cached_engines_ > cached_engine_batches_.size()) {
smallest_engine = num_batch;
- cached_engine_batches.push_back(num_batch);
- std::sort(cached_engine_batches.begin(), cached_engine_batches.end());
+ cached_engine_batches_.push_back(num_batch);
VLOG(1) << "Running with batch size " << num_batch;
} else {
string s("Engine buffer is full. buffer limit= ");
- StrAppend(&s, max_cached_engines, ", current entries= ");
- for (auto i : cached_engine_batches) StrAppend(&s, i, ", ");
+ StrAppend(&s, max_cached_engines_, ", current entries= ");
+ for (auto i : cached_engine_batches_) StrAppend(&s, i, ", ");
StrAppend(&s, "Requested batch= ", num_batch);
LOG(ERROR) << s;
ctx->SetStatus(tensorflow::errors::ResourceExhausted(
"Requested batch size is not available and engine cache is full"));
- return;
+ return -1;
}
}
- auto engine_ctx_pair = get_engine(smallest_engine, ctx, fixed_input_size);
+ return smallest_engine;
+}
+
+void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
+ tensorflow::AsyncOpKernel::DoneCallback done) {
+ auto ah = new AsyncHelper(done);
+ tensorflow::core::ScopedUnref SC(ah);
+ if (calibration_mode_) {
+ ah->Ref();
+ ExecuteCalibration(ctx, ah);
+ return;
+ }
+ int num_binding = ctx->num_inputs() + ctx->num_outputs();
+ std::vector<void*> buffers(num_binding);
+ int smallest_engine=GetEngineBatch(ctx);
+ if(smallest_engine<0)return;
+ int num_batch=ctx->input(0).shape().dim_size(0);
+ size_t binding_index;
+ auto engine_ctx_pair = GetEngine(smallest_engine, ctx, fixed_input_size_);
auto trt_engine_ptr_ = engine_ctx_pair.first;
if (!trt_engine_ptr_) {
LOG(WARNING) << "Engine retrieval for batch size " << num_batch
<< " failed Running native segment";
ExecuteNativeSegment(ctx, ah);
return;
- // ctx->SetStatus(tensorflow::errors::Unavailable(
- // StrCat("Engine retrieval for batch ", num_batch, " Failed")));
- // return;
}
for (int i = 0; i < ctx->num_inputs(); i++) {
string inp_name = "InputPH_";
@@ -283,17 +294,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
- if (i == 0) {
- num_batch = input_shape.dim_size(0);
- if (num_batch > trt_engine_ptr_->getMaxBatchSize()) {
- LOG(ERROR) << "input tensor batch " << num_batch
- << " larger than max_batch_size: "
- << trt_engine_ptr_->getMaxBatchSize();
- ctx->SetStatus(tensorflow::errors::FailedPrecondition(
- StrCat("Invalid batch size ", num_batch)));
- return;
- }
- } else if (num_batch != input_shape.dim_size(0)) {
+ if (num_batch != input_shape.dim_size(0)) {
LOG(ERROR) << "input data inconsistent batch size";
ctx->SetStatus(tensorflow::errors::FailedPrecondition(
"Different batch sizes between input tensors"));
@@ -393,25 +394,25 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx,
nullptr);
VLOG(2) << "enqueue returns: " << ret;
// sync should be done by TF.
-} // namespace tensorrt
+}
+
TRTEngineOp::~TRTEngineOp() {
// Order matters!
- for (auto eng : engine_map) {
+ for (auto eng : engine_map_) {
eng.second.first.reset();
eng.second.second.reset();
}
for (auto alloc : allocators_) alloc.second.reset();
}
-// template <typename T>
-// using destroyed_ptr = std::shared_ptr<T, TRTEngineOp::Destroyer<T>>;
-TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size,
+
+TRTEngineOp::EngineCtxPair TRTEngineOp::GetEngine(int batch_size,
OpKernelContext* ctx,
bool ignore_dim_change) {
tensorflow::mutex_lock lock(engine_mutex_);
- if (static_engine) {
- if (engine_map.size()) {
- if (engine_map.begin()->first >= batch_size) {
- return engine_map.begin()->second;
+ if (static_engine_) {
+ if (engine_map_.size()) {
+ if (engine_map_.begin()->first >= batch_size) {
+ return engine_map_.begin()->second;
} else {
return {nullptr, nullptr};
}
@@ -432,22 +433,22 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size,
infer->deserializeCudaEngine(serialized_segment_.c_str(),
serialized_segment_.size(), nullptr),
Destroyer<nvinfer1::ICudaEngine>());
- engine_map.insert({static_engine->getMaxBatchSize(),
- {static_engine,
- {static_engine->createExecutionContext(),
- Destroyer<nvinfer1::IExecutionContext>()}}});
+ engine_map_.insert({static_engine->getMaxBatchSize(),
+ {static_engine,
+ {static_engine->createExecutionContext(),
+ Destroyer<nvinfer1::IExecutionContext>()}}});
// Runtime is safe to delete after engine creation
infer->destroy();
serialized_segment_.clear();
if (static_engine->getMaxBatchSize() < batch_size) {
return {nullptr, nullptr};
}
- return engine_map.at(static_engine->getMaxBatchSize());
+ return engine_map_.at(static_engine->getMaxBatchSize());
}
} else {
- auto engine_it = engine_map.find(batch_size);
- if (engine_it == engine_map.end() &&
- engine_map.size() < (size_t)max_cached_engines) {
+ auto engine_it = engine_map_.find(batch_size);
+ if (engine_it == engine_map_.end() &&
+ engine_map_.size() < (size_t)max_cached_engines_) {
auto builder_ = std::shared_ptr<nvinfer1::IBuilder>(
nvinfer1::createInferBuilder(logger),
Destroyer<nvinfer1::IBuilder>()); // reset the builder to ensure
@@ -475,9 +476,9 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size,
VLOG(1) << name() << " Constructing a new engine with batch size "
<< batch_size;
builder_->setMaxBatchSize(batch_size);
- if (precision_mode == tensorflow::tensorrt::convert::FP16MODE) {
+ if (precision_mode_ == tensorflow::tensorrt::convert::FP16MODE) {
builder_->setHalf2Mode(true);
- } else if (precision_mode == tensorflow::tensorrt::convert::INT8MODE) {
+ } else if (precision_mode_ == tensorflow::tensorrt::convert::INT8MODE) {
builder_->setInt8Mode(true);
builder_->setInt8Calibrator(calibrator_.get());
}
@@ -488,9 +489,9 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size,
shapes.emplace_back(ctx->input(i).shape());
}
auto status = tensorflow::tensorrt::convert::ConvertSubgraphToEngine(
- segment_graph_, builder_.get(), shapes, &engine, precision_mode);
+ segment_graph_, builder_.get(), shapes, &engine, precision_mode_);
if (engine) {
- engine_map[batch_size] = {
+ engine_map_[batch_size] = {
std::shared_ptr<nvinfer1::ICudaEngine>(
engine, Destroyer<nvinfer1::ICudaEngine>()),
std::shared_ptr<nvinfer1::IExecutionContext>(
@@ -500,11 +501,11 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size,
LOG(ERROR) << "Engine creation for batch size " << batch_size
<< " failed";
ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!"));
- engine_map[batch_size] = {nullptr, nullptr};
+ engine_map_[batch_size] = {nullptr, nullptr};
return {nullptr, nullptr};
}
}
- return engine_map.at(batch_size);
+ return engine_map_.at(batch_size);
}
}
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index 5c9cd98cb3..1e6d7fbe93 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -54,24 +54,37 @@ class TRTEngineOp : public AsyncOpKernel {
}
};
+ // Execute calibration
+ void ExecuteCalibration(tensorflow::OpKernelContext* ctx,
+ AsyncHelper* helper);
+
+ // Construct a function handle for executing native funcdef graph
tensorflow::Status ConstructFunctionHandle(tensorflow::OpKernelContext* ctx);
- void ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, AsyncHelper* ah);
+
+ // Execute replaced native segment as function Op.
+ void ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
+ AsyncHelper* helper);
+
+ // Allocate necessary resources for calibration
tensorflow::Status AllocateCalibrationResources(
tensorflow::OpKernelContext* ctx,
tensorflow::tensorrt::TRTCalibrationResource** cr);
// TODO(samikama): context should go to a resource manager!
- // std::shared_ptr<nvinfer1::IExecutionContext> get_execution_context(
- // int batch_size);
typedef std::pair<std::shared_ptr<nvinfer1::ICudaEngine>,
std::shared_ptr<nvinfer1::IExecutionContext>>
EngineCtxPair;
- EngineCtxPair get_engine(int batch_size, OpKernelContext* ctx,
- bool ignore_dim_change = true);
+ EngineCtxPair GetEngine(int batch_size, OpKernelContext* ctx,
+ bool ignore_dim_change = true);
+
+ // Return engine batch closest to input batch.
+ int GetEngineBatch(OpKernelContext* ctx);
- std::unordered_map<int, EngineCtxPair> engine_map;
+ // map to keep engines and their execution context.
+ std::unordered_map<int, EngineCtxPair> engine_map_;
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
+ // keep device allocator for TRT
std::unordered_map<string, std::shared_ptr<nvinfer1::IGpuAllocator>>
allocators_;
string serialized_segment_;
@@ -80,12 +93,12 @@ class TRTEngineOp : public AsyncOpKernel {
tensorflow::GraphDef segment_graph_;
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
std::vector<tensorflow::PersistentTensor> dev_tensors_;
- int precision_mode;
- bool static_engine;
- bool calibration_mode;
- bool fixed_input_size;
- std::vector<int> cached_engine_batches;
- int max_cached_engines;
+ int precision_mode_;
+ bool static_engine_;
+ bool calibration_mode_;
+ bool fixed_input_size_;
+ std::vector<int> cached_engine_batches_;
+ int max_cached_engines_;
tensorflow::int64 workspace_size_;
tensorflow::mutex engine_mutex_;
tensorflow::FunctionLibraryRuntime::Handle native_func_;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index 5adffdc3d1..695394156c 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -47,13 +47,11 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
done_(false),
calib_running_(false),
batch_is_set_(false),
- calibration_table(calib_data) {}
+ calibration_table_(calib_data) {}
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
- const cudaStream_t stream,
- tensorflow::core::RefCounted* rc) {
+ const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
- tensorflow::core::ScopedUnref SC(rc);
while ((calib_running_ || batch_is_set_) &&
!done_) { // wait while calibration is running
cond_.wait(lock);
@@ -116,9 +114,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
}
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
- if (calibration_table.empty()) return nullptr;
- length = calibration_table.size();
- return calibration_table.data();
+ if (calibration_table_.empty()) return nullptr;
+ length = calibration_table_.size();
+ return calibration_table_.data();
}
void TRTInt8Calibrator::setDone() {
@@ -129,8 +127,9 @@ void TRTInt8Calibrator::setDone() {
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {
- calibration_table = string((const char*)ptr, length);
- VLOG(1) << "Got calibration data for "<<engine_name_<<" @"<<ptr<<" length="<<length;
+ calibration_table_ = string((const char*)ptr, length);
+ VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
+ << " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(1) << "Destroying calibrator for " << engine_name_;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index eec9571418..6b59d52c70 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -47,12 +47,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
bool setBatch(const std::unordered_map<string, void*>& data,
- const cudaStream_t stream,
- tensorflow::core::RefCounted* helper);
+ const cudaStream_t stream);
void setDone();
const void* readCalibrationCache(std::size_t& length) override;
void writeCalibrationCache(const void* ptr, std::size_t length) override;
- const string& getCalibrationTableAsString(){return calibration_table;}
+ const string& getCalibrationTableAsString() { return calibration_table_; }
~TRTInt8Calibrator();
private:
@@ -68,7 +67,7 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
bool calib_running_;
bool batch_is_set_;
string engine_name_;
- string calibration_table;
+ string calibration_table_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index 584d6baee5..022639dc01 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -47,17 +47,11 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
builder_->destroy();
- builder_ = nullptr;
network_->destroy();
- network_ = nullptr;
engine_->destroy();
- engine_ = nullptr;
delete thr_;
- thr_ = nullptr;
delete logger_;
- logger_ = nullptr;
delete calibrator_;
- calibrator_ = nullptr;
}
string DebugString() override {
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index 8142872fca..9bf2a56f99 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -29,67 +29,11 @@ namespace tensorflow {
namespace shape_inference {
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) {
- tensorflow::tensorrt::Logger logger;
- string serialized_engine;
- if(true){
- for(int i=0;i<context->num_outputs();++i){
- context->set_output(i,context->UnknownShape());
- }
- return Status::OK();
+ for (int i = 0; i < context->num_outputs(); ++i) {
+ context->set_output(i, context->UnknownShape());
}
- TF_RETURN_IF_ERROR(context->GetAttr("serialized_segment", &serialized_engine));
- nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
- nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
- serialized_engine.c_str(), serialized_engine.size(),
- tensorrt::PluginFactoryTensorRT::GetInstance());
-
- int num_batch = -1;
- std::vector<::tensorflow::DataType> input_type;
- TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type));
- for (size_t i = 0; i < context->num_inputs(); i++) {
- // Check if input shape is legit
- auto input_shape = context->input(i);
- for (int j = 0; j < context->Rank(input_shape); j++) {
- auto dim_handler = context->Dim(input_shape, j);
- if (j == 0) {
- if (i == 0) {
- num_batch = context->Value(dim_handler);
- } else if (num_batch != context->Value(dim_handler)) {
- // TODO(jie): TensorRT engine requires consistent batch between inputs
- // tensors. Segmenter should be aware of this.
- LOG(FATAL) << "TensorRT engine requires consistent batch size";
- }
- }
- }
- }
-
- // Arrange input here
- std::vector<string> input_nodes;
- TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes));
-
- // Arrange output here
- std::vector<string> output_nodes;
- //TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes));
- for (size_t i = 0; i < output_nodes.size(); i++) {
- int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str());
- ShapeHandle output_shape;
- std::vector<DimensionHandle> dim_vec;
- dim_vec.emplace_back(context->MakeDim(num_batch));
- if (binding_index != -1) {
- auto dims = trt_engine->getBindingDimensions(binding_index);
- for (int j = 0; j < dims.nbDims; j++) {
- dim_vec.emplace_back(context->MakeDim(dims.d[j]));
- }
- } else {
- LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i];
- }
- output_shape = context->MakeShape(dim_vec);
- context->set_output(i, output_shape);
- }
-
return Status::OK();
}
-
} // namespace shape_inference
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 861d241afb..80bb14accf 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -61,7 +61,7 @@ PyObject* version_helper(version_struct* in) {
if (!tuple) {
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_TypeError,
- "Tuple creation from pair<string,string> failed!");
+ "Tuple creation from version structure failed!");
}
return NULL;
}
@@ -69,15 +69,15 @@ PyObject* version_helper(version_struct* in) {
}
/* Define converters for vector<int> */
template<>
- bool _PyObjAs(PyObject *pyobj, int* dest) {
- *dest=PyLong_AsLong(pyobj);
- return true;
- }
+ bool _PyObjAs(PyObject *pyobj, int* dest) {
+ *dest = PyLong_AsLong(pyobj);
+ return true;
+}
- template<>
- PyObject *_PyObjFrom(const int& src) {
- return PyLong_FromLong(src);
- }
+template<>
+ PyObject *_PyObjFrom(const int& src) {
+ return PyLong_FromLong(src);
+}
%}
@@ -175,7 +175,8 @@ std::pair<string, string> trt_convert(
#endif // GOOGLE_CUDA && GOOGLE_TENSORRT
}
-std::pair<string, string> calib_convert(string graph_def_string
+std::pair<string, string> calib_convert(
+ string graph_def_string
// unfortunately we can't use TF_Status here since it
// is in c/c_api and brings in a lot of other libraries
// which in turn declare ops. These ops are included
@@ -250,8 +251,7 @@ std::pair<string, string> trt_convert(string graph_def_string,
int precision_mode, int minimum_segment_size,
bool is_dyn_op,
int max_cached_engines,
- std::vector<int> cached_engine_batches
- );
+ std::vector<int> cached_engine_batches);
version_struct get_linked_tensorrt_version();
version_struct get_loaded_tensorrt_version();