diff options
11 files changed, 178 insertions, 223 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 55a5a45692..fd0f97f3af 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -187,7 +187,6 @@ tf_py_wrap_cc( deps = [ ":trt_conversion", ":trt_engine_op_kernel", - #"//tensorflow/core:framework_lite", "//third_party/python_runtime:headers", ], ) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 36191b5cc6..6ddfb01d9f 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -189,6 +189,11 @@ tensorflow::Status ConvertGraphDefToTensorRT( VLOG(2) << "cpu_cores: " << num_cpu_cores; VLOG(2) << "gpus: " << num_gpus; tensorflow::RewriterConfig rw_cfg; + // use only const folding and layout for the time being since new optimizers + // break the graph for us + rw_cfg.add_optimizers("constfold"); + rw_cfg.add_optimizers("layout"); + tensorflow::grappler::MetaOptimizer meta_opt(nullptr, rw_cfg); tensorflow::GraphDef gdef; TF_RETURN_IF_ERROR(meta_opt.Optimize(cluster.get(), item, &gdef)); @@ -210,10 +215,13 @@ tensorflow::Status ConvertGraphDefToTensorRT( cp.minimum_segment_size = minimum_segment_size; cp.graph_properties = &static_graph_properties; cp.max_workspace_size_bytes = max_workspace_size_bytes; - // return ConvertAfterShapes(gdef, output_names, max_batch_size, - // max_workspace_size_bytes, new_graph_def, - // precision_mode, minimum_segment_size, - // static_graph_properties, nullptr); + if (VLOG_IS_ON(5)) { + std::fstream f; + f.open("TRTConversionInput.pb", + std::fstream::out | std::fstream::binary | std::fstream::trunc); + f << gdef.SerializeAsString(); + f.close(); + } return ConvertAfterShapes(cp); } diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h index 9dd4a69965..f742b8acbc 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.h +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h @@ -30,8 +30,6 @@ namespace tensorflow { namespace tensorrt { namespace convert { -// This method converts an already generated calibration graph which was used in -// calibration runs to an inference graph struct ConversionParams { ConversionParams() : input_graph_def(nullptr), diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc index af7830c4e9..68659e4ab5 100644 --- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc @@ -205,16 +205,16 @@ tensorflow::Status TRTOptimizationPass::Optimize( tensorflow::grappler::GraphProperties static_graph_properties(item); TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true)); tensorflow::tensorrt::convert::ConversionParams cp; - cp.input_graph_def=&item.graph; - cp.output_names=&item.fetch; - cp.max_batch_size=maximum_batch_size_; - cp.max_workspace_size_bytes=maximum_workspace_size_; - cp.output_graph_def=optimized_graph; - cp.precision_mode=precision_mode_; - cp.minimum_segment_size=minimum_segment_size_; - cp.graph_properties=&static_graph_properties; - cp.cluster=cluster; - cp.is_dyn_op=false; + cp.input_graph_def = &item.graph; + cp.output_names = &item.fetch; + cp.max_batch_size = maximum_batch_size_; + cp.max_workspace_size_bytes = maximum_workspace_size_; + cp.output_graph_def = optimized_graph; + cp.precision_mode = precision_mode_; + cp.minimum_segment_size = minimum_segment_size_; + cp.graph_properties = &static_graph_properties; + cp.cluster = cluster; + cp.is_dyn_op = false; auto status = tensorflow::tensorrt::convert::ConvertAfterShapes(cp); VLOG(2) << optimized_graph->DebugString(); return status; diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index c1371d4830..76153886a8 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -39,6 +39,8 @@ using Dims = nvinfer1::Dims; namespace tensorrt { using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; +// A helper class to call done() for asynchronous execution. +// Helps simultaneous execution of native and TRT engines. class AsyncHelper : public tensorflow::core::RefCounted { public: AsyncHelper(tensorflow::AsyncOpKernel::DoneCallback done) { done_ = done; } @@ -100,8 +102,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) context->GetAttr("serialized_segment", &serialized_segment_)); OP_REQUIRES_OK(context, context->GetAttr("workspace_size_bytes", &workspace_size_)); - OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine)); - if (!static_engine) { + OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_)); + if (!static_engine_) { if (!segment_graph_.ParseFromString(serialized_segment_)) { LOG(ERROR) << "Parsing segment graph failed!"; context->SetStatus(tensorflow::errors::InvalidArgument( @@ -119,14 +121,14 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) OP_REQUIRES_OK(context, context->GetAttr("segment_funcdef_name", &funcdef_name_)); if (precision_string == "FP32") { - precision_mode = tensorflow::tensorrt::convert::FP32MODE; + precision_mode_ = tensorflow::tensorrt::convert::FP32MODE; } else if (precision_string == "FP16") { - precision_mode = tensorflow::tensorrt::convert::FP16MODE; + precision_mode_ = tensorflow::tensorrt::convert::FP16MODE; } else if (precision_string == "INT8") { - precision_mode = tensorflow::tensorrt::convert::INT8MODE; + precision_mode_ = tensorflow::tensorrt::convert::INT8MODE; } - calibration_mode = - precision_mode == tensorflow::tensorrt::convert::INT8MODE && + calibration_mode_ = + precision_mode_ == tensorflow::tensorrt::convert::INT8MODE && calibration_data_.size() == 0; if (calibration_data_.size()) { calibrator_.reset(new TRTInt8Calibrator(calibration_data_)); @@ -134,15 +136,15 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) } native_func_ = tensorflow::kInvalidHandle; OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count", - &max_cached_engines)); + &max_cached_engines_)); OP_REQUIRES_OK(context, - context->GetAttr("fixed_input_size", &fixed_input_size)); + context->GetAttr("fixed_input_size", &fixed_input_size_)); OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches", - &cached_engine_batches)); - std::sort(cached_engine_batches.begin(), cached_engine_batches.end()); + &cached_engine_batches_)); + std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end()); if (VLOG_IS_ON(1)) { string s("Engine Batches= "); - for (auto i : cached_engine_batches) { + for (auto i : cached_engine_batches_) { StrAppend(&s, i, " "); } VLOG(1) << s; @@ -150,8 +152,8 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context) } void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, - AsyncHelper* ah) { - if (!calibration_mode) { + AsyncHelper* helper) { + if (!calibration_mode_) { VLOG(1) << "Executing native engine"; } std::vector<Tensor> inputs; @@ -173,11 +175,11 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, for (int i = 0; i < ctx->num_inputs(); i++) { inputs.push_back(ctx->input(i)); } - ah->Ref(); // Increment count for calculating native graph + helper->Ref(); // Increment count for calculating native graph VLOG(1) << "Executing native segment " << name(); lib->Run(opts, native_func_, inputs, outputs, - [ctx, outputs, ah](const tensorflow::Status& s) { - tensorflow::core::ScopedUnref SC(ah); + [ctx, outputs, helper](const tensorflow::Status& s) { + tensorflow::core::ScopedUnref SC(helper); VLOG(1) << "Native Segment completed"; if (!s.ok()) { ctx->SetStatus(s); @@ -192,55 +194,50 @@ void TRTEngineOp::ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, return; } -void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, - tensorflow::AsyncOpKernel::DoneCallback done) { - auto ah = new AsyncHelper(done); - tensorflow::core::ScopedUnref SC(ah); - if (calibration_mode) { - auto TRT_RM = tensorflow::tensorrt::TRTResourceManager::instance(); - auto res_mgr = TRT_RM->getManager("TRTCalibration"); - tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; - auto status = res_mgr->LookupOrCreate( - funcdef_name_, "Calibrator", &calib_res, - {[ctx, this](tensorflow::tensorrt::TRTCalibrationResource** cr) - -> tensorflow::Status { - return this->AllocateCalibrationResources(ctx, cr); - }}); - if (!status.ok()) { - ctx->SetStatus(status); - return; - } - ExecuteNativeSegment(ctx, ah); - int num_inputs = ctx->num_inputs(); - // Pass input data to calibrator - std::unordered_map<string, void*> input_data; - for (int i = 0; i < num_inputs; i++) { - const Tensor& t = ctx->input(i); - void* data_address = GetTensorAddress(&t); - const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), - device_tensor->TotalBytes()); // use the tensor so FW keeps it - input_data.emplace(StrCat("InputPH_", i), data_address); - } - VLOG(2) << "Filled map for sending"; - // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files - const cudaStream_t* stream = CHECK_NOTNULL( - reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() - ->stream() - ->implementation() - ->CudaStreamMemberHack())); - ah->Ref(); // Increment count for calculating calibration data - calib_res->calibrator_->setBatch(input_data, *stream, ah); - VLOG(2) << "Passed calibration data"; +void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper) { + tensorflow::core::ScopedUnref SC(helper); + auto TRT_RM = tensorflow::tensorrt::TRTResourceManager::instance(); + auto res_mgr = TRT_RM->getManager("TRTCalibration"); + tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->LookupOrCreate( + funcdef_name_, "Calibrator", &calib_res, + {[ctx, this](tensorflow::tensorrt::TRTCalibrationResource** cr) + -> tensorflow::Status { + return this->AllocateCalibrationResources(ctx, cr); + }}); + if (!status.ok()) { + ctx->SetStatus(status); return; } - int num_binding = ctx->num_inputs() + ctx->num_outputs(); - std::vector<void*> buffers(num_binding); + ExecuteNativeSegment(ctx, helper); + int num_inputs = ctx->num_inputs(); + // Pass input data to calibrator + std::unordered_map<string, void*> input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), + device_tensor->TotalBytes()); // use the tensor so FW keeps it + input_data.emplace(StrCat("InputPH_", i), data_address); + } + VLOG(2) << "Filled map for sending"; + // copied from cuda_kernel_helper since it seems only valid in *.cu.cc files + const cudaStream_t* stream = CHECK_NOTNULL( + reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() + ->stream() + ->implementation() + ->CudaStreamMemberHack())); + calib_res->calibrator_->setBatch(input_data, *stream); + VLOG(2) << "Passed calibration data"; + return; +} - size_t binding_index; +int TRTEngineOp::GetEngineBatch(tensorflow::OpKernelContext *ctx){ int num_batch = ctx->input(0).shape().dim_size(0); int smallest_engine = 0; - for (const auto i : cached_engine_batches) { + for (const auto i : cached_engine_batches_) { if (i >= num_batch) { smallest_engine = i; break; @@ -248,32 +245,46 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, } // TODO(sami): Need an LRU here if (smallest_engine == 0) { - if (max_cached_engines > cached_engine_batches.size()) { + if (max_cached_engines_ > cached_engine_batches_.size()) { smallest_engine = num_batch; - cached_engine_batches.push_back(num_batch); - std::sort(cached_engine_batches.begin(), cached_engine_batches.end()); + cached_engine_batches_.push_back(num_batch); VLOG(1) << "Running with batch size " << num_batch; } else { string s("Engine buffer is full. buffer limit= "); - StrAppend(&s, max_cached_engines, ", current entries= "); - for (auto i : cached_engine_batches) StrAppend(&s, i, ", "); + StrAppend(&s, max_cached_engines_, ", current entries= "); + for (auto i : cached_engine_batches_) StrAppend(&s, i, ", "); StrAppend(&s, "Requested batch= ", num_batch); LOG(ERROR) << s; ctx->SetStatus(tensorflow::errors::ResourceExhausted( "Requested batch size is not available and engine cache is full")); - return; + return -1; } } - auto engine_ctx_pair = get_engine(smallest_engine, ctx, fixed_input_size); + return smallest_engine; +} + +void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, + tensorflow::AsyncOpKernel::DoneCallback done) { + auto ah = new AsyncHelper(done); + tensorflow::core::ScopedUnref SC(ah); + if (calibration_mode_) { + ah->Ref(); + ExecuteCalibration(ctx, ah); + return; + } + int num_binding = ctx->num_inputs() + ctx->num_outputs(); + std::vector<void*> buffers(num_binding); + int smallest_engine=GetEngineBatch(ctx); + if(smallest_engine<0)return; + int num_batch=ctx->input(0).shape().dim_size(0); + size_t binding_index; + auto engine_ctx_pair = GetEngine(smallest_engine, ctx, fixed_input_size_); auto trt_engine_ptr_ = engine_ctx_pair.first; if (!trt_engine_ptr_) { LOG(WARNING) << "Engine retrieval for batch size " << num_batch << " failed Running native segment"; ExecuteNativeSegment(ctx, ah); return; - // ctx->SetStatus(tensorflow::errors::Unavailable( - // StrCat("Engine retrieval for batch ", num_batch, " Failed"))); - // return; } for (int i = 0; i < ctx->num_inputs(); i++) { string inp_name = "InputPH_"; @@ -283,17 +294,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, const Tensor& input_tensor = ctx->input(i); const TensorShape& input_shape = input_tensor.shape(); - if (i == 0) { - num_batch = input_shape.dim_size(0); - if (num_batch > trt_engine_ptr_->getMaxBatchSize()) { - LOG(ERROR) << "input tensor batch " << num_batch - << " larger than max_batch_size: " - << trt_engine_ptr_->getMaxBatchSize(); - ctx->SetStatus(tensorflow::errors::FailedPrecondition( - StrCat("Invalid batch size ", num_batch))); - return; - } - } else if (num_batch != input_shape.dim_size(0)) { + if (num_batch != input_shape.dim_size(0)) { LOG(ERROR) << "input data inconsistent batch size"; ctx->SetStatus(tensorflow::errors::FailedPrecondition( "Different batch sizes between input tensors")); @@ -393,25 +394,25 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, nullptr); VLOG(2) << "enqueue returns: " << ret; // sync should be done by TF. -} // namespace tensorrt +} + TRTEngineOp::~TRTEngineOp() { // Order matters! - for (auto eng : engine_map) { + for (auto eng : engine_map_) { eng.second.first.reset(); eng.second.second.reset(); } for (auto alloc : allocators_) alloc.second.reset(); } -// template <typename T> -// using destroyed_ptr = std::shared_ptr<T, TRTEngineOp::Destroyer<T>>; -TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size, + +TRTEngineOp::EngineCtxPair TRTEngineOp::GetEngine(int batch_size, OpKernelContext* ctx, bool ignore_dim_change) { tensorflow::mutex_lock lock(engine_mutex_); - if (static_engine) { - if (engine_map.size()) { - if (engine_map.begin()->first >= batch_size) { - return engine_map.begin()->second; + if (static_engine_) { + if (engine_map_.size()) { + if (engine_map_.begin()->first >= batch_size) { + return engine_map_.begin()->second; } else { return {nullptr, nullptr}; } @@ -432,22 +433,22 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size, infer->deserializeCudaEngine(serialized_segment_.c_str(), serialized_segment_.size(), nullptr), Destroyer<nvinfer1::ICudaEngine>()); - engine_map.insert({static_engine->getMaxBatchSize(), - {static_engine, - {static_engine->createExecutionContext(), - Destroyer<nvinfer1::IExecutionContext>()}}}); + engine_map_.insert({static_engine->getMaxBatchSize(), + {static_engine, + {static_engine->createExecutionContext(), + Destroyer<nvinfer1::IExecutionContext>()}}}); // Runtime is safe to delete after engine creation infer->destroy(); serialized_segment_.clear(); if (static_engine->getMaxBatchSize() < batch_size) { return {nullptr, nullptr}; } - return engine_map.at(static_engine->getMaxBatchSize()); + return engine_map_.at(static_engine->getMaxBatchSize()); } } else { - auto engine_it = engine_map.find(batch_size); - if (engine_it == engine_map.end() && - engine_map.size() < (size_t)max_cached_engines) { + auto engine_it = engine_map_.find(batch_size); + if (engine_it == engine_map_.end() && + engine_map_.size() < (size_t)max_cached_engines_) { auto builder_ = std::shared_ptr<nvinfer1::IBuilder>( nvinfer1::createInferBuilder(logger), Destroyer<nvinfer1::IBuilder>()); // reset the builder to ensure @@ -475,9 +476,9 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size, VLOG(1) << name() << " Constructing a new engine with batch size " << batch_size; builder_->setMaxBatchSize(batch_size); - if (precision_mode == tensorflow::tensorrt::convert::FP16MODE) { + if (precision_mode_ == tensorflow::tensorrt::convert::FP16MODE) { builder_->setHalf2Mode(true); - } else if (precision_mode == tensorflow::tensorrt::convert::INT8MODE) { + } else if (precision_mode_ == tensorflow::tensorrt::convert::INT8MODE) { builder_->setInt8Mode(true); builder_->setInt8Calibrator(calibrator_.get()); } @@ -488,9 +489,9 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size, shapes.emplace_back(ctx->input(i).shape()); } auto status = tensorflow::tensorrt::convert::ConvertSubgraphToEngine( - segment_graph_, builder_.get(), shapes, &engine, precision_mode); + segment_graph_, builder_.get(), shapes, &engine, precision_mode_); if (engine) { - engine_map[batch_size] = { + engine_map_[batch_size] = { std::shared_ptr<nvinfer1::ICudaEngine>( engine, Destroyer<nvinfer1::ICudaEngine>()), std::shared_ptr<nvinfer1::IExecutionContext>( @@ -500,11 +501,11 @@ TRTEngineOp::EngineCtxPair TRTEngineOp::get_engine(int batch_size, LOG(ERROR) << "Engine creation for batch size " << batch_size << " failed"; ctx->SetStatus(tensorflow::errors::Internal("Engine creation failed!")); - engine_map[batch_size] = {nullptr, nullptr}; + engine_map_[batch_size] = {nullptr, nullptr}; return {nullptr, nullptr}; } } - return engine_map.at(batch_size); + return engine_map_.at(batch_size); } } diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 5c9cd98cb3..1e6d7fbe93 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -54,24 +54,37 @@ class TRTEngineOp : public AsyncOpKernel { } }; + // Execute calibration + void ExecuteCalibration(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper); + + // Construct a function handle for executing native funcdef graph tensorflow::Status ConstructFunctionHandle(tensorflow::OpKernelContext* ctx); - void ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, AsyncHelper* ah); + + // Execute replaced native segment as function Op. + void ExecuteNativeSegment(tensorflow::OpKernelContext* ctx, + AsyncHelper* helper); + + // Allocate necessary resources for calibration tensorflow::Status AllocateCalibrationResources( tensorflow::OpKernelContext* ctx, tensorflow::tensorrt::TRTCalibrationResource** cr); // TODO(samikama): context should go to a resource manager! - // std::shared_ptr<nvinfer1::IExecutionContext> get_execution_context( - // int batch_size); typedef std::pair<std::shared_ptr<nvinfer1::ICudaEngine>, std::shared_ptr<nvinfer1::IExecutionContext>> EngineCtxPair; - EngineCtxPair get_engine(int batch_size, OpKernelContext* ctx, - bool ignore_dim_change = true); + EngineCtxPair GetEngine(int batch_size, OpKernelContext* ctx, + bool ignore_dim_change = true); + + // Return engine batch closest to input batch. + int GetEngineBatch(OpKernelContext* ctx); - std::unordered_map<int, EngineCtxPair> engine_map; + // map to keep engines and their execution context. + std::unordered_map<int, EngineCtxPair> engine_map_; std::vector<string> input_nodes_; std::vector<string> output_nodes_; + // keep device allocator for TRT std::unordered_map<string, std::shared_ptr<nvinfer1::IGpuAllocator>> allocators_; string serialized_segment_; @@ -80,12 +93,12 @@ class TRTEngineOp : public AsyncOpKernel { tensorflow::GraphDef segment_graph_; std::unordered_map<string, std::pair<void*, size_t>> device_buffers_; std::vector<tensorflow::PersistentTensor> dev_tensors_; - int precision_mode; - bool static_engine; - bool calibration_mode; - bool fixed_input_size; - std::vector<int> cached_engine_batches; - int max_cached_engines; + int precision_mode_; + bool static_engine_; + bool calibration_mode_; + bool fixed_input_size_; + std::vector<int> cached_engine_batches_; + int max_cached_engines_; tensorflow::int64 workspace_size_; tensorflow::mutex engine_mutex_; tensorflow::FunctionLibraryRuntime::Handle native_func_; diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index 5adffdc3d1..695394156c 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -47,13 +47,11 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) done_(false), calib_running_(false), batch_is_set_(false), - calibration_table(calib_data) {} + calibration_table_(calib_data) {} bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, - const cudaStream_t stream, - tensorflow::core::RefCounted* rc) { + const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - tensorflow::core::ScopedUnref SC(rc); while ((calib_running_ || batch_is_set_) && !done_) { // wait while calibration is running cond_.wait(lock); @@ -116,9 +114,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, } const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { - if (calibration_table.empty()) return nullptr; - length = calibration_table.size(); - return calibration_table.data(); + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); } void TRTInt8Calibrator::setDone() { @@ -129,8 +127,9 @@ void TRTInt8Calibrator::setDone() { void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, std::size_t length) { - calibration_table = string((const char*)ptr, length); - VLOG(1) << "Got calibration data for "<<engine_name_<<" @"<<ptr<<" length="<<length; + calibration_table_ = string((const char*)ptr, length); + VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr + << " length=" << length; } TRTInt8Calibrator::~TRTInt8Calibrator() { VLOG(1) << "Destroying calibrator for " << engine_name_; diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index eec9571418..6b59d52c70 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -47,12 +47,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool getBatch(void* bindings[], const char* names[], int num_bindings) override; bool setBatch(const std::unordered_map<string, void*>& data, - const cudaStream_t stream, - tensorflow::core::RefCounted* helper); + const cudaStream_t stream); void setDone(); const void* readCalibrationCache(std::size_t& length) override; void writeCalibrationCache(const void* ptr, std::size_t length) override; - const string& getCalibrationTableAsString(){return calibration_table;} + const string& getCalibrationTableAsString() { return calibration_table_; } ~TRTInt8Calibrator(); private: @@ -68,7 +67,7 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool calib_running_; bool batch_is_set_; string engine_name_; - string calibration_table; + string calibration_table_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 584d6baee5..022639dc01 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -47,17 +47,11 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); builder_->destroy(); - builder_ = nullptr; network_->destroy(); - network_ = nullptr; engine_->destroy(); - engine_ = nullptr; delete thr_; - thr_ = nullptr; delete logger_; - logger_ = nullptr; delete calibrator_; - calibrator_ = nullptr; } string DebugString() override { diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 8142872fca..9bf2a56f99 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -29,67 +29,11 @@ namespace tensorflow { namespace shape_inference { tensorflow::Status TRTEngineOpShapeInference(InferenceContext* context) { - tensorflow::tensorrt::Logger logger; - string serialized_engine; - if(true){ - for(int i=0;i<context->num_outputs();++i){ - context->set_output(i,context->UnknownShape()); - } - return Status::OK(); + for (int i = 0; i < context->num_outputs(); ++i) { + context->set_output(i, context->UnknownShape()); } - TF_RETURN_IF_ERROR(context->GetAttr("serialized_segment", &serialized_engine)); - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); - nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( - serialized_engine.c_str(), serialized_engine.size(), - tensorrt::PluginFactoryTensorRT::GetInstance()); - - int num_batch = -1; - std::vector<::tensorflow::DataType> input_type; - TF_RETURN_IF_ERROR(context->GetAttr("InT", &input_type)); - for (size_t i = 0; i < context->num_inputs(); i++) { - // Check if input shape is legit - auto input_shape = context->input(i); - for (int j = 0; j < context->Rank(input_shape); j++) { - auto dim_handler = context->Dim(input_shape, j); - if (j == 0) { - if (i == 0) { - num_batch = context->Value(dim_handler); - } else if (num_batch != context->Value(dim_handler)) { - // TODO(jie): TensorRT engine requires consistent batch between inputs - // tensors. Segmenter should be aware of this. - LOG(FATAL) << "TensorRT engine requires consistent batch size"; - } - } - } - } - - // Arrange input here - std::vector<string> input_nodes; - TF_RETURN_IF_ERROR(context->GetAttr("input_nodes", &input_nodes)); - - // Arrange output here - std::vector<string> output_nodes; - //TF_RETURN_IF_ERROR(context->GetAttr("output_nodes", &output_nodes)); - for (size_t i = 0; i < output_nodes.size(); i++) { - int binding_index = trt_engine->getBindingIndex(output_nodes[i].c_str()); - ShapeHandle output_shape; - std::vector<DimensionHandle> dim_vec; - dim_vec.emplace_back(context->MakeDim(num_batch)); - if (binding_index != -1) { - auto dims = trt_engine->getBindingDimensions(binding_index); - for (int j = 0; j < dims.nbDims; j++) { - dim_vec.emplace_back(context->MakeDim(dims.d[j])); - } - } else { - LOG(FATAL) << "TensorRT engine cannot find binding: " << output_nodes[i]; - } - output_shape = context->MakeShape(dim_vec); - context->set_output(i, output_shape); - } - return Status::OK(); } - } // namespace shape_inference } // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i index 861d241afb..80bb14accf 100644 --- a/tensorflow/contrib/tensorrt/trt_conversion.i +++ b/tensorflow/contrib/tensorrt/trt_conversion.i @@ -61,7 +61,7 @@ PyObject* version_helper(version_struct* in) { if (!tuple) { if (!PyErr_Occurred()) { PyErr_SetString(PyExc_TypeError, - "Tuple creation from pair<string,string> failed!"); + "Tuple creation from version structure failed!"); } return NULL; } @@ -69,15 +69,15 @@ PyObject* version_helper(version_struct* in) { } /* Define converters for vector<int> */ template<> - bool _PyObjAs(PyObject *pyobj, int* dest) { - *dest=PyLong_AsLong(pyobj); - return true; - } + bool _PyObjAs(PyObject *pyobj, int* dest) { + *dest = PyLong_AsLong(pyobj); + return true; +} - template<> - PyObject *_PyObjFrom(const int& src) { - return PyLong_FromLong(src); - } +template<> + PyObject *_PyObjFrom(const int& src) { + return PyLong_FromLong(src); +} %} @@ -175,7 +175,8 @@ std::pair<string, string> trt_convert( #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } -std::pair<string, string> calib_convert(string graph_def_string +std::pair<string, string> calib_convert( + string graph_def_string // unfortunately we can't use TF_Status here since it // is in c/c_api and brings in a lot of other libraries // which in turn declare ops. These ops are included @@ -250,8 +251,7 @@ std::pair<string, string> trt_convert(string graph_def_string, int precision_mode, int minimum_segment_size, bool is_dyn_op, int max_cached_engines, - std::vector<int> cached_engine_batches - ); + std::vector<int> cached_engine_batches); version_struct get_linked_tensorrt_version(); version_struct get_loaded_tensorrt_version(); |