diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc | 20 |
1 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc index 8a17eb02f1..646d62483f 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/contrib/tensorrt/kernels/trt_engine_op.h" #include <algorithm> + #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/contrib/tensorrt/convert/utils.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h" #include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" #include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -230,7 +232,7 @@ void TRTEngineOp::ExecuteCalibration(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); calib_res->calibrator_->setBatch(input_data, *stream); VLOG(2) << "Passed calibration data"; ExecuteNativeSegment(ctx, helper); @@ -316,6 +318,11 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, ctx->SetStatus(tensorflow::errors::InvalidArgument( "INT8 inputs are not supported!")); return; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = (void*)(input_tensor.flat<int32>().data()); + break; +#endif default: LOG(ERROR) << "Unknown TRT data type: " << int(dtype); ctx->SetStatus(tensorflow::errors::InvalidArgument( @@ -368,6 +375,12 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, ctx->SetStatus(tensorflow::errors::InvalidArgument( "INT8 outputs are not supported!")); return; +#if NV_TENSORRT_MAJOR > 3 + case nvinfer1::DataType::kINT32: + buffers[binding_index] = + reinterpret_cast<void*>(output_tensor->flat<int32>().data()); + break; +#endif default: LOG(ERROR) << "Unknown TRT data type: " << static_cast<int>(dtype); ctx->SetStatus(tensorflow::errors::InvalidArgument( @@ -380,7 +393,7 @@ void TRTEngineOp::ComputeAsync(tensorflow::OpKernelContext* ctx, reinterpret_cast<const cudaStream_t*>(ctx->op_device_context() ->stream() ->implementation() - ->CudaStreamMemberHack())); + ->GpuStreamMemberHack())); // TODO(jie): trt enqueue does not return error auto& trt_execution_context_ptr = engine_ctx_pair.second; @@ -446,7 +459,8 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size, #endif TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), - serialized_segment_.size(), nullptr)); + serialized_segment_.size(), + PluginFactoryTensorRT::GetInstance())); auto raw_static_engine = static_engine.get(); const auto max_batch_size = raw_static_engine->getMaxBatchSize(); engine_map_[max_batch_size] = { |